Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion SeQuant/core/eval/backends/tiledarray/result.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,14 @@ class ResultTensorOfTensorTA final : public Result {

using _inner_tensor_type = typename ArrayT::value_type::value_type;

// "Regular" (non-nested) companion array for ToT * T einsum. The OUTER tile
// type must be a TA::Tensor — inner tile types like btas::Tensor are only
// valid as the *innermost* tile (they don't support permute/reshape/batch
// and so can't drive einsum's outer kernel). So we wrap the inner's numeric
// type in TA::Tensor here, rather than re-using the inner tile type as the
// outer tile.
using compatible_regular_distarray_type =
TA::DistArray<_inner_tensor_type, typename ArrayT::policy_type>;
TA::DistArray<TA::Tensor<numeric_type>, typename ArrayT::policy_type>;

// Only @c that_type type is allowed for ToT * T computation
using that_type = ResultTensorTA<compatible_regular_distarray_type>;
Expand Down
20 changes: 18 additions & 2 deletions SeQuant/core/eval/cache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ class CacheManager {
std::unordered_map<TreeNode, entry, hasher_type, comparator_type> cache_map_;

public:
template <typename Iterable1>
explicit CacheManager(Iterable1&& decaying) noexcept {
template <typename Iterable>
requires(!std::same_as<std::remove_cvref_t<Iterable>, CacheManager>)
explicit CacheManager(Iterable&& decaying) noexcept {
for (auto&& [k, c] : decaying) cache_map_.try_emplace(k, entry{c});
}

Expand Down Expand Up @@ -145,6 +146,21 @@ class CacheManager {
return iter == end ? -1 : static_cast<int>(iter->second.max_life_count());
}

/// \return true iff the key is registered for caching and currently holds
/// stored data (i.e. has been stored and not yet drained by its
/// final access).
[[nodiscard]] bool alive(key_type const& key) const noexcept {
auto iter = cache_map_.find(key);
return iter != cache_map_.end() && iter->second.alive();
}

/// \return size in bytes of the data currently held for @p key, or 0 if
/// the key is not registered or no data is currently stored.
[[nodiscard]] size_t entry_size_in_bytes(key_type const& key) const noexcept {
auto iter = cache_map_.find(key);
return iter == cache_map_.end() ? 0 : iter->second.size_in_bytes();
}

///
/// \return The number of entries with life_count greater than zero.
///
Expand Down
179 changes: 142 additions & 37 deletions SeQuant/core/eval/eval.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <any>
#include <iostream>
#include <optional>
#include <stdexcept>
#include <type_traits>

Expand All @@ -33,15 +34,15 @@ struct Bytes {
size_t value;
};

template <typename... T>
requires((std::same_as<ResultPtr, T> && ...))
[[nodiscard]] auto bytes(T const&... args) {
return Bytes{(args->size_in_bytes() + ...)};
}

template <typename N, bool F>
[[nodiscard]] inline auto bytes(CacheManager<N, F> const& cman) {
return cman.size_in_bytes();
template <typename T, typename... Ts>
[[nodiscard]] inline auto bytes(T const& arg, Ts const&... args) {
auto one = [](auto const& a) -> size_t {
if constexpr (requires { a->size_in_bytes(); })
return a->size_in_bytes();
else
return a.size_in_bytes();
};
return Bytes{(one(arg) + ... + one(args))};
}

[[nodiscard]] inline auto to_string(Bytes bs) noexcept {
Expand Down Expand Up @@ -107,18 +108,54 @@ enum struct TermMode { Begin, End };
return (mode == TermMode::Begin) ? "Begin" : "End";
}

/// One log record per eval op. Line format:
///
// clang-format off
/// Eval | <mode> | <time> | [left=L | right=R |] result=X | alloc=A | hw=H | <label>
// clang-format on
///
/// Which fields are set depends on the op's arity:
///
/// mode | left/right | alloc
/// ----------------------------------------------+------------+--------
/// Constant / Variable / Tensor (leaf) | — | result
/// Permute / MultByPhase / | — | result
/// Symmetrize / Antisymmetrize | |
/// SumInplace | — | 0B
/// Sum / Product | set | result
///
/// Only Sum and Product set left/right, since their operand sizes can
/// differ from the result. Other modes omit those fields rather than
/// zeroing them, so a logged 0B always means an empty buffer.
///
/// mem_result is the size of the buffer the op produces; for SumInplace
/// it's the size of the accumulator after the add. mem_alloc is what the
/// op allocated — equal to mem_result everywhere except SumInplace,
/// which writes into the accumulator and allocates nothing. mem_hwmark
/// is the live working set during the op:
///
/// bytes(cache) + bytes(result) + bytes of each operand not aliased
/// to a cache entry
///
/// Aliasing is evaluated at each call site using cache.alive,
/// canon_phase, and the requested layout.
struct EvalStat {
EvalMode mode;
Duration time;
Bytes memory;
Bytes mem_result{};
Bytes mem_alloc{};
Bytes mem_hwmark{};
std::optional<Bytes> mem_left;
std::optional<Bytes> mem_right;
};

struct CacheStat {
CacheMode mode;
size_t key;
int curr_life, max_life;
size_t num_alive;
Bytes memory;
Bytes entry_memory;
Bytes total_memory;
};

template <typename Arg, typename... Args>
Expand All @@ -129,21 +166,35 @@ void log(Arg const& arg, Args const&... args) {

template <typename... Args>
auto eval(EvalStat const& stat, Args const&... args) {
log("Eval", //
to_string(stat.mode), //
stat.time, //
to_string(stat.memory), //
args...);
auto const result_s = std::format("result={}", to_string(stat.mem_result));
auto const alloc_s = std::format("alloc={}", to_string(stat.mem_alloc));
auto const hw_s = std::format("hw={}", to_string(stat.mem_hwmark));
if (stat.mem_left) {
SEQUANT_ASSERT(stat.mem_right);
log("Eval", //
to_string(stat.mode), //
stat.time, //
std::format("left={}", to_string(*stat.mem_left)), //
std::format("right={}", to_string(*stat.mem_right)), //
result_s, alloc_s, hw_s, //
args...);
} else {
log("Eval", //
to_string(stat.mode), //
stat.time, //
result_s, alloc_s, hw_s, args...);
}
}

template <typename... Args>
auto cache(CacheStat const& stat, Args const&... args) {
log("Cache", //
to_string(stat.mode), //
stat.key, //
std::format("{}/{}", stat.curr_life, stat.max_life), //
stat.num_alive, //
to_string(stat.memory), //
log("Cache", //
to_string(stat.mode), //
std::format("key={}", stat.key), //
std::format("life={}/{}", stat.curr_life, stat.max_life), //
std::format("alive={}", stat.num_alive), //
std::format("entry={}", to_string(stat.entry_memory)), //
std::format("total={}", to_string(stat.total_memory)), //
args...);
}

Expand All @@ -164,7 +215,8 @@ auto cache(N const& node, CacheManager<N, F>& cm, Args const&... args) {
.curr_life = cur_l,
.max_life = max_l,
.num_alive = cm.alive_count(),
.memory = {bytes(cm)}},
.entry_memory = {cm.entry_size_in_bytes(node)},
.total_memory = {bytes(cm)}},
args...);
}

Expand Down Expand Up @@ -263,7 +315,7 @@ ResultPtr evaluate(Node const& node, //
CacheManager<N, FHC>& cache) {
if constexpr (Cache == CacheCheck::Checked) { // return from cache if found

auto mult_by_phase = [&node](ResultPtr res) {
auto mult_by_phase = [&node, &cache](ResultPtr res) {
auto phase = node->canon_phase();
if (phase == 1) return res;

Expand All @@ -272,23 +324,27 @@ ResultPtr evaluate(Node const& node, //
timed_eval_inplace([&]() { post = res->mult_by_phase(phase); });

if constexpr (trace(EvalTrace)) {
size_t hwmark = log::bytes(cache, post).value;
if (!cache.alive(node)) hwmark += log::bytes(res).value;
auto stat = log::EvalStat{.mode = log::EvalMode::MultByPhase,
.time = time,
.memory = log::bytes(res, post)};
.mem_result = log::bytes(post),
.mem_alloc = log::bytes(post),
.mem_hwmark = {hwmark}};
log::eval(stat, std::format("{} * {}", phase, node->label()));
}
return post;
};

if (auto ptr = cache.access(node); ptr) {
if constexpr (trace(EvalTrace)) log::cache(node, cache);
if constexpr (trace(EvalTrace)) log::cache(node, cache, log::label(node));

return mult_by_phase(ptr);
} else if (cache.exists(node)) {
auto ptr = cache.store(
node, mult_by_phase(evaluate<EvalTrace, CacheCheck::Unchecked>(
node, le, cache)));
if constexpr (trace(EvalTrace)) log::cache(node, cache);
if constexpr (trace(EvalTrace)) log::cache(node, cache, log::label(node));

return mult_by_phase(ptr);
} else {
Expand Down Expand Up @@ -329,12 +385,32 @@ ResultPtr evaluate(Node const& node, //

// logging
if constexpr (trace(EvalTrace)) {
auto stat =
log::EvalStat{.mode = log::eval_mode(node),
.time = time,
.memory = node.leaf() ? log::bytes(result)
: log::bytes(left, right, result)};
log::eval(stat, log::label(node));
if (node.leaf()) {
log::eval(log::EvalStat{.mode = log::eval_mode(node),
.time = time,
.mem_result = log::bytes(result),
.mem_alloc = log::bytes(result),
.mem_hwmark = log::bytes(cache, result)},
log::label(node));
} else {
// A cached child is *distinct* from the local left/right when its
// canon_phase != 1, because mult_by_phase allocates a fresh buffer
// while the cache still holds the pre-phase data. So only skip the
// local's bytes when the cache aliases the same buffer (phase == 1).
size_t hwmark = log::bytes(cache, result).value;
if (!cache.alive(node.left()) || node.left()->canon_phase() != 1)
hwmark += log::bytes(left).value;
if (!cache.alive(node.right()) || node.right()->canon_phase() != 1)
hwmark += log::bytes(right).value;
log::eval(log::EvalStat{.mode = log::eval_mode(node),
.time = time,
.mem_result = log::bytes(result),
.mem_alloc = log::bytes(result),
.mem_hwmark = {hwmark},
.mem_left = log::bytes(left),
.mem_right = log::bytes(right)},
log::label(node));
}
}

return result;
Expand Down Expand Up @@ -386,9 +462,17 @@ ResultPtr evaluate(Node const& node, //
// logging
if constexpr (trace(EvalTrace)) {
if (perm) {
// result.pre aliases the cache only when the inner evaluate returned
// the cached buffer unchanged — i.e. the node is cached AND no
// mult_by_phase fresh allocation happened (phase == 1).
size_t hwmark = log::bytes(cache, result.post).value;
if (!cache.alive(node) || node->canon_phase() != 1)
hwmark += log::bytes(result.pre).value;
auto stat = log::EvalStat{.mode = log::EvalMode::Permute,
.time = time,
.memory = log::bytes(result.pre, result.post)};
.mem_result = log::bytes(result.post),
.mem_alloc = log::bytes(result.post),
.mem_hwmark = {hwmark}};
log::eval(stat, node->label());
}
log::term(log::TermMode::End, xpr);
Expand Down Expand Up @@ -422,6 +506,11 @@ ResultPtr evaluate(Nodes const& nodes, //
F const& le, CacheManager<N, FHC>& cache) {
ResultPtr result;

// pre comes back from the permute-wrapping evaluate; it aliases the
// cache only when the inner evaluate returned the cached buffer
// unchanged — i.e. node cached, phase == 1, AND no permute happened.
bool const layout_is_default = (layout == decltype(layout){});

for (auto&& n : nodes) {
if (!result) {
result = evaluate<EvalTrace>(n, layout, le, cache);
Expand All @@ -433,9 +522,17 @@ ResultPtr evaluate(Nodes const& nodes, //

// logging
if constexpr (trace(EvalTrace)) {
// SumInplace allocates nothing: it writes into the accumulator.
// hwmark counts the cache plus both operands live at this moment;
// skip pre's bytes only when pre is the cached buffer itself.
size_t hwmark = log::bytes(cache, result).value;
if (!cache.alive(n) || n->canon_phase() != 1 || !layout_is_default)
hwmark += log::bytes(pre).value;
auto stat = log::EvalStat{.mode = log::EvalMode::SumInplace,
.time = time,
.memory = log::bytes(result, pre)};
.mem_result = log::bytes(result),
.mem_alloc = {0},
.mem_hwmark = {hwmark}};
log::eval(stat, n->label());
}
}
Expand Down Expand Up @@ -510,9 +607,14 @@ ResultPtr evaluate_symm(Args&&... args) {

// logging
if constexpr (trace(EvalTrace)) {
// cache is owned by the inner evaluate call and out of scope here;
// hwmark reflects only the local working set (pre + freshly allocated
// result both live during the symmetrize op).
auto stat = log::EvalStat{.mode = log::EvalMode::Symmetrize,
.time = time,
.memory = log::bytes(pre, result)};
.mem_result = log::bytes(result),
.mem_alloc = log::bytes(result),
.mem_hwmark = log::bytes(pre, result)};
log::eval(stat, node0(arg0(std::forward<Args>(args)...))->label());
}

Expand Down Expand Up @@ -542,9 +644,12 @@ ResultPtr evaluate_antisymm(Args&&... args) {

// logging
if constexpr (trace(EvalTrace)) {
// See Symmetrize for the rationale on hwmark.
auto stat = log::EvalStat{.mode = log::EvalMode::Antisymmetrize,
.time = time,
.memory = log::bytes(pre, result)};
.mem_result = log::bytes(result),
.mem_alloc = log::bytes(result),
.mem_hwmark = log::bytes(pre, result)};
log::eval(stat, n0->label());
}
return result;
Expand Down
Loading