Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 68 additions & 121 deletions csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1413,52 +1413,42 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
Val* Index::getLinearLogicalIndex(
TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
if (!ir_utils::hasRootToLoopLinearTransformations(consumer_tv) ||
ir_utils::isCpAsyncBulkLoad(consumer_tv->definition()) ||
GpuLower::current()->idModelOptions().isTensorIndexerEnabled() ||
GpuLower::current()->tmemInfo().hasTMemTensor()) {
const TensorIndexer& indexer = GpuLower::current()->tensorIndexer();
auto per_dim_indices = indexer.getIndexFor(
consumer_tv->definition(),
/*as_consumer=*/true,
consumer_tv->getLogicalDomain(),
loops,
/*use_magic_zero=*/true);
Val* stride = consumer_tv->fusion()->oneVal();
for (const auto [i, logical_id] :
enumerate(consumer_tv->getLogicalDomain()) | std::views::reverse) {
auto per_dim_index = per_dim_indices.at(i);
auto per_dim_strided_index =
SimplifyingIrBuilder::mulExpr(per_dim_index, stride);
per_dim_indices.at(i) = per_dim_strided_index;
stride = SimplifyingIrBuilder::mulExpr(stride, logical_id->extent());
}
return sumVals(per_dim_indices);
} else {
auto guard = ir_utils::allocateToLogicalDomainGuard(consumer_tv, true);
return sumVals(getGlobalConsumerStridedIndices(consumer_tv, loops));
}
NVF_ERROR(
GpuLower::current()->idModelOptions().isTensorIndexerEnabled(),
"Legacy indexer no longer available");

const TensorIndexer& indexer = GpuLower::current()->tensorIndexer();
auto per_dim_indices = indexer.getIndexFor(
consumer_tv->definition(),
/*as_consumer=*/true,
consumer_tv->getLogicalDomain(),
loops,
/*use_magic_zero=*/true);
Val* stride = consumer_tv->fusion()->oneVal();
for (const auto [i, logical_id] :
enumerate(consumer_tv->getLogicalDomain()) | std::views::reverse) {
auto per_dim_index = per_dim_indices.at(i);
auto per_dim_strided_index =
SimplifyingIrBuilder::mulExpr(per_dim_index, stride);
per_dim_indices.at(i) = per_dim_strided_index;
stride = SimplifyingIrBuilder::mulExpr(stride, logical_id->extent());
}
return sumVals(per_dim_indices);
}

std::vector<Val*> Index::getConsumerPerDimLogicalIndex(
TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
if (!ir_utils::hasRootToLoopLinearTransformations(consumer_tv) ||
GpuLower::current()->idModelOptions().isTensorIndexerEnabled() ||
GpuLower::current()->tmemInfo().hasTMemTensor()) {
const TensorIndexer& indexer = GpuLower::current()->tensorIndexer();
return indexer.getIndexFor(
consumer_tv->definition(),
/*as_consumer=*/true,
consumer_tv->getLogicalDomain(),
loops);
} else {
auto guard = ir_utils::allocateToLogicalDomainGuard(consumer_tv, false);
IndexFromIdGraph index_from_id_graph =
getTensorIndexFromIdGraph(loops, consumer_tv);
return getConsumerAllocationIndices(
consumer_tv, loops, index_from_id_graph);
}
NVF_ERROR(
GpuLower::current()->idModelOptions().isTensorIndexerEnabled(),
"Legacy indexer no longer available");

const TensorIndexer& indexer = GpuLower::current()->tensorIndexer();
return indexer.getIndexFor(
consumer_tv->definition(),
/*as_consumer=*/true,
consumer_tv->getLogicalDomain(),
loops);
}

std::vector<Val*> Index::getProducerPerDimLogicalIndex(
Expand Down Expand Up @@ -1913,38 +1903,6 @@ Val* Index::getProducerStridedIndices(
}
}

namespace {

bool shouldUseTensorIndexer(
const TensorView* producer,
const TensorView* consumer) {
// Check if TensorIndexer is definitely required
auto is_tensor_indexer_required = [&]() -> bool {
bool is_producer_tma_op = producer->definition() != nullptr &&
producer->definition()->isA<LoadStoreOp>() &&
ir_utils::isCpAsyncBulkLoad(producer->definition());
bool is_consumer_tma_op = consumer->definition() != nullptr &&
consumer->definition()->isA<LoadStoreOp>() &&
ir_utils::isCpAsyncBulkLoad(consumer->definition());

return !ir_utils::hasRootToLoopLinearTransformations(producer) ||
(consumer->definition()->isA<MmaOp>() &&
isHopper(consumer->definition()->as<MmaOp>()->macro())) ||
is_producer_tma_op || is_consumer_tma_op ||
GpuLower::current()->tmemInfo().hasTMemTensor();
};

// TensorIndexer is always used when required or if not disabled.
// Note: Previously, ldmatrix and stmatrix were first introduced
// with Ampere, their indexing were only implemented in the legacy
// indexer in a rather manual way. The current implementation uses
// the alternate loop domain to enable TensorIndexer-based indexing.
return is_tensor_indexer_required() ||
GpuLower::current()->idModelOptions().isTensorIndexerEnabled();
}

} // namespace

// Producer is the inputs of an expression
kir::TensorIndex* Index::getProducerIndex(
TensorView* producer,
Expand All @@ -1954,28 +1912,25 @@ kir::TensorIndex* Index::getProducerIndex(
bool generate_pointer,
DataType as_type,
bool ld_st_matrix) {
Val* index = nullptr;

if (shouldUseTensorIndexer(producer, consumer)) {
index = GpuLower::current()->tensorIndexer().getLinearIndex(
producer, consumer->definition(), loops, override_index, ld_st_matrix);
if (generate_pointer) {
auto address_offset = index;
if (producer->getMemoryType() == MemoryType::Shared) {
auto producer_dt = producer->getDataType();
NVF_ERROR(producer_dt.has_value());
auto index_dt = index->getDataType();
NVF_ERROR(index_dt.has_value());
address_offset = SimplifyingIrBuilder::mulExpr(
address_offset,
IrBuilder::create<Val>(dataTypeSizeByte(*producer_dt), *index_dt));
}
index = SimplifyingIrBuilder::addExpr(
IrBuilder::baseAddressExpr(producer), address_offset);
}
} else {
index = getProducerStridedIndices(
producer, consumer, loops, override_index, generate_pointer);
NVF_ERROR(
GpuLower::current()->idModelOptions().isTensorIndexerEnabled(),
"Legacy indexer no longer available");

Val* index = GpuLower::current()->tensorIndexer().getLinearIndex(
producer, consumer->definition(), loops, override_index, ld_st_matrix);
if (generate_pointer) {
auto address_offset = index;
if (producer->getMemoryType() == MemoryType::Shared) {
auto producer_dt = producer->getDataType();
NVF_ERROR(producer_dt.has_value());
auto index_dt = index->getDataType();
NVF_ERROR(index_dt.has_value());
address_offset = SimplifyingIrBuilder::mulExpr(
address_offset,
IrBuilder::create<Val>(dataTypeSizeByte(*producer_dt), *index_dt));
}
index = SimplifyingIrBuilder::addExpr(
IrBuilder::baseAddressExpr(producer), address_offset);
}

index = GpuLower::current()->commonScalarMap().hoistScalar(index, loops);
Expand Down Expand Up @@ -2050,33 +2005,25 @@ kir::TensorIndex* Index::getConsumerIndex(
bool generate_pointer,
DataType as_type,
bool ld_st_matrix) {
Val* index = nullptr;
if (!ir_utils::hasRootToLoopLinearTransformations(consumer) ||
ir_utils::isCpAsyncBulkLoad(consumer->definition()) ||
GpuLower::current()->idModelOptions().isTensorIndexerEnabled() ||
GpuLower::current()->tmemInfo().hasTMemTensor()) {
index = GpuLower::current()->tensorIndexer().getLinearIndex(
consumer, consumer->definition(), loops, override_index, ld_st_matrix);
if (generate_pointer) {
auto address_offset = index;
if (consumer->getMemoryType() == MemoryType::Shared) {
auto consumer_dt = consumer->getDataType();
NVF_ERROR(consumer_dt.has_value());
auto index_dt = index->getDataType();
NVF_ERROR(index_dt.has_value());
address_offset = SimplifyingIrBuilder::mulExpr(
index,
IrBuilder::create<Val>(dataTypeSizeByte(*consumer_dt), *index_dt));
}
index = SimplifyingIrBuilder::addExpr(
IrBuilder::baseAddressExpr(consumer), address_offset);
NVF_ERROR(
GpuLower::current()->idModelOptions().isTensorIndexerEnabled(),
"Legacy indexer no longer available");

Val* index = GpuLower::current()->tensorIndexer().getLinearIndex(
consumer, consumer->definition(), loops, override_index, ld_st_matrix);
if (generate_pointer) {
auto address_offset = index;
if (consumer->getMemoryType() == MemoryType::Shared) {
auto consumer_dt = consumer->getDataType();
NVF_ERROR(consumer_dt.has_value());
auto index_dt = index->getDataType();
NVF_ERROR(index_dt.has_value());
address_offset = SimplifyingIrBuilder::mulExpr(
index,
IrBuilder::create<Val>(dataTypeSizeByte(*consumer_dt), *index_dt));
}
} else {
NVF_ERROR(
override_index.empty(),
"Overriding of consumer indexing with the legacy indexer is not "
"supported");
index = getConsumerStridedIndices(consumer, loops, generate_pointer);
index = SimplifyingIrBuilder::addExpr(
IrBuilder::baseAddressExpr(consumer), address_offset);
}

index = GpuLower::current()->commonScalarMap().hoistScalar(index, loops);
Expand Down
1 change: 1 addition & 0 deletions csrc/index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ class Index {
const std::unordered_map<IterDomain*, Val*>& override_index = {},
bool generate_pointer = false);

// TODO: Remove
//! Returns a vector of strided indices mapped onto the
//! allocation domain of a consumer tensor. The size of the returned
//! vector is guaranteed to be equal to the number of axes of the
Expand Down
Loading