Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ RUN curl -L -H "Accept: application/octet-stream" https://api.github.com/repos/P

# Store RISC-V LLVM for TorchSim
ENV TORCHSIM_LLVM_PATH=/riscv-llvm/bin
ENV TORCHSIM_LLVM_INCLUDE_PATH=/riscv-llvm/include
ENV TORCHSIM_DIR=/workspace/PyTorchSim
ENV LLVM_DIR=/riscv-llvm

# Download Spike simulator
RUN curl -L -H "Accept: application/octet-stream" https://api.github.com/repos/PSAL-POSTECH/riscv-isa-sim/releases/assets/${SPIKE_ASSET_ID} -o /tmp/spike-release.tar.gz && \
Expand Down
10 changes: 5 additions & 5 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,12 +438,12 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs)

# Handle scatter store
if "tmp" in str(index):
if mode == "atomic_add":
# Convert the output buffer type to the inplace buffer
arg_name = V.graph.scheduler.mutation_real_name.get(name, name)
if arg_name not in self.kernel_group.args.inplace_buffers:
self.kernel_group.args.make_inplace(arg_name, arg_name)
# Convert the output buffer type to the inplace buffer
arg_name = V.graph.scheduler.mutation_real_name.get(name, name)
if arg_name not in self.kernel_group.args.inplace_buffers:
self.kernel_group.args.make_inplace(arg_name, arg_name)

if mode == "atomic_add":
loaded_value = ops.load(name, index)
value = ops.add(loaded_value, value)
index, _ = self.convert_indirect_indexing(index)
Expand Down
20 changes: 10 additions & 10 deletions PyTorchSimFrontend/mlir/mlir_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ def _adjust_one(dim_size, tile_size):
remain = candidate_tile_size[axis] % stride

if remain:
candidate_tile_size[axis] += stride - remain
self.tile_constraint[axis].must_divide_dim = False
# #201: relax vlane_stride constraints
self.vmap.vlane_stride = 1
return candidate_tile_size

def scale_tile_dim(self, axis, dim_sz, scale_factor=2):
Expand Down Expand Up @@ -488,7 +488,7 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N
self.name = ""
self._tile_size = list(tile_size)
self._tile_stride = None
self.tile_constraint = [TileConstraint(vlane_stride) for _ in tile_size]
self.tile_constraint = [TileConstraint(vlane_stride if idx == vlane_split_axis else 1) for idx, _ in enumerate(tile_size)]
self.tile_axis_order = list(range(len(tile_size)))
self.update_tile_stride()

Expand Down Expand Up @@ -718,13 +718,13 @@ def compute_tile_size(self, nodes, vars, reduction_vars):
init_tile_desc.nr_rdim = len(reduction_vars)
self.kernel_group.set_tile_info(init_tile_desc)

# Handle edge case
if len(self.ranges)==1 and self.ranges[0] == 1: # Scalar case 2
self.kernel_group.tile_desc.vmap.vlane_stride = 1
self.kernel_group.tile_desc.vmap.vlane_split_axis = 0
elif vlane_split_axis == -1: # Reduction only case
self.kernel_group.tile_desc.vmap.vlane_split_axis = 0
self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0]
# Handle edge case
if len(self.ranges)==1 and self.ranges[0] == 1: # Scalar case 2
self.kernel_group.tile_desc.vmap.vlane_stride = 1
self.kernel_group.tile_desc.vmap.vlane_split_axis = 0
elif vlane_split_axis == -1: # Reduction only case
self.kernel_group.tile_desc.vmap.vlane_split_axis = 0
self.kernel_group.tile_desc.vmap.vlane_stride = self.kernel_group.tile_desc.get_tile_size()[0]

# Handle implict dims. Input operand could be high dimension tensor.
# Note: https://github.com/PSAL-POSTECH/PyTorchSim/issues/173
Expand Down
6 changes: 4 additions & 2 deletions tests/test_indirect_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ def vectoradd(a, idx, b):
a[idx, :] = b
return a
x = torch.randn(size, dtype=torch.float32).to(device=device)
x_cpu = x.clone().cpu()
idx = torch.randint(0,128, [128]).to(device=device)
y = torch.randn(128, dtype=torch.float32).to(device=device)
y = torch.randn(size[1], dtype=torch.float32).to(device=device)
opt_fn = torch.compile(dynamic=False)(vectoradd)
res = opt_fn(x, idx, y)
out = vectoradd(x.cpu(), idx.cpu(), y.cpu())
out = vectoradd(x_cpu, idx.cpu(), y.cpu())
test_result("Indirect VectorAdd", res, out)

if __name__ == "__main__":
Expand All @@ -86,6 +87,7 @@ def vectoradd(a, idx, b):
module = PyTorchSimRunner.setup_device()
device = module.custom_device()
test_scatter_full(device)
test_scatter_full(device, size=(2048, 2048))
test_scatter_add(device)
test_indirect_vectoradd(device)
#test_embedding(device, 1024, 2048)
2 changes: 0 additions & 2 deletions tutorial/jupyterhub/Dockerfile.ksc2025
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ RUN cd llvm-project && mkdir build && cd build && \

# Store RISC-V LLVM for TorchSim
ENV TORCHSIM_LLVM_PATH=/riscv-llvm/bin
ENV TORCHSIM_LLVM_INCLUDE_PATH=/riscv-llvm/include
ENV TORCHSIM_DIR=/workspace/PyTorchSim
ENV LLVM_DIR=/riscv-llvm

# Download RISC-V tool chain
RUN apt install -y wget && \
Expand Down