Skip to content

Optimize reshape idx lmbda lowering#656

Merged
kaushikcfd merged 3 commits intomainfrom
optimize_reshape_idx_lmbda_lowering
Apr 14, 2026
Merged

Optimize reshape idx lmbda lowering#656
kaushikcfd merged 3 commits intomainfrom
optimize_reshape_idx_lmbda_lowering

Conversation

@kaushikcfd
Copy link
Copy Markdown
Collaborator

On main, reshape lowering generates sub-optimal indexing:

(py314_env) kaushikggg@line:~/temp$ cat hello.py 
import pytato as pt


x = pt.make_placeholder("x", (10, 4))
print(pt.generate_loopy(x.reshape(-1)).program)
(py314_env) kaushikggg@line:~/temp$ python hello.py 
---------------------------------------------------------------------------
KERNEL: _pt_kernel
---------------------------------------------------------------------------
ARGUMENTS:
x: type: np:dtype('float64'), shape: (10, 4), dim_tags: (N1:stride:4, N0:stride:1), offset: <class 'loopy.typing.auto'> in aspace: global
_pt_out: type: np:dtype('float64'), shape: (40), dim_tags: (N0:stride:1) out aspace: global
---------------------------------------------------------------------------
DOMAINS:
{  :  }
{ [_pt_out_dim0] : 0 <= _pt_out_dim0 <= 39 }
---------------------------------------------------------------------------
INAME TAGS:
_pt_out_dim0: None
---------------------------------------------------------------------------
INSTRUCTIONS:
for _pt_out_dim0
  _pt_out[_pt_out_dim0] = x[((0 + _pt_out_dim0*1) % 40) // 4, ((0 + _pt_out_dim0*1) % 4) // 1]  {id=_pt_out_store}
end _pt_out_dim0
---------------------------------------------------------------------------

After this patch, the assignment becomes:

for _pt_out_dim0
  _pt_out[_pt_out_dim0] = x[_pt_out_dim0 // 4, _pt_out_dim0 % 4]  {id=_pt_out_store}
end _pt_out_dim0

Notice the reduction in the redundant product with one and modulo with the size operations.

@kaushikcfd kaushikcfd force-pushed the optimize_reshape_idx_lmbda_lowering branch from d318c3e to c8e06ab Compare March 28, 2026 20:32
@kaushikcfd kaushikcfd requested a review from inducer March 28, 2026 20:33
@kaushikcfd kaushikcfd force-pushed the optimize_reshape_idx_lmbda_lowering branch from c8e06ab to 9ab995b Compare March 28, 2026 23:02
@kaushikcfd
Copy link
Copy Markdown
Collaborator Author

Soft ping.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Optimizes reshape-to-IndexLambda lowering to avoid generating redundant arithmetic (e.g., * 1, % total_size, // 1) in the computed index expressions, and adds a regression test to lock in the improved form.

Changes:

  • Simplify generated flattened-index arithmetic to avoid multiplying by 1 when forming the linearized index.
  • Drop redundant modulo-by-total-size and floor-div-by-1 in reshape index reconstruction where safe.
  • Add a regression test ensuring reshape(-1) lowers to (_0 // 4, _0 % 4) for a (10, 4) input.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
pytato/transform/lower_to_index_lambda.py Adds simplification logic in _generate_index_expressions to reduce redundant ops in reshape lowering.
test/test_pytato.py Adds a regression test asserting the simplified IndexLambda index tuple for flatten-reshape.
.basedpyright/baseline.json Removes baseline entries corresponding to pyright issues addressed by the updated lowering code.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread test/test_pytato.py
Comment thread pytato/transform/lower_to_index_lambda.py
Comment thread pytato/transform/lower_to_index_lambda.py
Comment thread pytato/transform/lower_to_index_lambda.py Outdated
Comment thread pytato/transform/lower_to_index_lambda.py
@inducer inducer force-pushed the optimize_reshape_idx_lmbda_lowering branch from b32942d to 20852cc Compare April 14, 2026 22:10
@inducer
Copy link
Copy Markdown
Owner

inducer commented Apr 14, 2026

Thanks! Pushed a counterproposal that (I think) is easier on the eyes even if it's a tad bit more expensive. LMK if that works for you.

Copy link
Copy Markdown
Collaborator Author

@kaushikcfd kaushikcfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@kaushikcfd kaushikcfd merged commit ee4a8a4 into main Apr 14, 2026
10 checks passed
@kaushikcfd kaushikcfd deleted the optimize_reshape_idx_lmbda_lowering branch April 14, 2026 22:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants