Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions hip/hipBlockCopy.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,20 @@ __global__ void copyToWaveletKernelOpt(
}
}

// ---- Phase 3: Store ZPB planes ----
// ---- Phase 3: Store ZPB planes (skip planes beyond wnz) ----
if constexpr (DO_COPY) {
uint32_t dst_byte = (gx + gy * wnx) * (uint32_t)sizeof(float);

#pragma unroll
for (int dz = 0; dz < BCOPY_ZPB; ++dz) {
auto plane_rsrc = __builtin_amdgcn_make_buffer_rsrc(
d_dst + (long)(z_start + dz) * wnx * wny,
0, -1, 0x00027000);
auto vi = __builtin_bit_cast(bcopy_int4_vec, regs[dz]);
__builtin_amdgcn_raw_buffer_store_b128(
vi, plane_rsrc, dst_byte, 0, SLC);
if (z_start + dz < wnz) {
auto plane_rsrc = __builtin_amdgcn_make_buffer_rsrc(
d_dst + (long)(z_start + dz) * wnx * wny,
0, -1, 0x00027000);
auto vi = __builtin_bit_cast(bcopy_int4_vec, regs[dz]);
__builtin_amdgcn_raw_buffer_store_b128(
vi, plane_rsrc, dst_byte, 0, SLC);
}
}
}

