diff --git a/quest/include/paulis.h b/quest/include/paulis.h index 1d08169f..e4645923 100644 --- a/quest/include/paulis.h +++ b/quest/include/paulis.h @@ -99,6 +99,9 @@ typedef struct { * * @defgroup paulis_reporters Reporters * @brief Functions for printing Pauli data structures. + * + * @defgroup paulis_setters Setters + * @brief Functions for overwriting the elements of Pauli data structures. */ @@ -420,6 +423,91 @@ extern "C" { +/* + * SETTERS + */ + +// enable invocation by both C and C++ binaries +#ifdef __cplusplus +extern "C" { +#endif + + + /** @ingroup paulis_setters + * + * Reorders the terms within a @p sum of weighted Pauli strings to sort Pauli + * strings into lexicographic (dictionary) ordering. + * + * @formulae + * Let @f$ H = @f$ @p sum, which can be represented as + * @f[ + H = \sum\limits_j c_j \, \hat{\sigma}_j + * @f] + * where @f$ c_j @f$ is the coefficient of the @f$ j @f$-th PauliStr @f$ \hat{\sigma}_j @f$. + * + * This function constructs and applies the permutation @f$ \pi @f$ to @f$ H @f$ + * @f[ + H = \sum\limits_j c_{\pi(j)} \, \hat{\sigma}_{\pi(j)} + * @f] + * such that + * @f[ + * \hat{\sigma}_{\pi(i)} <_{lex} \hat{\sigma}_{\pi(j)} \ \forall \ \pi(i) < \pi(j). + * @f] + * + * + * @param[in,out] sum a weighted sum of Pauli strings to reorder. + * + * @throws @validationerror + * - if @p sum is not initialised. + * + * @see + * - sortPauliStrSumMagnitude() + * @author Vasco Ferreira + */ + void sortPauliStrSumLexicographic(PauliStrSum sum); + + + /** @ingroup paulis_setters + * + * Reorders the terms within a @p sum of weighted Pauli strings to sort Pauli + * strings into decreasing magnitude weights. + * + * @formulae + * Let @f$ H = @f$ @p sum, represented as the weighted sum + * @f[ + H = \sum\limits_j c_j \, \hat{\sigma}_j + * @f] + * where @f$ c_j @f$ is the coefficient of the @f$ j @f$-th PauliStr @f$ \hat{\sigma}_j @f$. + * + * This function constructs and applies the permutation @f$ \pi @f$ to @f$ H @f$ + * @f[ + H = \sum\limits_j c_{\pi(j)} \, \hat{\sigma}_{\pi(j)} + * @f] + * such that + * @f[ + * |c_{\pi(i)}| > |c_{\pi(j)}| \, \forall \, \pi(i) < \pi(j). + * @f] + * + * @param[in,out] sum a weighted sum of Pauli strings to reorder. + * + * @throws @validationerror + * - if @p sum is not initialised. + * + * @see + * - sortPauliStrSumLexicographic() + * + * @author Vasco Ferreira + */ + void sortPauliStrSumMagnitude(PauliStrSum sum); + + +// end de-mangler +#ifdef __cplusplus +} +#endif + + + #endif // PAULIS_H /** @} */ // (end file-wide doxygen defgroup) diff --git a/quest/src/api/paulis.cpp b/quest/src/api/paulis.cpp index 5a2a122c..c53bf996 100644 --- a/quest/src/api/paulis.cpp +++ b/quest/src/api/paulis.cpp @@ -291,3 +291,32 @@ extern "C" void reportPauliStrSum(PauliStrSum sum) { // exclude mandatory newline above print_oneFewerNewlines(); } + + + +/* + * SETTERS + */ + +extern "C" void sortPauliStrSumLexicographic(PauliStrSum sum) { + validate_pauliStrSumFields(sum, __func__); + + auto lexSort = [&](qindex i, qindex j) { + PauliStr strI = sum.strings[i]; + PauliStr strJ = sum.strings[j]; + return std::tie(strI.highPaulis, strI.lowPaulis) < std::tie(strJ.highPaulis, strJ.lowPaulis); + }; + + paulis_sortGeneric(sum, lexSort); +} + +extern "C" void sortPauliStrSumMagnitude(PauliStrSum sum) { + validate_pauliStrSumFields(sum, __func__); + + auto magSort = [&](qindex i, qindex j) { + return std::norm(sum.coeffs[i]) > std::norm(sum.coeffs[j]); + }; + + paulis_sortGeneric(sum, magSort); +} + diff --git a/quest/src/core/paulilogic.cpp b/quest/src/core/paulilogic.cpp index 1a90dce2..4ac748ab 100644 --- a/quest/src/core/paulilogic.cpp +++ b/quest/src/core/paulilogic.cpp @@ -13,6 +13,7 @@ #include "quest/src/core/bitwise.hpp" #include "quest/src/core/errors.hpp" +#include #include #include #include @@ -307,6 +308,32 @@ qindex paulis_getTargetBitMask(PauliStrSum sum) { } +void paulis_applyPermutation(PauliStrSum sum, vector scatterPermutation) { + // permutation passed by value since we modify it + + // scatterPermutation[i] = destination index for element originally at i + for (qindex i = 0; i < sum.numTerms; i++) { + while (scatterPermutation[i] != i) { + qindex j = scatterPermutation[i]; + std::swap(sum.strings[i], sum.strings[j]); + std::swap(sum.coeffs[i], sum.coeffs[j]); + std::swap(scatterPermutation[i], scatterPermutation[j]); + } + } +} + +void paulis_sortGeneric(PauliStrSum sum, std::function comparator) { + + // gatherPermutation[j] = source index of element placed at j + vector gatherPermutation(sum.numTerms); + std::iota(gatherPermutation.begin(), gatherPermutation.end(), 0); + std::stable_sort(gatherPermutation.begin(), gatherPermutation.end(), comparator); + + // invert permutation and apply + vector scatterPermutation = util_invertPermutation(gatherPermutation); + paulis_applyPermutation(sum, scatterPermutation); +} + void paulis_setPauliStrSumToScaledTensorProdOfConjWithSelf(PauliStrSum out, qreal factor, PauliStrSum in, int numQubits) { // sets out = factor * conj(in) (x) in, where in has dim of numQubits diff --git a/quest/src/core/paulilogic.hpp b/quest/src/core/paulilogic.hpp index f3748b5e..bc8b725b 100644 --- a/quest/src/core/paulilogic.hpp +++ b/quest/src/core/paulilogic.hpp @@ -76,6 +76,10 @@ int paulis_getIndOfLefmostNonIdentityPauli(PauliStrSum sum); qindex paulis_getTargetBitMask(PauliStrSum sum); +void paulis_applyPermutation(PauliStrSum sum, vector permutation); + +void paulis_sortGeneric(PauliStrSum sum, std::function comparator); + // below are used exclusively by Trotterisation @@ -88,4 +92,4 @@ void paulis_setPauliStrSumToScaledProdOfAdjointWithSelf(PauliStrSum out, qreal f void paulis_setPauliStrSumToShiftedConj(PauliStrSum out, PauliStrSum in, int numQubits); -#endif // PAULILOGIC_HPP \ No newline at end of file +#endif // PAULILOGIC_HPP diff --git a/quest/src/core/utilities.cpp b/quest/src/core/utilities.cpp index 999bd9a7..861ccc73 100644 --- a/quest/src/core/utilities.cpp +++ b/quest/src/core/utilities.cpp @@ -1229,3 +1229,21 @@ void util_tryAllocMatrix(vector> &matr, qindex numRows, qindex num errFunc(); } } + + + +/* + * OTHER + */ + +vector util_invertPermutation(const vector& permutation) { + qindex numTerms = permutation.size(); + vector out(numTerms); + + // invert permutation + for (qindex i = 0; i < numTerms; i++) { + out[permutation[i]] = i; + } + + return out; +} diff --git a/quest/src/core/utilities.hpp b/quest/src/core/utilities.hpp index 4b7fb5db..25c59940 100644 --- a/quest/src/core/utilities.hpp +++ b/quest/src/core/utilities.hpp @@ -430,4 +430,10 @@ void util_tryAllocMatrix(vector> &vec, qindex numRows, qindex numC -#endif // UTILITIES_HPP \ No newline at end of file +/* + * OTHER + */ + +vector util_invertPermutation(const vector& permutation); + +#endif // UTILITIES_HPP diff --git a/tests/unit/paulis.cpp b/tests/unit/paulis.cpp index 8b6fc4f9..e3339100 100644 --- a/tests/unit/paulis.cpp +++ b/tests/unit/paulis.cpp @@ -587,6 +587,61 @@ TEST_CASE( "destroyPauliStrSum", TEST_CATEGORY ) { } } +TEST_CASE( "sortPauliStrSumLexicographic", TEST_CATEGORY ) { + + SECTION( LABEL_CORRECTNESS ) { + + vector coeffs = {0.1_i, 2+1_i, 5, 3+4_i}; + vector strings = { + getPauliStr("XY", {31,32}), + getPauliStr("YX", {0,1}), + getPauliStr("II", {0,1}), + getPauliStr("YY", {31,32}) + }; + + PauliStrSum sum = createPauliStrSum(strings, coeffs); + sortPauliStrSumLexicographic(sum); + + REQUIRE(sum.coeffs[0] == 5+0_i); + REQUIRE(sum.coeffs[1] == 2+1_i); + REQUIRE(sum.coeffs[3] == 3+4_i); + + REQUIRE(sum.strings[0].lowPaulis == 0); + REQUIRE(sum.strings[1].lowPaulis == 2 + 1*4); + REQUIRE(sum.strings[3].highPaulis == 2); + REQUIRE(sum.strings[3].lowPaulis == 2*std::pow(4, 31)); + + destroyPauliStrSum(sum); + } +} + +TEST_CASE( "sortPauliStrSumMagnitude", TEST_CATEGORY ) { + + SECTION( LABEL_CORRECTNESS ) { + + vector coeffs = {0.1_i, 2+1_i, 5, 3+4_i}; + vector strings = { + getPauliStr("XY", {0,1}), + getPauliStr("ZX", {0,1}), + getPauliStr("II", {0,1}), + getPauliStr("YZ", {0,1}) + }; + + PauliStrSum sum = createPauliStrSum(strings, coeffs); + sortPauliStrSumMagnitude(sum); + + REQUIRE(sum.coeffs[0] == 5+0_i); + REQUIRE(sum.coeffs[1] == 3+4_i); + REQUIRE(sum.coeffs[3] == 0+0.1_i); + + REQUIRE(sum.strings[0].lowPaulis == 0); + REQUIRE(sum.strings[1].lowPaulis == 2 + 3*4); + REQUIRE(sum.strings[3].lowPaulis == 1 + 2*4); + + destroyPauliStrSum(sum); + } +} + /** @} (end defgroup) */