diff --git a/.gitignore b/.gitignore index 5112ca62ed..09398e2bf3 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,7 @@ # IDEs *.idea *.vscode + + +build/* +cmake-build* diff --git a/examples/device/CMakeLists.txt b/examples/device/CMakeLists.txt index 14da71efae..febe78c57a 100644 --- a/examples/device/CMakeLists.txt +++ b/examples/device/CMakeLists.txt @@ -25,7 +25,7 @@ if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) - foreach(_exec device_task ta_dense_device ta_cc_abcd_device ta_vector_device ta_reduce_device) + foreach(_exec device_task ta_dense_device ta_cc_abcd_device ta_vector_device ta_reduce_device ta_dense_um_tensor ta_vector_um_tensor) # Add executable add_ta_executable(${_exec} "${_exec}.cpp" "tiledarray") diff --git a/examples/device/ta_dense_um_tensor.cpp b/examples/device/ta_dense_um_tensor.cpp new file mode 100644 index 0000000000..6b94e315f3 --- /dev/null +++ b/examples/device/ta_dense_um_tensor.cpp @@ -0,0 +1,203 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + */ + +// Dense matrix-multiply benchmark using the native UMTensor tile type +// (TA::Tensor backed by device_um_allocator). Companion to the btas-based +// ta_dense_device.cpp; same shape + reporting, but the tile is bare +// `UMTensor` -- no `TA::Tile<>` wrapper -- and the data flows through +// the device tile-op overloads in src/TiledArray/device/tensor.h. +// +// Usage: +// ta_dense_um_tensor Nm Bm Nn Bn Nk Bk [nrepeat=5] +// +// Computes c(Nm,Nn) = a(Nm,Nk) * b(Nk,Nn) with each dimension blocked by +// Bm/Bn/Bk. Default scalar type is double; nrepeat iterations are timed +// for an average GFLOPS reading. + +#include +#include + +#ifdef TILEDARRAY_HAS_CUDA +#include +#endif + +#include +#include +#include + +namespace { + +template +void run(TiledArray::World &world, long Nm, long Bm, long Nn, long Bn, long Nk, + long Bk, long nrepeat) { + using TA::DistArray; + using TA::TiledRange; + using TA::TiledRange1; + using TA::UMTensor; + using TileT = UMTensor; + using ArrayT = DistArray; + + constexpr bool complex_T = TA::detail::is_complex_v; + // GEMM flops: 2 * M * N * K (8 * for complex). + const std::int64_t nflops = + (complex_T ? 8 : 2) * static_cast(Nm) * + static_cast(Nn) * static_cast(Nk); + + auto blocking = [](long N, long B) { + std::vector v; + for (long i = 0; i <= N; i += B) v.push_back(static_cast(i)); + return v; + }; + auto blk_m = blocking(Nm, Bm); + auto blk_n = blocking(Nn, Bn); + auto blk_k = blocking(Nk, Bk); + + TiledRange trange_a({TiledRange1(blk_m.begin(), blk_m.end()), + TiledRange1(blk_k.begin(), blk_k.end())}); + TiledRange trange_b({TiledRange1(blk_k.begin(), blk_k.end()), + TiledRange1(blk_n.begin(), blk_n.end())}); + TiledRange trange_c({TiledRange1(blk_m.begin(), blk_m.end()), + TiledRange1(blk_n.begin(), blk_n.end())}); + + if (world.rank() == 0) + std::cout << "TiledArray UMTensor dense matrix multiply\n" + << " Nodes = " << world.size() << "\n" + << " A = " << Nm << " x " << Nk << " (" + << double(Nm * Nk * sizeof(T)) / 1.0e9 << " GB)\n" + << " B = " << Nk << " x " << Nn << " (" + << double(Nk * Nn * sizeof(T)) / 1.0e9 << " GB)\n" + << " C = " << Nm << " x " << Nn << " (" + << double(Nm * Nn * sizeof(T)) / 1.0e9 << " GB)\n" + << " Tile A,B,C = " << Bm << "x" << Bk << ", " << Bk << "x" + << Bn << ", " << Bm << "x" << Bn << "\n" + << " Iterations = " << nrepeat << "\n"; + + ArrayT a(world, trange_a); + ArrayT b(world, trange_b); + ArrayT c(world, trange_c); + + const T val_a = T(0.03); + const T val_b = T(0.02); + a.fill(val_a); + b.fill(val_b); + world.gop.fence(); + + // Prefetch inputs to the device once before the timed loop -- the per-tile + // ops will also prefetch lazily, but doing it up front keeps the timing + // focused on the GEMM kernel cost. + TA::to_device(a); + TA::to_device(b); + +#ifdef TILEDARRAY_HAS_CUDA + cudaProfilerStart(); +#endif + + double total_time = 0.0; + double total_gflops = 0.0; + for (long i = 0; i < nrepeat; ++i) { + const double t0 = madness::wall_time(); + c("m,n") = a("m,k") * b("k,n"); + world.gop.fence(); + const double t1 = madness::wall_time(); + const double dt = t1 - t0; + const double gflops = static_cast(nflops) / (dt * 1.0e9); + total_time += dt; + total_gflops += gflops; + if (world.rank() == 0) + std::cout << " iter " << (i + 1) << " time=" << dt + << " s gflops=" << gflops << "\n"; + } + +#ifdef TILEDARRAY_HAS_CUDA + cudaProfilerStop(); +#endif + + if (world.rank() == 0) + std::cout << " Average time = " << (total_time / double(nrepeat)) + << " s\n Average gflops = " << (total_gflops / double(nrepeat)) + << "\n"; + + // Verify: every result element should be Nk * val_a * val_b. + const T expected = T(Nk) * val_a * val_b; + const auto eps = std::numeric_limits>::epsilon(); + const auto tolerance = std::abs(expected) * static_cast(Nk) * + static_cast(8) * eps; + TA::to_host(c); + bool ok = true; + for (auto it = c.begin(); it != c.end(); ++it) { + const auto tile = it->get(); + for (std::size_t k = 0; k < tile.size(); ++k) { + if (std::abs(tile.data()[k] - expected) > tolerance) { + ok = false; + if (world.rank() == 0) + std::cout << " MISMATCH at tile " << it.index() << " element " << k + << ": got " << tile.data()[k] << " expected " << expected + << "\n"; + break; + } + } + if (!ok) break; + } + if (world.rank() == 0) + std::cout << (ok ? " Verification PASSED\n" : " Verification FAILED\n"); +} + +} // namespace + +int try_main(int argc, char **argv) { + TiledArray::World &world = TA_SCOPED_INITIALIZE(argc, argv); + + if (argc < 7) { + if (world.rank() == 0) + std::cerr + << "Usage: " << argv[0] << " Nm Bm Nn Bn Nk Bk [nrepeat=5]\n" + << " Computes c(Nm,Nn) = a(Nm,Nk) * b(Nk,Nn) with UMTensor tiles\n"; + return 1; + } + const long Nm = std::atol(argv[1]); + const long Bm = std::atol(argv[2]); + const long Nn = std::atol(argv[3]); + const long Bn = std::atol(argv[4]); + const long Nk = std::atol(argv[5]); + const long Bk = std::atol(argv[6]); + const long nrepeat = (argc >= 8 ? std::atol(argv[7]) : 5); + if (Nm <= 0 || Nn <= 0 || Nk <= 0 || Bm <= 0 || Bn <= 0 || Bk <= 0 || + nrepeat <= 0) { + if (world.rank() == 0) + std::cerr << "All sizes / blocks / nrepeat must be positive\n"; + return 1; + } + + run(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat); + return 0; +} + +int main(int argc, char **argv) { + try { + return try_main(argc, argv); + } catch (const std::exception &e) { + std::cerr << "exception: " << e.what() << "\n"; + return 1; + } catch (...) { + std::cerr << "unknown exception\n"; + return 1; + } +} diff --git a/examples/device/ta_vector_um_tensor.cpp b/examples/device/ta_vector_um_tensor.cpp new file mode 100644 index 0000000000..603a3bd15f --- /dev/null +++ b/examples/device/ta_vector_um_tensor.cpp @@ -0,0 +1,156 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + */ + +// Element-wise vector-op benchmarks (add, scale, permute, Hadamard) using +// the native UMTensor tile type. Companion to ta_vector_device.cpp. +// +// Usage: +// ta_vector_um_tensor Nm Bm Nn Bn [nrepeat=5] +// +// Times each op for nrepeat iterations and reports the average wall time +// and effective bandwidth (counting one read + one write per element for +// in-place ops, two reads + one write for binary ops). + +#include +#include + +#include +#include +#include + +namespace { + +template +void run(TiledArray::World &world, long Nm, long Bm, long Nn, long Bn, + long nrepeat) { + using TA::DistArray; + using TA::TiledRange; + using TA::TiledRange1; + using TA::UMTensor; + using TileT = UMTensor; + using ArrayT = DistArray; + + auto blocking = [](long N, long B) { + std::vector v; + for (long i = 0; i <= N; i += B) v.push_back(static_cast(i)); + return v; + }; + auto blk_m = blocking(Nm, Bm); + auto blk_n = blocking(Nn, Bn); + + TiledRange trange({TiledRange1(blk_m.begin(), blk_m.end()), + TiledRange1(blk_n.begin(), blk_n.end())}); + TiledRange trange_T({TiledRange1(blk_n.begin(), blk_n.end()), + TiledRange1(blk_m.begin(), blk_m.end())}); + + if (world.rank() == 0) + std::cout << "TiledArray UMTensor vector-op benchmark\n" + << " Nodes = " << world.size() << "\n" + << " Matrix = " << Nm << " x " << Nn << " (" + << double(Nm * Nn * sizeof(T)) / 1.0e9 << " GB)\n" + << " Tile = " << Bm << " x " << Bn << "\n" + << " Iterations = " << nrepeat << "\n"; + + ArrayT a(world, trange); + ArrayT b(world, trange); + ArrayT c(world, trange); + ArrayT t(world, trange_T); // transposed-shape result for permute test + + a.fill(T(0.03)); + b.fill(T(0.02)); + c.fill(T(0.0)); + t.fill(T(0.0)); + world.gop.fence(); + TA::to_device(a); + TA::to_device(b); + + const double bytes_per_elem = static_cast(sizeof(T)); + const double n_elems = static_cast(Nm) * static_cast(Nn); + + auto bench = [&](const char *name, double bytes_per_iter, auto &&op) { + double total_time = 0.0; + for (long i = 0; i < nrepeat; ++i) { + const double t0 = madness::wall_time(); + op(); + world.gop.fence(); + const double t1 = madness::wall_time(); + total_time += t1 - t0; + } + const double avg = total_time / static_cast(nrepeat); + const double bw_gbs = bytes_per_iter / (avg * 1.0e9); + if (world.rank() == 0) + std::cout << " " << name << ": avg=" << avg << " s bw=" << bw_gbs + << " GB/s\n"; + }; + + // Binary read-read-write: 3 element accesses per element. + const double rw3_bytes = 3.0 * n_elems * bytes_per_elem; + // Unary read-write: 2 element accesses per element. + const double rw2_bytes = 2.0 * n_elems * bytes_per_elem; + + bench("add(c=a+b)", rw3_bytes, [&] { c("m,n") = a("m,n") + b("m,n"); }); + bench("subt(c=a-b)", rw3_bytes, [&] { c("m,n") = a("m,n") - b("m,n"); }); + bench("scale(c=2*a)", rw2_bytes, [&] { c("m,n") = 2.0 * a("m,n"); }); + bench("hadamard(c=a*b)", rw3_bytes, [&] { c("m,n") = a("m,n") * b("m,n"); }); + bench("permute(t=a^T)", rw2_bytes, [&] { t("n,m") = a("m,n"); }); + bench("axpy(c+=a)", rw3_bytes, [&] { c("m,n") += a("m,n"); }); + + world.gop.fence(); +} + +} // namespace + +int try_main(int argc, char **argv) { + TiledArray::World &world = TA_SCOPED_INITIALIZE(argc, argv); + + if (argc < 5) { + if (world.rank() == 0) + std::cerr + << "Usage: " << argv[0] << " Nm Bm Nn Bn [nrepeat=5]\n" + << " Times element-wise vector ops on Nm x Nn UMTensor matrices\n"; + return 1; + } + const long Nm = std::atol(argv[1]); + const long Bm = std::atol(argv[2]); + const long Nn = std::atol(argv[3]); + const long Bn = std::atol(argv[4]); + const long nrepeat = (argc >= 6 ? std::atol(argv[5]) : 5); + if (Nm <= 0 || Nn <= 0 || Bm <= 0 || Bn <= 0 || nrepeat <= 0) { + if (world.rank() == 0) + std::cerr << "All sizes / blocks / nrepeat must be positive\n"; + return 1; + } + + run(world, Nm, Bm, Nn, Bn, nrepeat); + return 0; +} + +int main(int argc, char **argv) { + try { + return try_main(argc, argv); + } catch (const std::exception &e) { + std::cerr << "exception: " << e.what() << "\n"; + return 1; + } catch (...) { + std::cerr << "unknown exception\n"; + return 1; + } +} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e0edad8c7e..28c1eb0bf1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -234,6 +234,7 @@ if(TILEDARRAY_HAS_HIP OR TILEDARRAY_HAS_CUDA) TiledArray/device/blas.h TiledArray/device/btas.h TiledArray/device/btas_um_tensor.h + TiledArray/device/tensor.h TiledArray/device/device_task_fn.h TiledArray/device/kernel/mult_kernel.h TiledArray/device/kernel/reduce_kernel.h @@ -267,6 +268,7 @@ if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) set(TILEDARRAY_DEVICE_SOURCE_FILES TiledArray/device/btas_um_tensor.cpp + TiledArray/device/tensor.cpp ) if(TILEDARRAY_HAS_CUDA) diff --git a/src/TiledArray/device/tensor.cpp b/src/TiledArray/device/tensor.cpp new file mode 100644 index 0000000000..94dc4d3cc0 --- /dev/null +++ b/src/TiledArray/device/tensor.cpp @@ -0,0 +1,42 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + */ + +#include + +#include + +namespace TiledArray { + +// Explicit instantiations of the UMTensor class for the standard numeric +// types. Without these, every TU including device/tensor.h would instantiate +// the full TA::Tensor> class body. +// Mirrors the host-side set in tensor/tensor.cpp; paired with the +// `extern template` declarations in device/tensor.h. +template class Tensor>; +template class Tensor>; +template class Tensor, + device_um_allocator>>; +template class Tensor, + device_um_allocator>>; +template class Tensor>; +template class Tensor>; + +} // namespace TiledArray diff --git a/src/TiledArray/device/tensor.h b/src/TiledArray/device/tensor.h new file mode 100644 index 0000000000..4488734191 --- /dev/null +++ b/src/TiledArray/device/tensor.h @@ -0,0 +1,946 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + */ + +#ifndef TILEDARRAY_DEVICE_TENSOR_H +#define TILEDARRAY_DEVICE_TENSOR_H + +#include + +#ifdef TILEDARRAY_HAS_DEVICE + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace TiledArray { +namespace detail { + +/// UMTensor lives in unified memory; it is identified as a device_tile and +/// the expression engine must route its tile ops through +/// madness::add_device_task. +template +struct is_device_tile> + : public std::bool_constant> {}; + +/// Prefetch a UMTensor's storage to the device associated with its tile range. +template + requires TiledArray::detail::is_numeric_v +inline void to_device(const TiledArray::UMTensor& tile) { + if (tile.empty()) return; + auto stream = device::stream_for(tile.range()); + if (deviceEnv::instance()->concurrent_managed_access()) { + DeviceSafeCall(device::memPrefetchAsync(tile.data(), + tile.total_size() * sizeof(T), + stream.device, stream.stream)); + } +} + +/// Prefetch a UMTensor's storage back to the host. +template + requires TiledArray::detail::is_numeric_v +inline void to_host(const TiledArray::UMTensor& tile) { + if (tile.empty()) return; + auto stream = device::stream_for(tile.range()); + if (deviceEnv::instance()->concurrent_managed_access()) { + DeviceSafeCall( + device::memPrefetchAsync(tile.data(), tile.total_size() * sizeof(T), + device::CpuDeviceId, stream.stream)); + } +} + +} // namespace detail + +// clang-format off +/// Tile-op overloads for UMTensor. +/// +/// Each overload sits in `namespace TiledArray` so ADL finds it from the +/// expression engine and from the tile_op layer's free-function defaults. +/// More-specialized concrete-type overloads win against the generic +/// forwarder in `tile_op/tile_interface.h`: +/// \code +/// template +/// auto add(Left&& left, Right&& right) { +/// return left.add(right); +/// } +/// \endcode +/// so we never fall back to the CPU member functions for UMTensor. +/// +/// All overloads follow the stream/queue contract: +/// 1. Resolve a queue via `blasqueue_for(range)`. Inside a device task +/// this is the same queue everyone else in the task uses (see +/// `external/device.h:899-907`); outside one, it round-robins. +/// 2. Prefetch every input + the result to the device. +/// 3. Call into BLAS++ / device kernels on that queue. +/// 4. `sync_madness_task_with(stream)` so the enclosing MADNESS device +/// task waits for the queue to drain before completing. +/// +/// In-place ops provide both an lvalue and an rvalue overload: the lvalue +/// overload does the work, the rvalue overload forwards to it. +/// +/// nbatch_ > 1 is not yet supported; the host-side tile +/// ops don't support them either. +// clang-format on + +/// result[i] = arg[i] +template + requires TiledArray::detail::is_numeric_v +inline UMTensor clone(const UMTensor& arg) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(arg.nbatch() == 1); + + auto& queue = blasqueue_for(arg.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(arg.range()); + + detail::to_device(arg); + detail::to_device(result); + + blas::copy(result.size(), arg.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +namespace detail { + +/// Apply a scaling factor in-place on the device, replicating the +/// ComplexConjugate handling from device/btas.h::scale. Real-valued kernels +/// reduce to a single `blas::scal`; conjugation+scale on complex tiles +/// requires a custom kernel that we have not implemented yet. +template +inline void apply_scale_factor(T* data, std::size_t n, const Scalar factor, + ::blas::Queue& queue) { + if constexpr (TiledArray::detail::is_blas_numeric_v || + std::is_arithmetic_v) { + ::blas::scal(n, factor, data, 1, queue); + } else if constexpr (TiledArray::detail::is_complex_v) { + TA_EXCEPTION( + "UMTensor scale with ComplexConjugate factor on complex T is not " + "implemented (requires a fused conjugation kernel)"); + } else if constexpr (std::is_same_v< + Scalar, + TiledArray::detail::ComplexConjugate>) { + // conjugation on a real tensor is a no-op + } else if constexpr (std::is_same_v>) { + ::blas::scal(n, static_cast(-1), data, 1, queue); + } +} + +} // namespace detail + +/// result[i] = arg[i] * factor +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor scale(const UMTensor& arg, const Scalar factor) { + auto result = clone(arg); + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + detail::apply_scale_factor(result.data(), result.size(), factor, queue); + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] *= factor (in-place) +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor& scale_to(UMTensor& result, const Scalar factor) { + TA_ASSERT(!result.empty()); + TA_ASSERT(result.nbatch() == 1); + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + detail::to_device(result); + detail::apply_scale_factor(result.data(), result.size(), factor, queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor& scale_to(UMTensor&& result, const Scalar factor) { + return scale_to(result, factor); +} + +/// result[i] = -arg[i] +template + requires TiledArray::detail::is_numeric_v +inline UMTensor neg(const UMTensor& arg) { + return scale(arg, T(-1)); +} + +/// arg[i] = -arg[i] (in-place) +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& neg_to(UMTensor& arg) { + return scale_to(arg, T(-1)); +} + +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& neg_to(UMTensor&& arg) { + return neg_to(arg); +} + +/// result[i] = arg1[i] + arg2[i] +template + requires TiledArray::detail::is_numeric_v +inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2) { + TA_ASSERT(!arg1.empty()); + TA_ASSERT(!arg2.empty()); + TA_ASSERT(arg1.nbatch() == 1 && arg2.nbatch() == 1); + + auto& queue = blasqueue_for(arg1.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(arg1.range()); + + detail::to_device(arg1); + detail::to_device(arg2); + detail::to_device(result); + + ::blas::copy(result.size(), arg1.data(), 1, result.data(), 1, queue); + ::blas::axpy(result.size(), T(1), arg2.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] = (arg1[i] + arg2[i]) * factor +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor) { + auto result = add(arg1, arg2); + return scale_to(result, factor); +} + +/// result[i] += arg[i] +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& add_to(UMTensor& result, const UMTensor& arg) { + TA_ASSERT(!result.empty()); + TA_ASSERT(!arg.empty()); + TA_ASSERT(result.nbatch() == 1 && arg.nbatch() == 1); + + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(result); + detail::to_device(arg); + + ::blas::axpy(result.size(), T(1), arg.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& add_to(UMTensor&& result, const UMTensor& arg) { + return add_to(result, arg); +} + +/// result[i] = (result[i] + arg[i]) * factor +/// Matches TA::Tensor::add_to(right, factor) semantics: `(l += r) *= factor`. +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor& add_to(UMTensor& result, const UMTensor& arg, + const Scalar factor) { + add_to(result, arg); + return scale_to(result, factor); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor& add_to(UMTensor&& result, const UMTensor& arg, + const Scalar factor) { + return add_to(result, arg, factor); +} + +/// result[i] = arg1[i] - arg2[i] +template + requires TiledArray::detail::is_numeric_v +inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2) { + TA_ASSERT(!arg1.empty()); + TA_ASSERT(!arg2.empty()); + TA_ASSERT(arg1.nbatch() == 1 && arg2.nbatch() == 1); + + auto& queue = blasqueue_for(arg1.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(arg1.range()); + + detail::to_device(arg1); + detail::to_device(arg2); + detail::to_device(result); + + ::blas::copy(result.size(), arg1.data(), 1, result.data(), 1, queue); + ::blas::axpy(result.size(), T(-1), arg2.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] = (arg1[i] - arg2[i]) * factor +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor) { + auto result = subt(arg1, arg2); + return scale_to(result, factor); +} + +/// result[i] -= arg[i] +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg) { + TA_ASSERT(!result.empty()); + TA_ASSERT(!arg.empty()); + TA_ASSERT(result.nbatch() == 1 && arg.nbatch() == 1); + + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(result); + detail::to_device(arg); + + ::blas::axpy(result.size(), T(-1), arg.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& subt_to(UMTensor&& result, const UMTensor& arg) { + return subt_to(result, arg); +} + +/// result[i] = (result[i] - arg[i]) * factor +/// Matches TA::Tensor::subt_to(right, factor) semantics: `(l -= r) *= factor`. +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor& subt_to(UMTensor& result, const UMTensor& arg, + const Scalar factor) { + subt_to(result, arg); + return scale_to(result, factor); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor& subt_to(UMTensor&& result, const UMTensor& arg, + const Scalar factor) { + return subt_to(result, arg, factor); +} + +/// dot product: scalar = sum_i arg1[i] * arg2[i] +template + requires TiledArray::detail::is_numeric_v +inline T dot(const UMTensor& arg1, const UMTensor& arg2) { + TA_ASSERT(!arg1.empty()); + TA_ASSERT(!arg2.empty()); + TA_ASSERT(arg1.nbatch() == 1 && arg2.nbatch() == 1); + TA_ASSERT(arg1.size() == arg2.size()); + + auto& queue = blasqueue_for(arg1.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(arg1); + detail::to_device(arg2); + + T result(0); + ::blas::dot(arg1.size(), arg1.data(), 1, arg2.data(), 1, &result, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// scalar = sum_i arg[i] * arg[i] +template + requires TiledArray::detail::is_numeric_v +inline auto squared_norm(const UMTensor& arg) { + return dot(arg, arg); +} + +/// scalar = sqrt(squared_norm(arg)) +template + requires TiledArray::detail::is_numeric_v +inline auto norm(const UMTensor& arg) { + using std::sqrt; + using ResultType = TiledArray::detail::scalar_t; + return static_cast(sqrt(squared_norm(arg))); +} + +/// result[perm(i)] = arg[i] +template + requires TiledArray::detail::is_numeric_v +inline UMTensor permute(const UMTensor& arg, + const TiledArray::Permutation& perm) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(arg.nbatch() == 1); + TA_ASSERT(perm.size() == arg.range().rank()); + + auto result_range = perm * arg.range(); + auto& queue = blasqueue_for(result_range); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(result_range); + + detail::to_device(arg); + detail::to_device(result); + + // librett operates on the original (unpermuted) range and writes into the + // permuted layout; pointers go in as-is. + librett_permute(const_cast(arg.data()), result.data(), arg.range(), perm, + stream.stream); + + device::sync_madness_task_with(stream); + return result; +} + +/// BipartitePermutation -> plain Permutation forward. +template + requires TiledArray::detail::is_numeric_v +inline UMTensor permute(const UMTensor& arg, + const TiledArray::BipartitePermutation& perm) { + TA_ASSERT(inner_size(perm) == 0); // UMTensor is a non-nested tile + return permute(arg, outer(perm)); +} + +/// result[perm(i)] = arg[i] * factor +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +inline UMTensor scale(const UMTensor& arg, const Scalar factor, + const Perm& perm) { + auto scaled = scale(arg, factor); + return permute(scaled, perm); +} + +/// result[perm(i)] = -arg[i] +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +inline UMTensor neg(const UMTensor& arg, const Perm& perm) { + return permute(neg(arg), perm); +} + +/// result[perm(i)] = arg1[i] + arg2[i] +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, + const Perm& perm) { + return permute(add(arg1, arg2), perm); +} + +/// result[perm(i)] = (arg1[i] + arg2[i]) * factor +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +inline UMTensor add(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor, const Perm& perm) { + return permute(add(arg1, arg2, factor), perm); +} + +/// result[perm(i)] = arg1[i] - arg2[i] +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, + const Perm& perm) { + return permute(subt(arg1, arg2), perm); +} + +/// result[perm(i)] = (arg1[i] - arg2[i]) * factor +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +inline UMTensor subt(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor, const Perm& perm) { + return permute(subt(arg1, arg2, factor), perm); +} + +/// shift: result has arg's data, range shifted by bound_shift. +template + requires TiledArray::detail::is_numeric_v +inline UMTensor shift(const UMTensor& arg, const Index& bound_shift) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(arg.nbatch() == 1); + + TiledArray::Range result_range(arg.range()); + result_range.inplace_shift(bound_shift); + + auto& queue = blasqueue_for(result_range); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(result_range); + + detail::to_device(arg); + detail::to_device(result); + + ::blas::copy(result.size(), arg.data(), 1, result.data(), 1, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// shift_to: in-place range shift, no data movement. +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& shift_to(UMTensor& arg, const Index& bound_shift) { + return arg.shift_to(bound_shift); +} + +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& shift_to(UMTensor&& arg, const Index& bound_shift) { + return shift_to(arg, bound_shift); +} + +/// result[i] = arg1[i] * arg2[i] (element-wise / Hadamard) +template + requires TiledArray::detail::is_numeric_v +inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2) { + TA_ASSERT(!arg1.empty()); + TA_ASSERT(!arg2.empty()); + TA_ASSERT(arg1.size() == arg2.size()); + TA_ASSERT(arg1.nbatch() == 1 && arg2.nbatch() == 1); + + auto& queue = blasqueue_for(arg1.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(arg1.range()); + + detail::to_device(arg1); + detail::to_device(arg2); + detail::to_device(result); + + device::mult_kernel(result.data(), arg1.data(), arg2.data(), arg1.size(), + stream); + + device::sync_madness_task_with(stream); + return result; +} + +/// result[i] = arg1[i] * arg2[i] * factor +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor) { + auto result = mult(arg1, arg2); + return scale_to(result, factor); +} + +/// result[perm(i)] = arg1[i] * arg2[i] +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, + const Perm& perm) { + return permute(mult(arg1, arg2), perm); +} + +/// result[perm(i)] = arg1[i] * arg2[i] * factor +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +inline UMTensor mult(const UMTensor& arg1, const UMTensor& arg2, + const Scalar factor, const Perm& perm) { + return permute(mult(arg1, arg2, factor), perm); +} + +/// result[i] *= arg[i] +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg) { + TA_ASSERT(!result.empty()); + TA_ASSERT(!arg.empty()); + TA_ASSERT(result.size() == arg.size()); + TA_ASSERT(result.nbatch() == 1 && arg.nbatch() == 1); + + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(result); + detail::to_device(arg); + + device::mult_to_kernel(result.data(), arg.data(), result.size(), stream); + + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +inline UMTensor& mult_to(UMTensor&& result, const UMTensor& arg) { + return mult_to(result, arg); +} + +/// result[i] = (result[i] * arg[i]) * factor +/// Matches TA::Tensor::mult_to(right, factor) semantics: `(l *= r) *= factor`. +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor& mult_to(UMTensor& result, const UMTensor& arg, + const Scalar factor) { + mult_to(result, arg); + return scale_to(result, factor); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor& mult_to(UMTensor&& result, const UMTensor& arg, + const Scalar factor) { + return mult_to(result, arg, factor); +} + +/// gemm: result = factor * left * right +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline UMTensor gemm(const UMTensor& left, const UMTensor& right, + const Scalar factor, + const TiledArray::math::GemmHelper& gemm_helper) { + TA_ASSERT(!left.empty()); + TA_ASSERT(!right.empty()); + TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); + TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); + TA_ASSERT(left.nbatch() == 1 && right.nbatch() == 1); + + TA_ASSERT(gemm_helper.left_right_congruent(left.range().extent_data(), + right.range().extent_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(left.range().lobound_data(), + right.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(left.range().upbound_data(), + right.range().upbound_data())); + + auto result_range = gemm_helper.template make_result_range( + left.range(), right.range()); + + auto& queue = blasqueue_for(result_range); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(result_range); + TA_ASSERT(result.nbatch() == 1); + + detail::to_device(left); + detail::to_device(right); + detail::to_device(result); + + using TiledArray::math::blas::integer; + integer m = 1, n = 1, k = 1; + gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range()); + + const integer lda = std::max( + integer{1}, + gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m); + const integer ldb = std::max( + integer{1}, + gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k); + const integer ldc = std::max(integer{1}, n); + + const T factor_t = T(factor); + const T zero(0); + + // Match btas device gemm (device/btas.h): col-major view with right/left + // swapped reproduces TA::Tensor's row-major layout under cublas. + ::blas::gemm(::blas::Layout::ColMajor, gemm_helper.right_op(), + gemm_helper.left_op(), n, m, k, factor_t, right.data(), ldb, + left.data(), lda, zero, result.data(), ldc, queue); + + device::sync_madness_task_with(stream); + return result; +} + +/// gemm: result += factor * left * right +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +inline void gemm(UMTensor& result, const UMTensor& left, + const UMTensor& right, const Scalar factor, + const TiledArray::math::GemmHelper& gemm_helper) { + TA_ASSERT(!result.empty()); + TA_ASSERT(!left.empty()); + TA_ASSERT(!right.empty()); + TA_ASSERT(result.range().rank() == gemm_helper.result_rank()); + TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); + TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); + TA_ASSERT(left.nbatch() == 1 && right.nbatch() == 1 && result.nbatch() == 1); + + TA_ASSERT(gemm_helper.left_result_congruent(left.range().extent_data(), + result.range().extent_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_result_congruent(left.range().lobound_data(), + result.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_result_congruent(left.range().upbound_data(), + result.range().upbound_data())); + TA_ASSERT(gemm_helper.right_result_congruent(right.range().extent_data(), + result.range().extent_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.right_result_congruent(right.range().lobound_data(), + result.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.right_result_congruent(right.range().upbound_data(), + result.range().upbound_data())); + TA_ASSERT(gemm_helper.left_right_congruent(left.range().extent_data(), + right.range().extent_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(left.range().lobound_data(), + right.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(left.range().upbound_data(), + right.range().upbound_data())); + + auto& queue = blasqueue_for(result.range()); + const device::Stream stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(left); + detail::to_device(right); + detail::to_device(result); + + using TiledArray::math::blas::integer; + integer m = 1, n = 1, k = 1; + gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range()); + + const integer lda = std::max( + integer{1}, + gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m); + const integer ldb = std::max( + integer{1}, + gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k); + const integer ldc = std::max(integer{1}, n); + + const T factor_t = T(factor); + const T one(1); + + ::blas::gemm(::blas::Layout::ColMajor, gemm_helper.right_op(), + gemm_helper.left_op(), n, m, k, factor_t, right.data(), ldb, + left.data(), lda, one, result.data(), ldc, queue); + + device::sync_madness_task_with(stream); +} + +/// Array-level helpers: bulk to-host / to-device prefetch and conversions +/// between UMTensor-backed and host-Tensor-backed DistArrays. + +/// Prefetch every local tile of `array` to the host. Fences on the +/// containing world and globally synchronizes the device on exit. +template + requires TiledArray::detail::is_numeric_v +inline void to_host(TiledArray::DistArray, Policy>& array) { + auto prefetch = [](UMTensor& tile) { + auto stream = device::stream_for(tile.range()); + detail::to_host(tile); + device::sync_madness_task_with(stream); + }; + auto& world = array.world(); + for (auto it = array.pmap()->begin(); it != array.pmap()->end(); ++it) { + if (!array.is_zero(*it)) world.taskq.add(prefetch, array.find(*it)); + } + world.gop.fence(); + DeviceSafeCall(device::deviceSynchronize()); +} + +/// Prefetch every local tile of `array` to the device. Fences on the +/// containing world and globally synchronizes the device on exit. +template + requires TiledArray::detail::is_numeric_v +inline void to_device(TiledArray::DistArray, Policy>& array) { + auto prefetch = [](UMTensor& tile) { + auto stream = device::stream_for(tile.range()); + detail::to_device(tile); + device::sync_madness_task_with(stream); + }; + auto& world = array.world(); + for (auto it = array.pmap()->begin(); it != array.pmap()->end(); ++it) { + if (!array.is_zero(*it)) world.taskq.add(prefetch, array.find(*it)); + } + world.gop.fence(); + DeviceSafeCall(device::deviceSynchronize()); +} + +/// Convert a UMTensor-backed `DistArray` to one backed by a host tile type. +template +inline std::enable_if_t, + TiledArray::DistArray> +um_tensor_to_ta_tensor(const TiledArray::DistArray& um_array) { + auto convert_tile = [](const UMTile& tile) { + auto stream = device::stream_for(tile.range()); + detail::to_host(tile); + device::sync_madness_task_with(stream); + HostTile result(tile.range()); + std::copy_n(tile.data(), tile.total_size(), result.data()); + return result; + }; + auto out = to_new_tile_type(um_array, convert_tile); + um_array.world().gop.fence(); + return out; +} + +template +inline std::enable_if_t, + TiledArray::DistArray> +um_tensor_to_ta_tensor(const TiledArray::DistArray& um_array) { + return um_array; +} + +/// Convert a host-tile-backed `DistArray` to a UMTensor-backed one. +template +inline std::enable_if_t, + TiledArray::DistArray> +ta_tensor_to_um_tensor( + const TiledArray::DistArray& host_array) { + auto convert_tile = [](const HostTile& tile) { + UMTile result(tile.range()); + std::copy_n(tile.data(), tile.total_size(), result.data()); + detail::to_device(result); + return result; + }; + auto out = to_new_tile_type(host_array, convert_tile); + host_array.world().gop.fence(); + return out; +} + +template +inline std::enable_if_t, + TiledArray::DistArray> +ta_tensor_to_um_tensor( + const TiledArray::DistArray& host_array) { + return host_array; +} + +} // namespace TiledArray + +/// MADNESS archive specializations for UMTensor. +/// +/// `TA::Tensor::serialize(ar)` works on any allocator (the member just walks +/// `data() + range().volume() * nbatch()`), but UM data may be stale on the +/// host if a device kernel is in flight. The Store specialization prefetches +/// the tile back to the host before reading. +namespace madness { +namespace archive { + +template +struct ArchiveStoreImpl> { + static inline void store(const Archive& ar, + const TiledArray::UMTensor& t) { + if constexpr (TiledArray::detail::is_numeric_v) { + if (!t.empty()) { + auto stream = TiledArray::device::stream_for(t.range()); + TiledArray::detail::to_host(t); + TiledArray::device::sync_madness_task_with(stream); + } + } + const bool empty = t.empty(); + ar & empty; + if (!empty) { + ar & t.range(); + ar & t.nbatch(); + ar& madness::archive::wrap(t.data(), t.range().volume() * t.nbatch()); + } + } +}; + +template +struct ArchiveLoadImpl> { + static inline void load(const Archive& ar, TiledArray::UMTensor& t) { + bool empty = false; + ar & empty; + if (!empty) { + TiledArray::Range range; + std::size_t nbatch = 1; + ar & range; + ar & nbatch; + t = TiledArray::UMTensor( + std::move(range), typename TiledArray::UMTensor::nbatches(nbatch)); + ar& madness::archive::wrap(t.data(), t.range().volume() * t.nbatch()); + } else { + t = TiledArray::UMTensor(); + } + } +}; + +} // namespace archive +} // namespace madness + +/// extern template declarations for the UMTensor class. +namespace TiledArray { + +extern template class Tensor>; +extern template class Tensor>; +extern template class Tensor, + device_um_allocator>>; +extern template class Tensor, + device_um_allocator>>; +extern template class Tensor>; +extern template class Tensor>; + +} // namespace TiledArray + +#endif // TILEDARRAY_HAS_DEVICE + +#endif // TILEDARRAY_DEVICE_TENSOR_H diff --git a/src/TiledArray/fwd.h b/src/TiledArray/fwd.h index 00c36a5092..70aee401bf 100644 --- a/src/TiledArray/fwd.h +++ b/src/TiledArray/fwd.h @@ -142,6 +142,10 @@ template using btasUMTensorVarray = ::btas::Tensor>; +/// TA::Tensor backed by the unified-memory allocator, usable as a device tile +template +using UMTensor = Tensor>; + #endif // TILEDARRAY_HAS_DEVICE template diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 308a3dec1e..9b4a2061ea 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -114,7 +114,7 @@ set(ta_test_src_files ta_test.cpp ) if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) - list(APPEND ta_test_src_files librett.cpp expressions_device_um.cpp tensor_um.cpp) + list(APPEND ta_test_src_files librett.cpp expressions_device_um.cpp tensor_um.cpp expressions_device_tensor.cpp) endif() # if using C++20 must use Boost 1.74 or later: diff --git a/tests/expressions_device_tensor.cpp b/tests/expressions_device_tensor.cpp new file mode 100644 index 0000000000..126e17c3b4 --- /dev/null +++ b/tests/expressions_device_tensor.cpp @@ -0,0 +1,937 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + */ + +#include + +#ifdef TILEDARRAY_HAS_DEVICE + +#include +#include +#include + +#include +#include + +#include "unit_test_config.h" + +using namespace TiledArray; + +// Expression-engine tests for the native UMTensor tile type (TA::Tensor +// backed by device_um_allocator). +// +// All correctness checks use a host side TiledArray::Tensor mirror +// of the input arrays. The expression runs through the engine for both sides; +// we then compare elements after. + +struct DeviceTensorExpressionsFixture : public TiledRangeFixture { + using TileD = UMTensor; + using TArrayD = TiledArray::DistArray; + + using HostTile = TiledArray::Tensor; + using HostArray = TiledArray::DistArray; + + static constexpr double tolerance = 5.0e-14; + + DeviceTensorExpressionsFixture() + : a(*GlobalFixture::world, tr), + b(*GlobalFixture::world, tr), + c(*GlobalFixture::world, tr), + a_h(*GlobalFixture::world, tr), + b_h(*GlobalFixture::world, tr), + c_h(*GlobalFixture::world, tr) { + fill_with_seed(a, a_h, 7); + fill_with_seed(b, b_h, 11); + GlobalFixture::world->gop.fence(); + } + + ~DeviceTensorExpressionsFixture() { GlobalFixture::world->gop.fence(); } + + // Fill paired device + host arrays with the same deterministic data so + // the host array is an exact reference for the device expression result. + template + static void fill_with_seed(DeviceArray& d, HostArrayT& h, int seed) { + auto pmap_d = d.pmap(); + for (auto it = pmap_d->begin(); it != pmap_d->end(); ++it) { + const auto tile_range = d.trange().make_tile_range(*it); + const auto vol = tile_range.volume(); + + std::mt19937 rng(seed + static_cast(*it)); + std::uniform_real_distribution dist(-5.0, 5.0); + + typename DeviceArray::value_type d_tile(tile_range); + typename HostArrayT::value_type h_tile(tile_range); + for (std::size_t k = 0; k < vol; ++k) { + const double v = dist(rng); + d_tile.data()[k] = v; + h_tile.data()[k] = v; + } + d.set(*it, d_tile); + h.set(*it, h_tile); + } + } + + // Compare every element of two DistArrays with matching tiles. + template + static void check_close(const DeviceArrayT& d, const HostArrayT& h_ref, + double tol) { + GlobalFixture::world->gop.fence(); + for (auto it = d.begin(); it != d.end(); ++it) { + auto d_tile = it->get(); + auto h_tile = h_ref.find(it.index()).get(); + BOOST_REQUIRE_EQUAL(d_tile.range(), h_tile.range()); + for (std::size_t k = 0; k < d_tile.size(); ++k) { + BOOST_CHECK_CLOSE_FRACTION(d_tile.data()[k], h_tile.data()[k], tol); + } + } + } + + TArrayD a, b, c; + HostArray a_h, b_h, c_h; +}; + +BOOST_FIXTURE_TEST_SUITE(device_tensor_expressions_suite, + DeviceTensorExpressionsFixture) + +BOOST_AUTO_TEST_CASE(is_device_tile_classification) { + using detail::is_device_tile_v; + BOOST_CHECK(is_device_tile_v>); + BOOST_CHECK(is_device_tile_v>); + BOOST_CHECK(is_device_tile_v>>); + BOOST_CHECK(is_device_tile_v>>); + BOOST_CHECK(is_device_tile_v); + BOOST_CHECK(!is_device_tile_v); +} + +BOOST_AUTO_TEST_CASE(direct_assign) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c")); + c_h("a,b,c") = a_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a")); + c_h("a,b,c") = a_h("c,b,a"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 2.5 * a("a,b,c")); + c_h("a,b,c") = 2.5 * a_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(neg) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = -a("a,b,c")); + c_h("a,b,c") = -a_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(add) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") + b("a,b,c")); + c_h("a,b,c") = a_h("a,b,c") + b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(add_with_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a") + b("a,b,c")); + c_h("a,b,c") = a_h("c,b,a") + b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(add_to) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c")); + c_h("a,b,c") = a_h("a,b,c"); + BOOST_REQUIRE_NO_THROW(c("a,b,c") += b("a,b,c")); + c_h("a,b,c") += b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(subt) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") - b("a,b,c")); + c_h("a,b,c") = a_h("a,b,c") - b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(subt_to) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c")); + c_h("a,b,c") = a_h("a,b,c"); + BOOST_REQUIRE_NO_THROW(c("a,b,c") -= b("a,b,c")); + c_h("a,b,c") -= b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scaled_subt_right) { + // Isolate: scale-on-right only. `c = a - 3*b`. + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") - 3.0 * b("a,b,c")); + c_h("a,b,c") = a_h("a,b,c") - 3.0 * b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scaled_subt_left) { + // Isolate: scale-on-left only. `c = 2*a - b`. + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 2.0 * a("a,b,c") - b("a,b,c")); + c_h("a,b,c") = 2.0 * a_h("a,b,c") - b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(mixed_linear_combination) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 2.0 * a("a,b,c") - 3.0 * b("a,b,c")); + c_h("a,b,c") = 2.0 * a_h("a,b,c") - 3.0 * b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(hadamard) { + // C(ijk) = A(ijk) .* B(ijk), element-wise multiplication + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") * b("a,b,c")); + c_h("a,b,c") = a_h("a,b,c") * b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(contraction) { + // C(i,k) = A(i,j) * B(j,k) requires rank-2 arrays; build them on the fly + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 13); + fill_with_seed(b2, b2_h, 17); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(c2("i,k") = a2("i,j") * b2("j,k")); + c2_h("i,k") = a2_h("i,j") * b2_h("j,k"); + // GEMM tolerance: float-add reordering between BLAS and CPU Eigen path. + check_close(c2, c2_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(norm2_value) { + const double dev_norm = TA::norm2(a); + const double host_norm = TA::norm2(a_h); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(dev_norm, host_norm, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(dot_value) { + // dot expression: scalar = a . b + double dev_dot = static_cast(a("a,b,c") * b("a,b,c")); + double host_dot = static_cast(a_h("a,b,c") * b_h("a,b,c")); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(dev_dot, host_dot, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(reduce_factories) { + // Expression-level reductions through the DSL: scalar = a("...").reduce() + GlobalFixture::world->gop.fence(); + + BOOST_CHECK_CLOSE_FRACTION(a("a,b,c").sum().get(), a_h("a,b,c").sum().get(), + 1.0e-12); + BOOST_CHECK_CLOSE_FRACTION(a("a,b,c").squared_norm().get(), + a_h("a,b,c").squared_norm().get(), 1.0e-12); + BOOST_CHECK_CLOSE_FRACTION(a("a,b,c").norm().get(), a_h("a,b,c").norm().get(), + 1.0e-12); + BOOST_CHECK_EQUAL(a("a,b,c").min().get(), a_h("a,b,c").min().get()); + BOOST_CHECK_EQUAL(a("a,b,c").max().get(), a_h("a,b,c").max().get()); + BOOST_CHECK_EQUAL(a("a,b,c").abs_min().get(), a_h("a,b,c").abs_min().get()); + BOOST_CHECK_EQUAL(a("a,b,c").abs_max().get(), a_h("a,b,c").abs_max().get()); + BOOST_CHECK_NO_THROW(auto _ = a("a,b,c").product().get()); +} + +BOOST_AUTO_TEST_CASE(reuse_stress) { + const double host_ref = static_cast(a_h("a,b,c") * a_h("a,b,c")); + GlobalFixture::world->gop.fence(); + for (int iter = 0; iter < 8; ++iter) { + const double d = static_cast(a("a,b,c") * a("a,b,c")); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(d, host_ref, 1.0e-12); + } +} + +/// In-place expression operators (+=, -=, *=) +BOOST_AUTO_TEST_CASE(plus_equal_expr) { + c("a,b,c") = a("a,b,c"); + c_h("a,b,c") = a_h("a,b,c"); + c("a,b,c") += b("a,b,c"); + c_h("a,b,c") += b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(plus_equal_with_permute) { + c("a,b,c") = a("a,b,c"); + c_h("a,b,c") = a_h("a,b,c"); + c("a,b,c") += b("c,b,a"); + c_h("a,b,c") += b_h("c,b,a"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(minus_equal_expr) { + c("a,b,c") = a("a,b,c"); + c_h("a,b,c") = a_h("a,b,c"); + c("a,b,c") -= b("a,b,c"); + c_h("a,b,c") -= b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(times_equal_expr) { + c("a,b,c") = a("a,b,c"); + c_h("a,b,c") = a_h("a,b,c"); + c("a,b,c") *= b("a,b,c"); + c_h("a,b,c") *= b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(neg_scaled_sum) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = -(2.0 * (a("a,b,c") + b("a,b,c")))); + c_h("a,b,c") = -(2.0 * (a_h("a,b,c") + b_h("a,b,c"))); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(neg_permuted) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = -a("c,b,a")); + c_h("a,b,c") = -a_h("c,b,a"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale_with_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 2.5 * a("c,b,a")); + c_h("a,b,c") = 2.5 * a_h("c,b,a"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(subt_with_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a") - b("a,b,c")); + c_h("a,b,c") = a_h("c,b,a") - b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(mult_with_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a") * b("a,b,c")); + c_h("a,b,c") = a_h("c,b,a") * b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +// .conj() on a real-valued tensor is a compile-time no-op in +// `apply_scale_factor`, but the expression DSL still has to parse and route +// through the conjugation factory. This verifies it does. Complex-typed +// fixtures are out of scope here. +BOOST_AUTO_TEST_CASE(conj_real) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c").conj()); + c_h("a,b,c") = a_h("a,b,c").conj(); + check_close(c, c_h, tolerance); +} + +/// Multi-step chains +BOOST_AUTO_TEST_CASE(multi_step_chain) { + TArrayD t(*GlobalFixture::world, tr); + HostArray t_h(*GlobalFixture::world, tr); + t("a,b,c") = a("a,b,c") + b("a,b,c"); + t_h("a,b,c") = a_h("a,b,c") + b_h("a,b,c"); + c("a,b,c") = 2.0 * t("a,b,c") - a("a,b,c"); + c_h("a,b,c") = 2.0 * t_h("a,b,c") - a_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +/// Block expressions +BOOST_AUTO_TEST_CASE(block_assign) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + + // Result range matches the block's element range; build small companion + // arrays to receive the result. + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = a("a,b,c").block(lo, up); + blk_h("a,b,c") = a_h("a,b,c").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(block_add_then_scale) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = 2.0 * (a("a,b,c").block(lo, up) + b("a,b,c").block(lo, up)); + blk_h("a,b,c") = + 2.0 * (a_h("a,b,c").block(lo, up) + b_h("a,b,c").block(lo, up)); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(block_accumulate) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = a("a,b,c").block(lo, up); + blk_h("a,b,c") = a_h("a,b,c").block(lo, up); + blk_d("a,b,c") += b("a,b,c").block(lo, up); + blk_h("a,b,c") += b_h("a,b,c").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(const_block) { + const auto& ca = a; + const auto& ca_h = a_h; + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = ca("a,b,c").block(lo, up); + blk_h("a,b,c") = ca_h("a,b,c").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scal_block) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + blk_d("a,b,c") = 2.0 * a("a,b,c").block(lo, up); + blk_h("a,b,c") = 2.0 * a_h("a,b,c").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(permute_block) { + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + const TiledRange ctr{TiledRange1{lo[0], up[0]}, TiledRange1{lo[1], up[1]}, + TiledRange1{lo[2], up[2]}}; + TArrayD blk_d(*GlobalFixture::world, ctr); + HostArray blk_h(*GlobalFixture::world, ctr); + // Permute the source annotation before slicing. + blk_d("a,b,c") = a("c,b,a").block(lo, up); + blk_h("a,b,c") = a_h("c,b,a").block(lo, up); + check_close(blk_d, blk_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(assign_sub_block) { + c.fill_local(0.0); + c_h.fill_local(0.0); + GlobalFixture::world->gop.fence(); + + const std::array lo{3, 3, 3}; + const std::array up{5, 5, 5}; + BOOST_REQUIRE_NO_THROW(c("a,b,c").block(lo, up) = a("a,b,c").block(lo, up)); + c_h("a,b,c").block(lo, up) = a_h("a,b,c").block(lo, up); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(block_contract) { + const TiledRange tr_w{tr.data()[0], tr.data()[1]}; + TArrayD w(*GlobalFixture::world, tr_w); + HostArray w_h(*GlobalFixture::world, tr_w); + w.fill_local(0.0); + w_h.fill_local(0.0); + GlobalFixture::world->gop.fence(); + + const std::array alo{3, 2, 3}; + const std::array aup{5, 5, 5}; + const std::array blo{2, 3, 3}; + const std::array bup{5, 5, 5}; + + BOOST_REQUIRE_NO_THROW(w("a,b") = a("a,c,d").block(alo, aup) * + b("c,d,b").block(blo, bup)); + w_h("a,b") = a_h("a,c,d").block(alo, aup) * b_h("c,d,b").block(blo, bup); + check_close(w, w_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(block_permute_contract) { + const TiledRange tr_w{tr.data()[0], tr.data()[1]}; + TArrayD w(*GlobalFixture::world, tr_w); + HostArray w_h(*GlobalFixture::world, tr_w); + w.fill_local(0.0); + w_h.fill_local(0.0); + GlobalFixture::world->gop.fence(); + + const std::array alo{3, 3, 2}; + const std::array aup{5, 5, 5}; + const std::array blo{2, 3, 3}; + const std::array bup{5, 5, 5}; + + BOOST_REQUIRE_NO_THROW(w("a,b") = a("a,d,c").block(alo, aup) * + b("c,d,b").block(blo, bup)); + w_h("a,b") = a_h("a,d,c").block(alo, aup) * b_h("c,d,b").block(blo, bup); + check_close(w, w_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(outer_product) { + const TiledRange tr_u{tr.data()[0]}; + const TiledRange tr_v{tr.data()[1]}; + const TiledRange tr_w{tr.data()[0], tr.data()[1]}; + + TArrayD u(*GlobalFixture::world, tr_u); + TArrayD v(*GlobalFixture::world, tr_v); + TArrayD w; + HostArray u_h(*GlobalFixture::world, tr_u); + HostArray v_h(*GlobalFixture::world, tr_v); + HostArray w_h; + fill_with_seed(u, u_h, 31); + fill_with_seed(v, v_h, 37); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = u("i") * v("j")); + w_h("i,j") = u_h("i") * v_h("j"); + check_close(w, w_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(contraction_permuted_result) { + // c(k,i) = a(i,j) * b(j,k) + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 41); + fill_with_seed(b2, b2_h, 43); + GlobalFixture::world->gop.fence(); + BOOST_REQUIRE_NO_THROW(c2("k,i") = a2("i,j") * b2("j,k")); + c2_h("k,i") = a2_h("i,j") * b2_h("j,k"); + // Looser tolerance for permuted GEMM: BLAS sums in different + // tile-internal order than the CPU reference path. + check_close(c2, c2_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(contraction_with_transpose_on_right) { + // c(i,k) = a(i,j) * b(k,j) + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 47); + fill_with_seed(b2, b2_h, 53); + GlobalFixture::world->gop.fence(); + BOOST_REQUIRE_NO_THROW(c2("i,k") = a2("i,j") * b2("k,j")); + c2_h("i,k") = a2_h("i,j") * b2_h("k,j"); + check_close(c2, c2_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(contraction_rank4_via_two_indices) { + // r(a,c) = t(a,b,k,l) * v(c,b,k,l) + const TiledRange tr4{tr.data()[0], tr.data()[1], tr.data()[2], tr.data()[2]}; + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD t(*GlobalFixture::world, tr4); + TArrayD v(*GlobalFixture::world, tr4); + TArrayD r; + HostArray t_h(*GlobalFixture::world, tr4); + HostArray v_h(*GlobalFixture::world, tr4); + HostArray r_h; + fill_with_seed(t, t_h, 59); + fill_with_seed(v, v_h, 61); + GlobalFixture::world->gop.fence(); + BOOST_REQUIRE_NO_THROW(r("a,c") = t("a,b,k,l") * v("c,b,k,l")); + r_h("a,c") = t_h("a,b,k,l") * v_h("c,b,k,l"); + check_close(r, r_h, 1.0e-12); +} + +/// TA::einsum +BOOST_AUTO_TEST_CASE(einsum_matmul) { + // c(i,k) = a(i,j) * b(j,k) via einsum + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + fill_with_seed(a2, a2_h, 67); + fill_with_seed(b2, b2_h, 71); + GlobalFixture::world->gop.fence(); + + auto c2 = TiledArray::einsum(a2("i,j"), b2("j,k"), "i,k"); + auto c2_h = TiledArray::einsum(a2_h("i,j"), b2_h("j,k"), "i,k"); + check_close(c2, c2_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(einsum_hadamard) { + // c(i,j) = a(i,j) * b(i,j) via einsum: Hadamard / element-wise multiply + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + fill_with_seed(a2, a2_h, 73); + fill_with_seed(b2, b2_h, 79); + GlobalFixture::world->gop.fence(); + + auto c2 = TiledArray::einsum(a2("i,j"), b2("i,j"), "i,j"); + auto c2_h = TiledArray::einsum(a2_h("i,j"), b2_h("i,j"), "i,j"); + check_close(c2, c2_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(einsum_contraction_over_two_indices) { + // c(a,c) = t(a,b,k) * v(c,b,k) via einsum: contraction over (b, k), + const TiledRange tr3{tr.data()[0], tr.data()[1], tr.data()[2]}; + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD t(*GlobalFixture::world, tr3); + TArrayD v(*GlobalFixture::world, tr3); + HostArray t_h(*GlobalFixture::world, tr3); + HostArray v_h(*GlobalFixture::world, tr3); + fill_with_seed(t, t_h, 83); + fill_with_seed(v, v_h, 89); + GlobalFixture::world->gop.fence(); + + auto r = TiledArray::einsum(t("a,b,k"), v("c,b,k"), "a,c"); + auto r_h = TiledArray::einsum(t_h("a,b,k"), v_h("c,b,k"), "a,c"); + check_close(r, r_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(rank2_transpose) { + // c(j,i) = a(i,j) + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + HostArray a2_h(*GlobalFixture::world, tr2); + fill_with_seed(a2, a2_h, 97); + GlobalFixture::world->gop.fence(); + + TArrayD a2T; + HostArray a2T_h; + a2T("j,i") = a2("i,j"); + a2T_h("j,i") = a2_h("i,j"); + check_close(a2T, a2T_h, tolerance); +} + +/// Scaled and permuted variants of the elementary arithmetic ops. +BOOST_AUTO_TEST_CASE(scale_add) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5.0 * (a("a,b,c") + b("a,b,c"))); + c_h("a,b,c") = 5.0 * (a_h("a,b,c") + b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale_add_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5.0 * (2.0 * a("c,b,a")) + (3.0 * b("a,b,c"))); + c_h("a,b,c") = 5.0 * (2.0 * a_h("c,b,a")) + (3.0 * b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(subt_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a") - b("a,b,c")); + c_h("a,b,c") = a_h("c,b,a") - b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale_subt) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5.0 * (a("a,b,c") - b("a,b,c"))); + c_h("a,b,c") = 5.0 * (a_h("a,b,c") - b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale_subt_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5.0 * (2.0 * a("c,b,a")) - (3.0 * b("a,b,c"))); + c_h("a,b,c") = 5.0 * (2.0 * a_h("c,b,a")) - (3.0 * b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(mult_permute) { + // Hadamard with permutation on left operand. + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a") * b("a,b,c")); + c_h("a,b,c") = a_h("c,b,a") * b_h("a,b,c"); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale_mult) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5.0 * (a("a,b,c") * b("a,b,c"))); + c_h("a,b,c") = 5.0 * (a_h("a,b,c") * b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(scale_mult_permute) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5.0 * (2.0 * a("c,b,a")) * (3.0 * b("a,b,c"))); + c_h("a,b,c") = 5.0 * (2.0 * a_h("c,b,a")) * (3.0 * b_h("a,b,c")); + check_close(c, c_h, tolerance); +} + +/// Scaled contraction variants +BOOST_AUTO_TEST_CASE(scale_cont) { + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 101); + fill_with_seed(b2, b2_h, 103); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(c2("i,k") = 5.0 * (a2("i,j") * b2("j,k"))); + c2_h("i,k") = 5.0 * (a2_h("i,j") * b2_h("j,k")); + check_close(c2, c2_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(scale_cont_permute) { + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 107); + fill_with_seed(b2, b2_h, 109); + GlobalFixture::world->gop.fence(); + + // c(k,i) = 5 * a(i,j) * b(j,k) + BOOST_REQUIRE_NO_THROW(c2("k,i") = 5.0 * (a2("i,j") * b2("j,k"))); + c2_h("k,i") = 5.0 * (a2_h("i,j") * b2_h("j,k")); + check_close(c2, c2_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(scale_cont_with_input_transpose) { + // 5 * a(i,j) * b(k,j) + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2; + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h; + fill_with_seed(a2, a2_h, 113); + fill_with_seed(b2, b2_h, 127); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(c2("i,k") = 5.0 * (a2("i,j") * b2("k,j"))); + c2_h("i,k") = 5.0 * (a2_h("i,j") * b2_h("k,j")); + check_close(c2, c2_h, 1.0e-12); +} + +/// Non-uniform tile sizes for contraction. +BOOST_AUTO_TEST_CASE(cont_non_uniform_split_inner) { + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_2, tr1_1, tr1_1}}; + TiledRange tr_irr(tiling4.begin(), tiling4.end()); + + TArrayD lhs(*GlobalFixture::world, tr_irr); + TArrayD rhs(*GlobalFixture::world, tr_irr); + TArrayD out; + HostArray lhs_h(*GlobalFixture::world, tr_irr); + HostArray rhs_h(*GlobalFixture::world, tr_irr); + HostArray out_h; + fill_with_seed(lhs, lhs_h, 131); + fill_with_seed(rhs, rhs_h, 137); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(out("x,y") = 5.0 * (lhs("x,i,j,k") * rhs("y,i,j,k"))); + out_h("x,y") = 5.0 * (lhs_h("x,i,j,k") * rhs_h("y,i,j,k")); + check_close(out, out_h, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(cont_non_uniform_split_two_inner) { + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_1, tr1_2, tr1_2}}; + TiledRange tr_irr(tiling4.begin(), tiling4.end()); + + TArrayD lhs(*GlobalFixture::world, tr_irr); + TArrayD rhs(*GlobalFixture::world, tr_irr); + TArrayD out; + HostArray lhs_h(*GlobalFixture::world, tr_irr); + HostArray rhs_h(*GlobalFixture::world, tr_irr); + HostArray out_h; + fill_with_seed(lhs, lhs_h, 139); + fill_with_seed(rhs, rhs_h, 149); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(out("x,y") = 5.0 * (lhs("x,i,j,k") * rhs("y,i,j,k"))); + out_h("x,y") = 5.0 * (lhs_h("x,i,j,k") * rhs_h("y,i,j,k")); + check_close(out, out_h, 1.0e-12); +} + +/// Contraction-plus-reduction +BOOST_AUTO_TEST_CASE(cont_plus_reduce) { + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + fill_with_seed(a2, a2_h, 151); + fill_with_seed(b2, b2_h, 157); + GlobalFixture::world->gop.fence(); + + TArrayD c2; + HostArray c2_h; + c2("i,k") = a2("i,j") * b2("j,k"); + c2_h("i,k") = a2_h("i,j") * b2_h("j,k"); + const double dev_n = TA::norm2(c2); + const double host_n = TA::norm2(c2_h); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(dev_n, host_n, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(no_alias_plus_reduce) { + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + TArrayD c2(*GlobalFixture::world, TiledRange{tr.data()[0], tr.data()[1]}); + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + HostArray c2_h(*GlobalFixture::world, TiledRange{tr.data()[0], tr.data()[1]}); + fill_with_seed(a2, a2_h, 163); + fill_with_seed(b2, b2_h, 167); + c2.fill_local(0.0); + c2_h.fill_local(0.0); + GlobalFixture::world->gop.fence(); + + BOOST_REQUIRE_NO_THROW(c2("i,k").no_alias() = a2("i,j") * b2("j,k")); + c2_h("i,k").no_alias() = a2_h("i,j") * b2_h("j,k"); + check_close(c2, c2_h, 1.0e-12); + const double dev_n = TA::norm2(c2); + const double host_n = TA::norm2(c2_h); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(dev_n, host_n, 1.0e-12); +} + +/// Dot-product variants +BOOST_AUTO_TEST_CASE(dot_permute) { + const double dev_d = static_cast(a("a,b,c") * b("c,b,a")); + const double host_d = static_cast(a_h("a,b,c") * b_h("c,b,a")); + GlobalFixture::world->gop.fence(); + // Looser tolerance because permuted dot reads tiles in a different + // order, so the partial-sum accumulation order differs. + BOOST_CHECK_CLOSE_FRACTION(dev_d, host_d, 1.0e-12); +} + +BOOST_AUTO_TEST_CASE(dot_contr) { + // Dot of two contraction expressions: scalar = (a*b) . (b*a). + const TiledRange tr2{tr.data()[0], tr.data()[1]}; + TArrayD a2(*GlobalFixture::world, tr2); + TArrayD b2(*GlobalFixture::world, tr2); + HostArray a2_h(*GlobalFixture::world, tr2); + HostArray b2_h(*GlobalFixture::world, tr2); + fill_with_seed(a2, a2_h, 173); + fill_with_seed(b2, b2_h, 179); + GlobalFixture::world->gop.fence(); + + const double dev_d = + static_cast((a2("i,j") * b2("j,k")) * (a2("i,j") * b2("j,k"))); + const double host_d = static_cast((a2_h("i,j") * b2_h("j,k")) * + (a2_h("i,j") * b2_h("j,k"))); + GlobalFixture::world->gop.fence(); + BOOST_CHECK_CLOSE_FRACTION(dev_d, host_d, 1.0e-12); +} + +/// Archive round-trip, host/device array conversions +BOOST_AUTO_TEST_CASE(serialize_um_tensor) { + // Single-tile round-trip: build a UMTensor, write to a buffer archive, + // read into a fresh UMTensor, compare element-wise. The Store side + // forces a host prefetch so the archive sees coherent data. + TileD t(TiledArray::Range{4, 4}); + for (std::size_t k = 0; k < t.size(); ++k) + t.data()[k] = static_cast(k) - 5.0; + + const std::size_t buf_size = + (t.range().volume() * sizeof(double) + + sizeof(std::size_t) * (t.range().rank() * 4 + 4)) * + 2; + std::vector buf(buf_size); + + madness::archive::BufferOutputArchive oar(buf.data(), buf.size()); + BOOST_REQUIRE_NO_THROW(oar & t); + const std::size_t nbyte = oar.size(); + oar.close(); + + TileD u; + madness::archive::BufferInputArchive iar(buf.data(), nbyte); + BOOST_REQUIRE_NO_THROW(iar & u); + iar.close(); + + BOOST_REQUIRE_EQUAL(t.range(), u.range()); + for (std::size_t k = 0; k < t.size(); ++k) + BOOST_CHECK_CLOSE(t.data()[k], u.data()[k], 1.0e-15); +} + +BOOST_AUTO_TEST_CASE(serialize_um_tensor_empty) { + // Empty-tensor round-trip: the empty branch in Store/Load. + TileD t; + std::vector buf(1024); + madness::archive::BufferOutputArchive oar(buf.data(), buf.size()); + BOOST_REQUIRE_NO_THROW(oar & t); + const std::size_t nbyte = oar.size(); + oar.close(); + + TileD u(TiledArray::Range{2, 2}); // start non-empty so load has work + madness::archive::BufferInputArchive iar(buf.data(), nbyte); + BOOST_REQUIRE_NO_THROW(iar & u); + iar.close(); + BOOST_CHECK(u.empty()); +} + +BOOST_AUTO_TEST_CASE(um_to_ta_round_trip) { + // UMTensor array -> host array -> UMTensor array, verify element-wise + // against the original both at the host and device endpoints. + HostArray host_view = + TiledArray::um_tensor_to_ta_tensor(a); + GlobalFixture::world->gop.fence(); + // Compare host_view directly against a_h + check_close(host_view, a_h, tolerance); + + TArrayD device_view = + TiledArray::ta_tensor_to_um_tensor( + host_view); + GlobalFixture::world->gop.fence(); + // Round-trip: device_view should match the original `a`. + check_close(device_view, a_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(um_to_ta_then_expression) { + HostArray sum_h(*GlobalFixture::world, tr); + sum_h("a,b,c") = a_h("a,b,c") + b_h("a,b,c"); + + HostArray converted = + TiledArray::um_tensor_to_ta_tensor(a); + HostArray sum_from_device(*GlobalFixture::world, tr); + sum_from_device("a,b,c") = converted("a,b,c") + b_h("a,b,c"); + + GlobalFixture::world->gop.fence(); + check_close(sum_from_device, sum_h, tolerance); +} + +BOOST_AUTO_TEST_CASE(bulk_prefetch_round_trip) { + // to_host / to_device on a DistArray should be no-ops for correctness: + TiledArray::to_host(a); + TiledArray::to_device(a); + GlobalFixture::world->gop.fence(); + check_close(a, a_h, tolerance); +} + +BOOST_AUTO_TEST_SUITE_END() + +#endif // TILEDARRAY_HAS_DEVICE