Expand Down Expand Up @@ -174,15 +176,20 @@ __global__ void copyFromWaveletKernelOpt(
uint32_t src_byte = (gx + gy * wnx) * (uint32_t)sizeof(float);
long wav_plane = (long)wnx * wny;

constexpr bcopy_float4_vec zero_vec_from = {0.0f, 0.0f, 0.0f, 0.0f};
bcopy_float4_vec regs[BCOPY_ZPB];
#pragma unroll
for (int dz = 0; dz < BCOPY_ZPB; ++dz) {
auto plane_rsrc = __builtin_amdgcn_make_buffer_rsrc(
const_cast<float*>(d_src + (long)(z_start + dz) * wav_plane),
0, -1, 0x00027000);
regs[dz] = __builtin_bit_cast(bcopy_float4_vec,
__builtin_amdgcn_raw_buffer_load_b128(
plane_rsrc, src_byte, 0, SLC));
if (z_start + dz < wnz) {
auto plane_rsrc = __builtin_amdgcn_make_buffer_rsrc(
const_cast<float*>(d_src + (long)(z_start + dz) * wav_plane),
0, -1, 0x00027000);
regs[dz] = __builtin_bit_cast(bcopy_float4_vec,
__builtin_amdgcn_raw_buffer_load_b128(
plane_rsrc, src_byte, 0, SLC));
} else {
regs[dz] = zero_vec_from;
}
}

uint32_t dst_byte = ((y0 + gy) * ldimx + x0 + gx) * (uint32_t)sizeof(float);
Expand Down
129 changes: 93 additions & 36 deletions hip/hipCompress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define DS79_INCLUDE_REG32
#include "hipWaveletRLE.h"
#include "hipWaveletRLEInverse.h"
#include "hipWaveletRLE2D.h"
#include "hipBlockCopy.h"

#include <cmath>
Expand Down Expand Up @@ -67,20 +68,37 @@ hipError_t hipCompressCreatePlan(hipCompressPlan** plan, int nx, int ny, int nz,

if (nx <= 0 || ny <= 0 || nz <= 0)
PLAN_ERROR(p, HIP_COMPRESS_ERROR_INVALID_DIMENSIONS, hipErrorInvalidValue);
if (nx % 32 != 0 || ny % 32 != 0 || nz % 32 != 0)
PLAN_ERROR(p, HIP_COMPRESS_ERROR_NOT_MULTIPLE_OF_32, hipErrorInvalidValue);

bool is_2d = (nz == 1);

if (is_2d) {
if (nx % 32 != 0 || ny % 32 != 0)
PLAN_ERROR(p, HIP_COMPRESS_ERROR_NOT_MULTIPLE_OF_32, hipErrorInvalidValue);
} else {
if (nx % 32 != 0 || ny % 32 != 0 || nz % 32 != 0)
PLAN_ERROR(p, HIP_COMPRESS_ERROR_NOT_MULTIPLE_OF_32, hipErrorInvalidValue);
}

if ((long)nx * (long)ny * (long)sizeof(float) > (1L << 32))
PLAN_ERROR(p, HIP_COMPRESS_ERROR_PLANE_TOO_LARGE, hipErrorInvalidValue);

p->kernel = kernel;
p->nx = nx; p->ny = ny; p->nz = nz;
p->num_blocks = (nx / 32) * (ny / 32) * (nz / 32);
p->is_2d = is_2d;
p->aux_stream = aux_stream;
p->compress_pending = false;
p->last_error = HIP_COMPRESS_ERROR_HIP_RUNTIME;

if (is_2d) {
p->num_blocks = (nx / 32) * (ny / 32);
p->scratch_slot_stride = WRLE2D_SLOT_BYTES;
} else {
p->num_blocks = (nx / 32) * (ny / 32) * (nz / 32);
p->scratch_slot_stride = 4L * WRLE_LDS_BYTES;
}

int nb = p->num_blocks;
long scratch_size = (long)nb * 4L * WRLE_LDS_BYTES;
long scratch_size = (long)nb * (long)p->scratch_slot_stride;

HIPCHECK_PLAN(p, hipMalloc(&p->d_scratch, scratch_size));
HIPCHECK_PLAN(p, hipMalloc(&p->d_mulfac, sizeof(float)));
Expand All @@ -95,13 +113,15 @@ hipError_t hipCompressCreatePlan(hipCompressPlan** plan, int nx, int ny, int nz,

HIPCHECK_PLAN(p, hipMalloc(&p->d_rms, sizeof(double)));

p->max_copy_blocks = (nx / 32) * (ny / 32) * (nz / BCOPY_ZPB);
int nz_for_copy = is_2d ? 1 : nz;
p->max_copy_blocks = (nx / 32) * (ny / 32)
* ((nz_for_copy + BCOPY_ZPB - 1) / BCOPY_ZPB);
HIPCHECK_PLAN(p, hipMalloc(&p->d_partial_sums, p->max_copy_blocks * sizeof(double)));

HIPCHECK_PLAN(p, hipEventCreateWithFlags(&p->ready_event, hipEventDisableTiming));
HIPCHECK_PLAN(p, hipHostMalloc(&p->h_staging, 2 * sizeof(size_t)));

// JIT warmup: exclusive_scan on aux_stream (no user stream available yet)
// JIT warmup: exclusive_scan on aux_stream
err = rocprim::exclusive_scan(
p->d_scan_temp, p->scan_temp_bytes,
p->d_block_sizes, p->d_block_offsets, (size_t)0, (size_t)1,
Expand Down Expand Up @@ -149,25 +169,35 @@ hipError_t hipCompress(
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_INVALID_SCALE, hipErrorInvalidValue);

const int nx = plan->nx, ny = plan->ny, nz = plan->nz;
const int ldimx = nx, ldimxy = nx * ny;
const int ldimx = nx;
const int nb = plan->num_blocks;
const int num_mulfacs = 1;
const int hdr_size = hipCompressHeaderSize(nb, num_mulfacs);
hipStream_t s = user_stream;
hipStream_t aux = plan->aux_stream;

// 1. Fused wavelet + quantize + RLE → scratch (user_stream)
dim3 grid((nx + 31) / 32, (ny + 31) / 32, (nz + 31) / 32);
if (plan->kernel == HIP_COMPRESS_KERNEL_SEGRLE) {
waveletSegRLEFusedKernel<<<grid, dim3(256), 0, s>>>(
if (plan->is_2d) {
int nbx = nx / 32, nby = ny / 32;
dim3 grid((nbx + WRLE2D_TILES_PER_WG - 1) / WRLE2D_TILES_PER_WG, nby);
waveletRLE2DFusedKernel<<<grid, dim3(256), 0, s>>>(
d_input, plan->d_scratch, plan->d_block_sizes,
scale, ldimx, ldimxy,
scale, ldimx, nbx,
d_rms, plan->d_mulfac);
} else {
waveletRLEFusedKernel<<<grid, dim3(256), 0, s>>>(
d_input, plan->d_scratch, plan->d_block_sizes,
scale, ldimx, ldimxy,
d_rms, plan->d_mulfac);
const int ldimxy = nx * ny;
dim3 grid((nx + 31) / 32, (ny + 31) / 32, (nz + 31) / 32);
if (plan->kernel == HIP_COMPRESS_KERNEL_SEGRLE) {
waveletSegRLEFusedKernel<<<grid, dim3(256), 0, s>>>(
d_input, plan->d_scratch, plan->d_block_sizes,
scale, ldimx, ldimxy,
d_rms, plan->d_mulfac);
} else {
waveletRLEFusedKernel<<<grid, dim3(256), 0, s>>>(
d_input, plan->d_scratch, plan->d_block_sizes,
scale, ldimx, ldimxy,
d_rms, plan->d_mulfac);
}
}

// 2. Exclusive scan for compaction offsets (user_stream)
Expand All @@ -181,10 +211,18 @@ hipError_t hipCompress(
HIPCHECK_PLAN(plan, hipStreamWaitEvent(aux, plan->ready_event, 0));

// 4. Compact + write header on aux_stream
wrleCompactKernel<<<nb, 256, 0, aux>>>(
plan->d_scratch, d_output + hdr_size,
plan->d_block_sizes, plan->d_block_offsets,
d_output, nb, num_mulfacs, plan->d_mulfac);
if (plan->is_2d) {
wrle2DCompactKernel<<<nb, 256, 0, aux>>>(
plan->d_scratch, d_output + hdr_size,
plan->d_block_sizes, plan->d_block_offsets,
d_output, nb, num_mulfacs, plan->d_mulfac,
plan->scratch_slot_stride);
} else {
wrleCompactKernel<<<nb, 256, 0, aux>>>(
plan->d_scratch, d_output + hdr_size,
plan->d_block_sizes, plan->d_block_offsets,
d_output, nb, num_mulfacs, plan->d_mulfac);
}

// 5. Async readback on aux_stream
HIPCHECK_PLAN(plan, hipMemcpyAsync(&plan->h_staging[0], plan->d_block_offsets + (nb - 1),
Expand Down Expand Up @@ -241,14 +279,19 @@ hipError_t hipCopyToWaveletLayout(
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_NULL_INPUT, hipErrorInvalidValue);
if (!d_dst && !d_rms_out)
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_BOTH_OUTPUTS_NULL, hipErrorInvalidValue);
if (ex < 32 || ey < 32 || ez < 32)
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_WINDOW_TOO_SMALL, hipErrorInvalidValue);
if (plan->is_2d) {
if (ex < 32 || ey < 32 || ez != 1)
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_WINDOW_TOO_SMALL, hipErrorInvalidValue);
} else {
if (ex < 32 || ey < 32 || ez < 32)
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_WINDOW_TOO_SMALL, hipErrorInvalidValue);
}
if ((long)ldimxy * (long)sizeof(float) > (1L << 32))
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_PLANE_TOO_LARGE, hipErrorInvalidValue);

int wnx = hipCompressWaveletDim(ex);
int wny = hipCompressWaveletDim(ey);
int wnz = hipCompressWaveletDim(ez);
int wnz = plan->is_2d ? 1 : hipCompressWaveletDim(ez);

if (wnx != plan->nx || wny != plan->ny || wnz != plan->nz)
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_EXTRACTION_DIMS_MISMATCH, hipErrorInvalidValue);
Expand All @@ -258,7 +301,7 @@ hipError_t hipCopyToWaveletLayout(
bool do_rms = (d_rms_out != nullptr);
long total_samples = (long)ex * ey * ez;

dim3 grid(wnx / 32, wny / 32, wnz / BCOPY_ZPB);
dim3 grid(wnx / 32, wny / 32, (wnz + BCOPY_ZPB - 1) / BCOPY_ZPB);
int total_blocks = grid.x * grid.y * grid.z;

if (do_copy && do_rms) {
Expand Down Expand Up @@ -307,19 +350,24 @@ hipError_t hipCopyFromWaveletLayout(
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_NULL_INPUT, hipErrorInvalidValue);
if (!d_dst)
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_NULL_OUTPUT, hipErrorInvalidValue);
if (ex < 32 || ey < 32 || ez < 32)
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_WINDOW_TOO_SMALL, hipErrorInvalidValue);
if (plan->is_2d) {
if (ex < 32 || ey < 32 || ez != 1)
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_WINDOW_TOO_SMALL, hipErrorInvalidValue);
} else {
if (ex < 32 || ey < 32 || ez < 32)
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_WINDOW_TOO_SMALL, hipErrorInvalidValue);
}
if ((long)ldimxy * (long)sizeof(float) > (1L << 32))
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_PLANE_TOO_LARGE, hipErrorInvalidValue);

