diff --git a/include/tensorwrapper/dsl/pairwise_parser.hpp b/include/tensorwrapper/dsl/pairwise_parser.hpp index 99849f1e..2197c01d 100644 --- a/include/tensorwrapper/dsl/pairwise_parser.hpp +++ b/include/tensorwrapper/dsl/pairwise_parser.hpp @@ -145,8 +145,8 @@ class PairwiseParser { dispatch(lA, rhs.lhs()); lhs.object().scalar_multiplication(lhs.labels(), rhs.rhs(), lA); } else { - auto pA = lhs.object().clone(); - auto pB = lhs.object().clone(); + auto pA = rhs.lhs().object().clone(); + auto pB = rhs.rhs().object().clone(); auto labels = lhs.labels(); auto lA = (*pA)(labels); auto lB = (*pB)(labels); diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp index f4c954ef..ebd1e2d8 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp @@ -592,3 +592,307 @@ TEMPLATE_LIST_TEST_CASE("make_contiguous(shape, value)", "", Contiguous contig = buffer::make_contiguous(shape, init); REQUIRE(contig == corr); } + +TEST_CASE("interval contraction") { +#ifdef ENABLE_SIGMA + using interval_type = sigma::Interval; + tensorwrapper::shape::Smooth a_shape({4, 4}); + tensorwrapper::shape::Smooth b_shape({4, 4, 4, 4}); + std::vector a_buffer{interval_type(0.16477052), interval_type(0.27613426), + interval_type(-0.21520682), interval_type(-0.15798025), + interval_type(0.27613426), interval_type(0.64734094), + interval_type(-0.10065438), interval_type(-0.21520682), + interval_type(-0.21520682), interval_type(-0.10065438), + interval_type(0.64734094), interval_type(0.27613426), + interval_type(-0.15798025), interval_type(-0.21520682), + interval_type(0.27613426), interval_type(0.16477052)}; + + std::vector b_buffer{interval_type(0.7746059442114875, 0.7746059442114875), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.5699943521033807, 0.5699943521033807), + interval_type(0.2671325323063211, 0.2671325323063211), + interval_type(0.0842293132408630, 0.0842293392002486), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.2671325323063211, 0.2671325323063211), + interval_type(0.3360221878482700, 0.3360221878482700), + interval_type(0.1822547411402173, 0.1822547411402173), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.0842293132408630, 0.0842293392002486), + interval_type(0.1822547411402173, 0.1822547411402173), + interval_type(0.2309891154471100, 0.2309891154471100), + interval_type(0.4445903196340242, 0.4445903196340242), + interval_type(0.2975896157617893, 0.2975896157617893), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.2975896157617893, 0.2975896157617893), + interval_type(0.4445903196340242, 0.4445903196340242), + interval_type(0.2180688930726074, 0.2180764868443568), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.1822547411402173, 0.1822547411402173), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.0322095731006274, 0.0322148126672940), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0322095731006274, 0.0322148126672940), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0025484886593729, 0.0025550642433567), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0025484886593729, 0.0025550642433567), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.4445903196340242, 0.4445903196340242), + interval_type(0.2975896157617893, 0.2975896157617893), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.2975896157617893, 0.2975896157617893), + interval_type(0.4445903196340242, 0.4445903196340242), + interval_type(0.2180688930726074, 0.2180764868443568), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.2180688930726074, 0.2180764868443568), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.1822547411402174, 0.1822547411402174), + interval_type(0.5699943521033807, 0.5699943521033807), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.7746059442114875, 0.7746059442114875), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.5445965180667598, 0.5445965180667598), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.2801437259785110, 0.2801437259785110), + interval_type(0.3360221878482700, 0.3360221878482700), + interval_type(0.2671325323063211, 0.2671325323063211), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.2545062652576217, 0.2545102403725139), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.2545062652576217, 0.2545102403725139), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.2671325323063212, 0.2671325323063212), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.0322095731006275, 0.0322148126672941), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0322095731006275, 0.0322148126672941), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.0322095731006274, 0.0322148126672940), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0919419129407689, 0.0919506383693423), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0322095731006274, 0.0322148126672940), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.2671325323063212, 0.2671325323063212), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.2545062652576217, 0.2545102403725139), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.0859229666126391, 0.0859275381501771), + interval_type(0.2545062652576217, 0.2545102403725139), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.2671325323063212, 0.2671325323063212), + interval_type(0.3360221878482700, 0.3360221878482700), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.5445965180667598, 0.5445965180667598), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.1517471045021499, 0.1517471304615355), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.4063728513043064, 0.4063728513043064), + interval_type(0.7746059442114875, 0.7746059442114875), + interval_type(0.4445903196340244, 0.4445903196340244), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.1517471045021498, 0.1517471304615354), + interval_type(0.4445903196340244, 0.4445903196340244), + interval_type(0.5699943521033808, 0.5699943521033808), + interval_type(0.1822547411402174, 0.1822547411402174), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.2801437259785110, 0.2801437259785110), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.4445903196340245, 0.4445903196340245), + interval_type(0.2975896157617894, 0.2975896157617894), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.2975896157617894, 0.2975896157617894), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0025484886593729, 0.0025550642433567), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0248469091918667, 0.0248562278991491), + interval_type(0.0403894657257182, 0.0403955231018962), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0025484886593729, 0.0025550642433567), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.0322095731006275, 0.0322148126672941), + interval_type(0.0283749284020220, 0.0283801679686886), + interval_type(0.0859229666126392, 0.0859275381501772), + interval_type(0.1517471045021499, 0.1517471304615355), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.0088380257701204, 0.0088480573431448), + interval_type(0.0322095731006275, 0.0322148126672941), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.1822547411402173, 0.1822547411402173), + interval_type(0.1478475949866729, 0.1478475949866729), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.1478475949866730, 0.1478475949866730), + interval_type(0.2801437259785109, 0.2801437259785109), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.0681749223473700, 0.0681836477759434), + interval_type(0.2180688930726075, 0.2180764868443569), + interval_type(0.4445903196340243, 0.4445903196340243), + interval_type(0.2975896157617894, 0.2975896157617894), + interval_type(0.0236960375615112, 0.0236997571831646), + interval_type(0.0919419129407690, 0.0919506383693424), + interval_type(0.2975896157617894, 0.2975896157617894), + interval_type(0.4445903196340245, 0.4445903196340245), + interval_type(0.2309891154471100, 0.2309891154471100), + interval_type(0.1822547411402173, 0.1822547411402173), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.1822547411402174, 0.1822547411402174), + interval_type(0.3360221878482700, 0.3360221878482700), + interval_type(0.2671325323063212, 0.2671325323063212), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.0842293132408629, 0.0842293392002485), + interval_type(0.2671325323063212, 0.2671325323063212), + interval_type(0.5699943521033808, 0.5699943521033808), + interval_type(0.4445903196340244, 0.4445903196340244), + interval_type(0.0316618451794192, 0.0316679025555972), + interval_type(0.1266680887852389, 0.1266681147446245), + interval_type(0.4445903196340244, 0.4445903196340244), + interval_type(0.7746059442114875, 0.7746059442114875)}; + buffer::Contiguous a_tensor(a_buffer, a_shape); + buffer::Contiguous b_tensor(b_buffer, b_shape); + + std::vector c_buffer(a_shape.size(), interval_type(0.0)); + buffer::Contiguous c_tensor(c_buffer, a_shape); + c_tensor("k,l") = a_tensor("i,j") * b_tensor("k,l,j,i"); + + std::vector corr{ + interval_type(0.9438257539599672, 0.9438276901980663), + interval_type(0.6981236836099820, 0.6981338986433229), + interval_type(0.2572322079970916, 0.2572504879533325), + interval_type(0.0754708135852556, 0.0754973495884852), + interval_type(0.6981236836099821, 0.6981338986433229), + interval_type(1.1891028617809587, 1.1891047980190586), + interval_type(0.7219623782574144, 0.7219784457327180), + interval_type(0.2572322079970916, 0.2572504879533324), + interval_type(0.2572322079970916, 0.2572504879533325), + interval_type(0.7219623782574145, 0.7219784457327181), + interval_type(1.1891028617809591, 1.1891047980190588), + interval_type(0.6981236836099824, 0.6981338986433231), + interval_type(0.0754708135852556, 0.0754973495884852), + interval_type(0.2572322079970917, 0.2572504879533325), + interval_type(0.6981236836099823, 0.6981338986433230), + interval_type(0.9438257539599676, 0.9438276901980664)}; + + auto c_tensor_buffer = buffer::get_raw_data(c_tensor); + REQUIRE(c_tensor_buffer.size() == corr.size()); + for(size_t i = 0; i < c_tensor_buffer.size(); ++i) { + REQUIRE(c_tensor_buffer[i].lower() == Catch::Approx(corr[i].lower())); + REQUIRE(c_tensor_buffer[i].upper() == Catch::Approx(corr[i].upper())); + } +#endif +}