From 5b8697009fae8d80123566f32c840b931c1bc5ce Mon Sep 17 00:00:00 2001 From: "Ryan M. Richard" Date: Fri, 15 May 2026 11:59:48 -0500 Subject: [PATCH 1/2] fixes bug in dispatch --- include/tensorwrapper/dsl/pairwise_parser.hpp | 4 +- .../tensorwrapper/buffer/contiguous.cpp | 304 ++++++++++++++++++ 2 files changed, 306 insertions(+), 2 deletions(-) diff --git a/include/tensorwrapper/dsl/pairwise_parser.hpp b/include/tensorwrapper/dsl/pairwise_parser.hpp index 99849f1e..2197c01d 100644 --- a/include/tensorwrapper/dsl/pairwise_parser.hpp +++ b/include/tensorwrapper/dsl/pairwise_parser.hpp @@ -145,8 +145,8 @@ class PairwiseParser { dispatch(lA, rhs.lhs()); lhs.object().scalar_multiplication(lhs.labels(), rhs.rhs(), lA); } else { - auto pA = lhs.object().clone(); - auto pB = lhs.object().clone(); + auto pA = rhs.lhs().object().clone(); + auto pB = rhs.rhs().object().clone(); auto labels = lhs.labels(); auto lA = (*pA)(labels); auto lB = (*pB)(labels); diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp index f4c954ef..ebd1e2d8 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp @@ -592,3 +592,307 @@ TEMPLATE_LIST_TEST_CASE("make_contiguous(shape, value)", "", Contiguous contig = buffer::make_contiguous(shape, init); REQUIRE(contig == corr); } + +TEST_CASE("interval contraction") { +#ifdef ENABLE_SIGMA + using interval_type = sigma::Interval; + tensorwrapper::shape::Smooth a_shape({4, 4}); + tensorwrapper::shape::Smooth b_shape({4, 4, 4, 4}); + std::vector a_buffer{interval_type(0.16477052), interval_type(0.27613426), + interval_type(-0.21520682), interval_type(-0.15798025), + interval_type(0.27613426), interval_type(0.64734094), + interval_type(-0.10065438), interval_type(-0.21520682), + interval_type(-0.21520682), interval_type(-0.10065438), + interval_type(0.64734094), interval_type(0.27613426), + interval_type(-0.15798025), interval_type(-0.21520682), + interval_type(0.27613426), interval_type(0.16477052)}; + + std::vector b_buffer{interval_type(0.7746059442114875, 0.7746059442114875), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.5699943521033807, 0.5699943521033807), + interval_type(0.2671325323063211, 0.2671325323063211), + interval_type(0.0842293132408630, 0.0842293392002486), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.2671325323063211, 0.2671325323063211), + interval_type(0.3360221878482700, 0.3360221878482700), + interval_type(0.1822547411402173, 0.1822547411402173), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.0842293132408630, 0.0842293392002486), + interval_type(0.1822547411402173, 0.1822547411402173), + interval_type(0.2309891154471100, 0.2309891154471100), + interval_type(0.4445903196340242, 0.4445903196340242), + interval_type(0.2975896157617893, 0.2975896157617893), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.2975896157617893, 0.2975896157617893), + interval_type(0.4445903196340242, 0.4445903196340242), + interval_type(0.2180688930726074, 0.2180764868443568), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.1822547411402173, 0.1822547411402173), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.0322095731006274, 0.0322148126672940), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0322095731006274, 0.0322148126672940), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0025484886593729, 0.0025550642433567), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0025484886593729, 0.0025550642433567), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.4445903196340242, 0.4445903196340242), + interval_type(0.2975896157617893, 0.2975896157617893), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.2975896157617893, 0.2975896157617893), + interval_type(0.4445903196340242, 0.4445903196340242), + interval_type(0.2180688930726074, 0.2180764868443568), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.2180688930726074, 0.2180764868443568), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.1822547411402174, 0.1822547411402174), + interval_type(0.5699943521033807, 0.5699943521033807), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.7746059442114875, 0.7746059442114875), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.5445965180667598, 0.5445965180667598), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.2801437259785110, 0.2801437259785110), + interval_type(0.3360221878482700, 0.3360221878482700), + interval_type(0.2671325323063211, 0.2671325323063211), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.2545062652576217, 0.2545102403725139), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.2545062652576217, 0.2545102403725139), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.2671325323063212, 0.2671325323063212), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.0322095731006275, 0.0322148126672941), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0322095731006275, 0.0322148126672941), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.0322095731006274, 0.0322148126672940), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0322095731006274, 0.0322148126672940), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.2671325323063212, 0.2671325323063212), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.2545062652576217, 0.2545102403725139), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.2545062652576217, 0.2545102403725139), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.2671325323063212, 0.2671325323063212), + interval_type(0.3360221878482700, 0.3360221878482700), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.5445965180667598, 0.5445965180667598), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.1517471045021499, 0.1517471304615355), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.7746059442114875, 0.7746059442114875), + interval_type(0.4445903196340244, 0.4445903196340244), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.4445903196340244, 0.4445903196340244), + interval_type(0.5699943521033808, 0.5699943521033808), + interval_type(0.1822547411402174, 0.1822547411402174), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.2801437259785110, 0.2801437259785110), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.4445903196340245, 0.4445903196340245), + interval_type(0.2975896157617894, 0.2975896157617894), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.2975896157617894, 0.2975896157617894), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0025484886593729, 0.0025550642433567), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0025484886593729, 0.0025550642433567), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.0322095731006275, 0.0322148126672941), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.1517471045021499, 0.1517471304615355), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0322095731006275, 0.0322148126672941), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.1822547411402173, 0.1822547411402173), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.1478475949866730, 0.1478475949866730), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.2975896157617894, 0.2975896157617894), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.2975896157617894, 0.2975896157617894), + interval_type(0.4445903196340245, 0.4445903196340245), + interval_type(0.2309891154471100, 0.2309891154471100), + interval_type(0.1822547411402173, 0.1822547411402173), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.1822547411402174, 0.1822547411402174), + interval_type(0.3360221878482700, 0.3360221878482700), + interval_type(0.2671325323063212, 0.2671325323063212), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.2671325323063212, 0.2671325323063212), + interval_type(0.5699943521033808, 0.5699943521033808), + interval_type(0.4445903196340244, 0.4445903196340244), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.4445903196340244, 0.4445903196340244), + interval_type(0.7746059442114875, 0.7746059442114875)}; + buffer::Contiguous a_tensor(a_buffer, a_shape); + buffer::Contiguous b_tensor(b_buffer, b_shape); + + std::vector c_buffer(a_shape.size(), interval_type(0.0)); + buffer::Contiguous c_tensor(c_buffer, a_shape); + c_tensor("k,l") = a_tensor("i,j") * b_tensor("k,l,j,i"); + + std::vector corr{ + interval_type(0.9438257539599672, 0.9438276901980663), + interval_type(0.6981236836099820, 0.6981338986433229), + interval_type(0.2572322079970916, 0.2572504879533325), + interval_type(0.0754708135852556, 0.0754973495884852), + interval_type(0.6981236836099821, 0.6981338986433229), + interval_type(1.1891028617809587, 1.1891047980190586), + interval_type(0.7219623782574144, 0.7219784457327180), + interval_type(0.2572322079970916, 0.2572504879533324), + interval_type(0.2572322079970916, 0.2572504879533325), + interval_type(0.7219623782574145, 0.7219784457327181), + interval_type(1.1891028617809591, 1.1891047980190588), + interval_type(0.6981236836099824, 0.6981338986433231), + interval_type(0.0754708135852556, 0.0754973495884852), + interval_type(0.2572322079970917, 0.2572504879533325), + interval_type(0.6981236836099823, 0.6981338986433230), + interval_type(0.9438257539599676, 0.9438276901980664)}; + + auto c_tensor_buffer = buffer::get_raw_data(c_tensor); + REQUIRE(c_tensor_buffer.size() == corr.size()); + for(size_t i = 0; i < c_tensor_buffer.size(); ++i) { + REQUIRE(c_tensor_buffer[i].lower() == Catch::Approx(corr[i].lower())); + REQUIRE(c_tensor_buffer[i].upper() == Catch::Approx(corr[i].upper())); + } +#endif +} From b75791021984a777c67348d9c121658abe5e80da Mon Sep 17 00:00:00 2001 From: "Ryan M. Richard" Date: Wed, 20 May 2026 11:57:18 -0500 Subject: [PATCH 2/2] adds element wise power and more --- .../tensorwrapper/operations/operations.hpp | 1 + include/tensorwrapper/operations/power.hpp | 24 +++++++ .../tensorwrapper/types/floating_point.hpp | 16 +++++ .../utilities/diagonal_matrix.hpp | 24 +++++++ .../tensorwrapper/utilities/make_tensor.hpp | 37 ++++++++++ include/tensorwrapper/utilities/utilities.hpp | 2 + src/tensorwrapper/operations/power.cpp | 50 +++++++++++++ .../utilities/diagonal_matrix.cpp | 48 +++++++++++++ .../tensorwrapper/operations/power.cpp | 45 ++++++++++++ .../utilities/diagonal_matrix.cpp | 33 +++++++++ .../tensorwrapper/utilities/make_tensor.cpp | 70 +++++++++++++++++++ 11 files changed, 350 insertions(+) create mode 100644 include/tensorwrapper/operations/power.hpp create mode 100644 include/tensorwrapper/utilities/diagonal_matrix.hpp create mode 100644 include/tensorwrapper/utilities/make_tensor.hpp create mode 100644 src/tensorwrapper/operations/power.cpp create mode 100644 src/tensorwrapper/utilities/diagonal_matrix.cpp create mode 100644 tests/cxx/unit_tests/tensorwrapper/operations/power.cpp create mode 100644 tests/cxx/unit_tests/tensorwrapper/utilities/diagonal_matrix.cpp create mode 100644 tests/cxx/unit_tests/tensorwrapper/utilities/make_tensor.cpp diff --git a/include/tensorwrapper/operations/operations.hpp b/include/tensorwrapper/operations/operations.hpp index b913bd97..0b9854a1 100644 --- a/include/tensorwrapper/operations/operations.hpp +++ b/include/tensorwrapper/operations/operations.hpp @@ -17,6 +17,7 @@ #pragma once #include #include +#include /// Namespace for free functions that act on tensors namespace tensorwrapper::operations {} diff --git a/include/tensorwrapper/operations/power.hpp b/include/tensorwrapper/operations/power.hpp new file mode 100644 index 00000000..f59a5721 --- /dev/null +++ b/include/tensorwrapper/operations/power.hpp @@ -0,0 +1,24 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace tensorwrapper::operations { + +/** @brief Appplies a power element-wise to a tensor */ +Tensor power(Tensor A, double pow); +} // namespace tensorwrapper::operations diff --git a/include/tensorwrapper/types/floating_point.hpp b/include/tensorwrapper/types/floating_point.hpp index 72d75637..7e2de193 100644 --- a/include/tensorwrapper/types/floating_point.hpp +++ b/include/tensorwrapper/types/floating_point.hpp @@ -78,6 +78,17 @@ T exp(T value) { } } +template +T pow(T value, double pow) { + if constexpr(is_uncertain_v) { + return sigma::pow(value, pow); + } else if constexpr(is_interval_v) { + return T(sigma::pow(value, pow)); + } else { + return std::pow(value, pow); + } +} + #define TW_APPLY_FLOATING_POINT_TYPES(MACRO_IN) \ MACRO_IN(float); \ MACRO_IN(double); \ @@ -125,6 +136,11 @@ T exp(T value) { return std::exp(value); } +template +T pow(T value, double pow) { + return std::pow(value, pow); +} + #define TW_APPLY_FLOATING_POINT_TYPES(MACRO_IN) \ MACRO_IN(float); \ MACRO_IN(double) diff --git a/include/tensorwrapper/utilities/diagonal_matrix.hpp b/include/tensorwrapper/utilities/diagonal_matrix.hpp new file mode 100644 index 00000000..d205d190 --- /dev/null +++ b/include/tensorwrapper/utilities/diagonal_matrix.hpp @@ -0,0 +1,24 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace tensorwrapper::utilities { + +Tensor diagonal_matrix(const Tensor& diagonal_elements); + +} // namespace tensorwrapper::utilities diff --git a/include/tensorwrapper/utilities/make_tensor.hpp b/include/tensorwrapper/utilities/make_tensor.hpp new file mode 100644 index 00000000..6bb84cbd --- /dev/null +++ b/include/tensorwrapper/utilities/make_tensor.hpp @@ -0,0 +1,37 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +namespace tensorwrapper::utilities { + +template +Tensor make_tensor(std::vector shape, BeginIterator begin, + EndIterator end) { + shape::Smooth smooth_shape(shape.begin(), shape.end()); + std::vector data(begin, end); + buffer::Contiguous buffer(data, smooth_shape); + return Tensor(std::move(smooth_shape), std::move(buffer)); +} + +template +Tensor make_tensor(std::initializer_list shape, + ContainerType&& value) { + return make_tensor(shape, value.begin(), value.end()); +} + +} // namespace tensorwrapper::utilities diff --git a/include/tensorwrapper/utilities/utilities.hpp b/include/tensorwrapper/utilities/utilities.hpp index 39632702..90143068 100644 --- a/include/tensorwrapper/utilities/utilities.hpp +++ b/include/tensorwrapper/utilities/utilities.hpp @@ -16,6 +16,8 @@ #pragma once #include +#include +#include #include /// Namespace for helper functions diff --git a/src/tensorwrapper/operations/power.cpp b/src/tensorwrapper/operations/power.cpp new file mode 100644 index 00000000..d5482026 --- /dev/null +++ b/src/tensorwrapper/operations/power.cpp @@ -0,0 +1,50 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace tensorwrapper::operations { +namespace { + +class PowerKernel { +public: + PowerKernel(double pow) : m_pow_(pow) {} + + template + void operator()(std::span A) const { + if constexpr(std::is_const_v) { + // This path is only for the compiler, we won't actually get to + // it. + throw std::runtime_error("Can't modify const data"); + } else { + for(auto& a : A) { a = types::pow(a, m_pow_); } + } + } + +private: + double m_pow_; +}; +} // namespace + +Tensor power(Tensor A, double pow) { + PowerKernel kernel(pow); + auto& buffer = make_contiguous(A.buffer()); + buffer::visit_contiguous_buffer(kernel, buffer); + return A; +} +} // namespace tensorwrapper::operations diff --git a/src/tensorwrapper/utilities/diagonal_matrix.cpp b/src/tensorwrapper/utilities/diagonal_matrix.cpp new file mode 100644 index 00000000..775d00f5 --- /dev/null +++ b/src/tensorwrapper/utilities/diagonal_matrix.cpp @@ -0,0 +1,48 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace tensorwrapper::utilities { +namespace { +struct Kernel { + template + auto operator()(const std::span& diagonal_elements) { + using clean_type = std::decay_t; + const auto n = diagonal_elements.size(); + shape::Smooth new_shape{n, n}; + std::vector data(n * n, 0); + for(std::size_t i = 0; i < n; ++i) { + data[i * n + i] = diagonal_elements[i]; + } + buffer::Contiguous buffer(data, new_shape); + return Tensor(std::move(new_shape), std::move(buffer)); + } +}; +} // namespace + +Tensor diagonal_matrix(const Tensor& diagonal_elements) { + if(diagonal_elements.rank() != 1) { + throw std::runtime_error("Diagonal elements must be a vector"); + } + Kernel k; + auto& buffer = make_contiguous(diagonal_elements.buffer()); + return buffer::visit_contiguous_buffer(k, buffer); +} + +} // namespace tensorwrapper::utilities diff --git a/tests/cxx/unit_tests/tensorwrapper/operations/power.cpp b/tests/cxx/unit_tests/tensorwrapper/operations/power.cpp new file mode 100644 index 00000000..d89750c8 --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/operations/power.cpp @@ -0,0 +1,45 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +using namespace tensorwrapper; +using namespace tensorwrapper::operations; +using namespace tensorwrapper::utilities; + +TEMPLATE_LIST_TEST_CASE("power", "", types::floating_point_types) { + SECTION("scalar") { + shape::Smooth s{}; + Tensor scalar(s, testing::eigen_scalar()); + auto rv = power(scalar, 2); + REQUIRE(approximately_equal( + rv, Tensor(s, testing::eigen_scalar(42 * 42)))); + } + + SECTION("vector") { + shape::Smooth s{5}; + Tensor vector(s, testing::eigen_vector()); + auto rv = power(vector, 0.5); + TestType sqrt2 = std::sqrt(2); + TestType sqrt3 = std::sqrt(3); + std::vector data{0, 1, sqrt2, sqrt3, 2}; + auto corr = make_tensor({5}, data.begin(), data.end()); + REQUIRE(approximately_equal(rv, corr)); + } +} diff --git a/tests/cxx/unit_tests/tensorwrapper/utilities/diagonal_matrix.cpp b/tests/cxx/unit_tests/tensorwrapper/utilities/diagonal_matrix.cpp new file mode 100644 index 00000000..69ece93f --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/utilities/diagonal_matrix.cpp @@ -0,0 +1,33 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +using namespace tensorwrapper; +using namespace tensorwrapper::operations; +using namespace tensorwrapper::utilities; +using namespace testing; + +TEMPLATE_LIST_TEST_CASE("diagonal_matrix", "", types::floating_point_types) { + auto diagonal_values = make_tensor({3}, std::vector{1, 2, 3}); + auto corr = + make_tensor({3, 3}, std::vector{1, 0, 0, 0, 2, 0, 0, 0, 3}); + auto result = diagonal_matrix(diagonal_values); + REQUIRE(approximately_equal(result, corr)); +} diff --git a/tests/cxx/unit_tests/tensorwrapper/utilities/make_tensor.cpp b/tests/cxx/unit_tests/tensorwrapper/utilities/make_tensor.cpp new file mode 100644 index 00000000..3d63f252 --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/utilities/make_tensor.cpp @@ -0,0 +1,70 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +using namespace tensorwrapper; +using namespace tensorwrapper::operations; +using namespace tensorwrapper::utilities; + +TEMPLATE_LIST_TEST_CASE("make_tensor", "", types::floating_point_types) { + SECTION("scalar") { + std::vector data{42}; + auto tensor = make_tensor({}, data.begin(), data.end()); + auto tensor2 = make_tensor({}, data); + Tensor corr(shape::Smooth{}, testing::eigen_scalar(42)); + REQUIRE(approximately_equal(tensor, corr)); + REQUIRE(approximately_equal(tensor2, corr)); + } + + SECTION("vector") { + std::vector data{0, 1, 2, 3, 4}; + auto tensor = make_tensor({5}, data.begin(), data.end()); + auto tensor2 = make_tensor({5}, data); + Tensor corr(shape::Smooth{5}, testing::eigen_vector()); + REQUIRE(approximately_equal(tensor, corr)); + REQUIRE(approximately_equal(tensor2, corr)); + } + + SECTION("matrix") { + std::vector data{1, 2, 3, 4}; + auto tensor = make_tensor({2, 2}, data.begin(), data.end()); + auto tensor2 = make_tensor({2, 2}, data); + Tensor corr(shape::Smooth{2, 2}, testing::eigen_matrix()); + REQUIRE(approximately_equal(tensor, corr)); + REQUIRE(approximately_equal(tensor2, corr)); + } + SECTION("tensor3") { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8}; + auto tensor = make_tensor({2, 2, 2}, data.begin(), data.end()); + auto tensor2 = make_tensor({2, 2, 2}, data); + Tensor corr(shape::Smooth{2, 2, 2}, testing::eigen_tensor3()); + REQUIRE(approximately_equal(tensor, corr)); + REQUIRE(approximately_equal(tensor2, corr)); + } + SECTION("tensor4") { + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + auto tensor = make_tensor({2, 2, 2, 2}, data.begin(), data.end()); + auto tensor2 = make_tensor({2, 2, 2, 2}, data); + Tensor corr(shape::Smooth{2, 2, 2, 2}, + testing::eigen_tensor4()); + REQUIRE(approximately_equal(tensor, corr)); + REQUIRE(approximately_equal(tensor2, corr)); + } +}