int wnx = hipCompressWaveletDim(ex);
int wny = hipCompressWaveletDim(ey);
int wnz = hipCompressWaveletDim(ez);
int wnz = plan->is_2d ? 1 : hipCompressWaveletDim(ez);

if (wnx != plan->nx || wny != plan->ny || wnz != plan->nz)
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_EXTRACTION_DIMS_MISMATCH, hipErrorInvalidValue);

dim3 grid(wnx / 32, wny / 32, wnz / BCOPY_ZPB);
dim3 grid(wnx / 32, wny / 32, (wnz + BCOPY_ZPB - 1) / BCOPY_ZPB);
copyFromWaveletKernelOpt<<<grid, 256, 0, user_stream>>>(
d_src, wnx, wny, wnz,
d_dst, ldimx, ldimxy, x0, y0, z0, ex, ey, ez);
Expand All @@ -344,17 +392,26 @@ hipError_t hipDecompress(
PLAN_ERROR(plan, HIP_COMPRESS_ERROR_NULL_OUTPUT, hipErrorInvalidValue);

const int nx = plan->nx, ny = plan->ny, nz = plan->nz;
const int ldimx = nx, ldimxy = nx * ny;
const int ldimx = nx;

dim3 grid((nx + 31) / 32, (ny + 31) / 32, (nz + 31) / 32);
if (plan->kernel == HIP_COMPRESS_KERNEL_SEGRLE) {
waveletSegRLEInverseFusedKernel<<<grid, dim3(256), 0, user_stream>>>(
if (plan->is_2d) {
int nbx = nx / 32, nby = ny / 32;
dim3 grid((nbx + WRLE2D_TILES_PER_WG - 1) / WRLE2D_TILES_PER_WG, nby);
waveletRLE2DInverseFusedKernel<<<grid, dim3(256), 0, user_stream>>>(
d_input, nullptr, nullptr,
d_output, 0.0f, ldimx, ldimxy, 1);
d_output, 0.0f, ldimx, nbx, 1);
} else {
waveletRLEInverseFusedKernel<<<grid, dim3(256), 0, user_stream>>>(
d_input, nullptr, nullptr,
d_output, 0.0f, ldimx, ldimxy, 1);
const int ldimxy = nx * ny;
dim3 grid((nx + 31) / 32, (ny + 31) / 32, (nz + 31) / 32);
if (plan->kernel == HIP_COMPRESS_KERNEL_SEGRLE) {
waveletSegRLEInverseFusedKernel<<<grid, dim3(256), 0, user_stream>>>(
d_input, nullptr, nullptr,
d_output, 0.0f, ldimx, ldimxy, 1);
} else {
waveletRLEInverseFusedKernel<<<grid, dim3(256), 0, user_stream>>>(
d_input, nullptr, nullptr,
d_output, 0.0f, ldimx, ldimxy, 1);
}
}

hipError_t launch_err = hipGetLastError();
Expand All @@ -372,6 +429,6 @@ hipError_t hipCompressMaxOutputSize(const hipCompressPlan* plan, size_t* size)
}
plan->last_error = HIP_COMPRESS_SUCCESS;
int hdr_size = hipCompressHeaderSize(plan->num_blocks, 1);
*size = (size_t)hdr_size + (size_t)plan->num_blocks * 4 * WRLE_LDS_BYTES;
*size = (size_t)hdr_size + (size_t)plan->num_blocks * plan->scratch_slot_stride;
return hipSuccess;
}
5 changes: 4 additions & 1 deletion hip/hipCompress.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ struct hipCompressPlan {
hipCompressKernel kernel;
int nx, ny, nz;
int num_blocks;
bool is_2d; // true when nz == 1
size_t scratch_slot_stride; // per-block scratch slot size

unsigned char* d_scratch;
float* d_mulfac;
Expand Down Expand Up @@ -89,7 +91,8 @@ hipError_t hipCompressDestroyPlan(hipCompressPlan* plan);
// Zero-fills the padding band (ex..wnx-1, etc.).
// RMS is computed over the ex*ey*ez extraction window only.
//
// ex,ey,ez >= 32. Wavelet dims are derived internally: wnx = round32(ex), etc.
// 3D: ex,ey,ez >= 32. Wavelet dims: wnx = round32(ex), etc.
// 2D (nz=1): ex,ey >= 32, ez must be 1, wnz = 1.
// Must equal plan dimensions exactly.
//
// d_dst may be NULL (RMS-only mode). d_rms_out may be NULL (copy-only mode).
Expand Down
Loading