diff --git a/Dockerfile.base b/Dockerfile.base index 6a21760b..f961859e 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -67,9 +67,7 @@ RUN curl -L -H "Accept: application/octet-stream" https://api.github.com/repos/P # Store RISC-V LLVM for TorchSim ENV TORCHSIM_LLVM_PATH=/riscv-llvm/bin -ENV TORCHSIM_LLVM_INCLUDE_PATH=/riscv-llvm/include ENV TORCHSIM_DIR=/workspace/PyTorchSim -ENV LLVM_DIR=/riscv-llvm # Download Spike simulator RUN curl -L -H "Accept: application/octet-stream" https://api.github.com/repos/PSAL-POSTECH/riscv-isa-sim/releases/assets/${SPIKE_ASSET_ID} -o /tmp/spike-release.tar.gz && \ diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 266d884b..297ea162 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -438,12 +438,12 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs) # Handle scatter store if "tmp" in str(index): - if mode == "atomic_add": - # Convert the output buffer type to the inplace buffer - arg_name = V.graph.scheduler.mutation_real_name.get(name, name) - if arg_name not in self.kernel_group.args.inplace_buffers: - self.kernel_group.args.make_inplace(arg_name, arg_name) + # Convert the output buffer type to the inplace buffer + arg_name = V.graph.scheduler.mutation_real_name.get(name, name) + if arg_name not in self.kernel_group.args.inplace_buffers: + self.kernel_group.args.make_inplace(arg_name, arg_name) + if mode == "atomic_add": loaded_value = ops.load(name, index) value = ops.add(loaded_value, value) index, _ = self.convert_indirect_indexing(index) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 15408c0d..b86607ea 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -332,8 +332,8 @@ def _adjust_one(dim_size, tile_size): remain = candidate_tile_size[axis] % stride if remain: - candidate_tile_size[axis] += stride - remain - self.tile_constraint[axis].must_divide_dim = False + # #201: relax vlane_stride constraints + self.vmap.vlane_stride = 1 return candidate_tile_size def scale_tile_dim(self, axis, dim_sz, scale_factor=2): @@ -488,7 +488,7 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N self.name = "" self._tile_size = list(tile_size) self._tile_stride = None - self.tile_constraint = [TileConstraint(vlane_stride) for _ in tile_size] + self.tile_constraint = [TileConstraint(vlane_stride if idx == vlane_split_axis else 1) for idx, _ in enumerate(tile_size)] self.tile_axis_order = list(range(len(tile_size))) self.update_tile_stride() @@ -718,13 +718,13 @@ def compute_tile_size(self, nodes, vars, reduction_vars): init_tile_desc.nr_rdim = len(reduction_vars) self.kernel_group.set_tile_info(init_tile_desc) - # Handle edge case - if len(self.ranges)==1 and self.ranges[0] == 1: # Scalar case 2 - self.kernel_group.tile_desc.vmap.vlane_stride = 1 - self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 - elif vlane_split_axis == -1: # Reduction only case - self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 - self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0] + # Handle edge case + if len(self.ranges)==1 and self.ranges[0] == 1: # Scalar case 2 + self.kernel_group.tile_desc.vmap.vlane_stride = 1 + self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 + elif vlane_split_axis == -1: # Reduction only case + self.kernel_group.tile_desc.vmap.vlane_split_axis = 0 + self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0] # Handle implict dims. Input operand could be high dimension tensor. # Note: https://github.com/PSAL-POSTECH/PyTorchSim/issues/173 diff --git a/tests/test_indirect_access.py b/tests/test_indirect_access.py index 6cfa7b58..d103ee1b 100644 --- a/tests/test_indirect_access.py +++ b/tests/test_indirect_access.py @@ -70,11 +70,12 @@ def vectoradd(a, idx, b): a[idx, :] = b return a x = torch.randn(size, dtype=torch.float32).to(device=device) + x_cpu = x.clone().cpu() idx = torch.randint(0,128, [128]).to(device=device) - y = torch.randn(128, dtype=torch.float32).to(device=device) + y = torch.randn(size[1], dtype=torch.float32).to(device=device) opt_fn = torch.compile(dynamic=False)(vectoradd) res = opt_fn(x, idx, y) - out = vectoradd(x.cpu(), idx.cpu(), y.cpu()) + out = vectoradd(x_cpu, idx.cpu(), y.cpu()) test_result("Indirect VectorAdd", res, out) if __name__ == "__main__": @@ -86,6 +87,7 @@ def vectoradd(a, idx, b): module = PyTorchSimRunner.setup_device() device = module.custom_device() test_scatter_full(device) + test_scatter_full(device, size=(2048, 2048)) test_scatter_add(device) test_indirect_vectoradd(device) #test_embedding(device, 1024, 2048) \ No newline at end of file diff --git a/tutorial/jupyterhub/Dockerfile.ksc2025 b/tutorial/jupyterhub/Dockerfile.ksc2025 index 9eaec15a..7633c048 100644 --- a/tutorial/jupyterhub/Dockerfile.ksc2025 +++ b/tutorial/jupyterhub/Dockerfile.ksc2025 @@ -52,9 +52,7 @@ RUN cd llvm-project && mkdir build && cd build && \ # Store RISC-V LLVM for TorchSim ENV TORCHSIM_LLVM_PATH=/riscv-llvm/bin -ENV TORCHSIM_LLVM_INCLUDE_PATH=/riscv-llvm/include ENV TORCHSIM_DIR=/workspace/PyTorchSim -ENV LLVM_DIR=/riscv-llvm # Download RISC-V tool chain RUN apt install -y wget && \