diff --git a/intermediate_source/torch_compile_conv_bn_fuser.py b/intermediate_source/torch_compile_conv_bn_fuser.py index e057d145..e4978a10 100644 --- a/intermediate_source/torch_compile_conv_bn_fuser.py +++ b/intermediate_source/torch_compile_conv_bn_fuser.py @@ -1,32 +1,31 @@ # -*- coding: utf-8 -*- """ -Building a Convolution/Batch Norm fuser with torch.compile +torch.compile 기반 합성곱·배치 정규화 퓨저(Convolution/Batch Norm fuser) 만들기 =========================================================== -**Author:** `Horace He `_, `Will Feng `_ +**저자:** `Horace He `_, `Will Feng `_ 번역: `심기택 `_ .. grid:: 2 - .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + .. grid-item-card:: :octicon:`mortar-board;1em;` 배울 내용 :class-card: card-prerequisites - * How to register custom fusion patterns with torch.compile's pattern matcher + * torch.compile의 패턴 매처에 커스텀 퓨전 패턴을 등록하는 방법 - .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + .. grid-item-card:: :octicon:`list-unordered;1em;` 전제 조건 :class-card: card-prerequisites * PyTorch v2.7.0 .. note:: - This optimization only works for models in inference mode (i.e. ``model.eval()``). - However, torch.compile's pattern matching system works for both training and inference. + 이 최적화는 추론 모드의 모델에만 적용됩니다 (예: ``model.eval()``). + 하지만 torch.compile의 패턴 매칭 시스템은 학습과 추론 모두에서 동작합니다. """ ###################################################################### -# First, let's get some imports out of the way (we will be using all -# of these later in the code). +# 먼저 이후 코드에서 사용할 모듈들을 import 하겠습니다. from typing import Type, Dict, Any, Tuple, Iterable import copy @@ -36,10 +35,10 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### -# For this tutorial, we are going to create a model consisting of convolutions -# and batch norms. Note that this model has some tricky components - some of -# the conv/batch norm patterns are hidden within Sequentials and one of the -# ``BatchNorms`` is wrapped in another Module. +# 이번 튜토리얼에서는 합성곱과 배치 정규화로 구성된 모델을 만들어 보겠습니다. +# 이 모델에는 몇 가지 까다로운 요소가 있다는 점에 유의하세요. +# 일부 합성곱·배치 정규화 패턴은 Sequential 내부에 숨겨져 있으며, 배치 정규화 중 하나는 또 다른 +# Module로 감싸져 있습니다. class WrappedBatchNorm(nn.Module): def __init__(self): @@ -72,42 +71,37 @@ def forward(self, x): model.eval() ###################################################################### -# Fusing Convolution with Batch Norm +# 합성곱과 배치 정규화 퓨전하기 # ----------------------------------------- -# One of the primary challenges with trying to automatically fuse convolution -# and batch norm in PyTorch is that PyTorch does not provide an easy way of -# accessing the computational graph. torch.compile resolves this problem by -# capturing the computational graph during compilation, allowing us to apply -# pattern-based optimizations across the entire model, including operations -# nested within Sequential modules or wrapped in custom modules. +# 합성곱과 배치 정규화를 자동으로 퓨전하려 할 때의 주요 어려움 중 하나는 PyTorch가 계산 +# 그래프(computational graph)에 쉽게 접근할 수 있는 방법을 제공하지 않는다는 점입니다. +# torch.compile은 컴파일 과정에서 계산 그래프를 확보함으로써 이 문제를 해결하며, +# 이를 통해 Sequential 모듈 내부에 있는 중첩된 연산이나 사용자 정의 모듈로 감싸진 연산을 포함한 +# 모델 전체에 패턴 기반 최적화를 적용할 수 있습니다. import torch._inductor.pattern_matcher as pm from torch._inductor.pattern_matcher import register_replacement ###################################################################### -# torch.compile will capture a graph representation of our model. During -# compilation, modules hidden within Sequential containers and wrapped -# modules are all inlined into the graph, making them available for -# pattern matching and optimization. +# torch.compile은 모델의 계산 그래프를 확보합니다. +# 컴파일 과정에서 Sequential 컨테이너에 숨겨진 모듈과 다른 모듈로 감싸진 모듈들은 모두 그래프에 +# 직접 포함되어 패턴 매칭과 최적화의 대상이 됩니다. -#################################### -# Fusing Convolution with Batch Norm +###################################################################### +# 합성곱과 배치 정규화 퓨전하기 # ---------------------------------- -# Unlike some other fusions, fusion of convolution with batch norm does not -# require any new operators. Instead, as batch norm during inference -# consists of a pointwise add and multiply, these operations can be "baked" -# into the preceding convolution's weights. This allows us to remove the batch -# norm entirely from our model! Read -# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The -# code here is copied from -# https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py -# clarity purposes. +# 다른 일부 퓨전과 달리, 합성곱과 배치 정규화의 퓨전에는 새로운 연산자가 필요하지 않습니다. +# 추론 과정에서 배치 정규화는 요소별 덧셈과 곱셈으로 이루어지므로 이러한 연산들을 앞선 합성곱의 가중치에 +# 반영할 수 있습니다. 이를 통해 모델에서 배치 정규화를 완전히 제거할 수 있습니다! +# 자세한 내용은 이 글을 참고하세요. +# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ +# 여기서 사용한 코드는 설명의 명확성을 위해 다음의 구현을 가져온 것입니다. https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py def fuse_conv_bn_eval(conv, bn): """ - Given a conv Module `A` and an batch_norm module `B`, returns a conv - module `C` such that C(x) == B(A(x)) in inference mode. + 합성곱 모듈 A와 배치 정규화 모듈 B가 주어졌을 때, 추론 모드에서 C(x) == B(A(x))를 만족하는 + 합성곱 모듈 C를 반환합니다. """ - assert(not (conv.training or bn.training)), "Fusion only for eval!" + assert(not (conv.training or bn.training)), "추론 모드에서만 퓨전합니다!" fused_conv = copy.deepcopy(conv) fused_conv.weight, fused_conv.bias = \ @@ -131,14 +125,13 @@ def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b) -#################################### -# Pattern Matching with torch.compile +###################################################################### +# torch.compile 기반 패턴 매칭 # ------------------------------------ -# Now that we have our fusion logic, we need to register a pattern that -# torch.compile's pattern matcher will recognize and replace during -# compilation. +# 이제 퓨전 로직을 구현했으므로 컴파일 과정에서 torch.compile의 패턴 매처가 인식하고 치환할 수 있는 +# 패턴을 등록해야 합니다. -# Define the pattern we want to match: conv2d followed by batch_norm +# 매칭하려는 패턴을 정의합니다: conv2d 다음에 batch_norm이 오는 패턴입니다. def conv_bn_pattern(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias): conv_out = torch.nn.functional.conv2d(x, conv_weight, conv_bias) bn_out = torch.nn.functional.batch_norm( @@ -153,30 +146,29 @@ def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, b ) return torch.nn.functional.conv2d(x, fused_weight, fused_bias) -# Example inputs are needed to trace the pattern functions. -# The inputs should match the function signatures of conv_bn_pattern and conv_bn_replacement. -# These are used to trace the pattern functions to create the match template. -# IMPORTANT: The pattern matcher is shape-agnostic! The specific shapes you use here -# don't limit what shapes will be matched - any valid conv2d->batch_norm sequence -# will be matched regardless of channels, kernel size, or spatial dimensions. -# - x: input tensor (batch_size, channels, height, width) -# - conv_weight: (out_channels, in_channels, kernel_h, kernel_w) -# - conv_bias: (out_channels,) -# - bn_mean, bn_var, bn_weight, bn_bias: all have shape (num_features,) matching out_channels +# 패턴 함수들을 추적하려면 예시 입력이 필요합니다. +# 이 입력들은 conv_bn_pattern 및 conv_bn_replacement의 함수 시그니처와 일치해야 합니다. +# 이들은 패턴 함수를 추적하여 매치 템플릿을 만드는 데 사용됩니다. +# 중요: 패턴 매처는 형태에 구애받지 않습니다! 여기서 사용하는 특정 형태가 매칭될 형태를 제한하지 않습니다. +# 채널, 커널 크기, 공간 차원에 관계없이 유효한 conv2d -> batch_norm 시퀀스라면 모두 매칭됩니다. +# - x: 입력 텐서 (배치 크기, 채널, 높이, 너비) +# - conv_weight: (출력 채널, 입력 채널, 커널 높이, 커널 너비) +# - conv_bias: (출력 채널,) +# - bn_mean, bn_var, bn_weight, bn_bias: 모두 출력 채널과 일치하는 형태(num_features,)를 가짐 example_inputs = [ - torch.randn(1, 1, 4, 4).to(device), # x: input tensor - torch.randn(1, 1, 1, 1).to(device), # conv_weight: 1 output channel, 1 input channel, 1x1 kernel - torch.randn(1).to(device), # conv_bias: 1 output channel - torch.randn(1).to(device), # bn_mean: batch norm running mean - torch.randn(1).to(device), # bn_var: batch norm running variance - torch.randn(1).to(device), # bn_weight: batch norm weight (gamma) - torch.randn(1).to(device), # bn_bias: batch norm bias (beta) + torch.randn(1, 1, 4, 4).to(device), # x: 입력 텐서 + torch.randn(1, 1, 1, 1).to(device), # conv_weight: 출력 채널 1, 입력 채널 1, 1x1 커널 + torch.randn(1).to(device), # conv_bias: 출력 채널 1 + torch.randn(1).to(device), # bn_mean: 배치 정규화 이동 평균 + torch.randn(1).to(device), # bn_var: 배치 정규화 이동 분산 + torch.randn(1).to(device), # bn_weight: 배치 정규화 가중치 (감마) + torch.randn(1).to(device), # bn_bias: 배치 정규화 편향 (베타) ] from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor import config -# Create a pattern matcher pass and register our pattern +# 패턴 매처 패스를 생성하고 패턴을 등록합니다. patterns = PatternMatcherPass() register_replacement( @@ -187,48 +179,47 @@ def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, b patterns, ) -# Create a custom pass function that applies our patterns +# 등록된 패턴을 적용하는 커스텀 패스 함수를 생성합니다. def conv_bn_fusion_pass(graph): return patterns.apply(graph) -# Set our custom pass in the config +# 설정에 커스텀 패스를 지정합니다. config.post_grad_custom_post_pass = conv_bn_fusion_pass ###################################################################### -# .. note:: -# We make some simplifications here for demonstration purposes, such as only -# matching 2D convolutions. The pattern matcher in torch.compile -# can handle more complex patterns. +# .. 참고:: +# 설명을 돕기 위해 2D 합성곱 연산만 매칭하는 등 일부 단순화를 적용하였습니다. +# torch.compile의 패턴 매처는 이보다 훨씬 더 복잡한 패턴도 처리할 수 있습니다. ###################################################################### -# Testing out our Fusion Pass +# 퓨전 패스 테스트하기 # ----------------------------------------- -# We can now run this fusion pass on our initial toy model and verify that our -# results are identical. In addition, we can print out the code for our fused -# model and verify that there are no more batch norms. +# 앞서 만든 토이 모델에 이 퓨전 패스를 실행하여 결과가 기존과 완벽히 동일한지 확인할 수 있습니다. +# 또한, 퓨전이 완료된 모델의 코드를 직접 출력해 봄으로써 배치 정규화 연산이 정말로 모두 제거되었는지 +# 검증할 수 있습니다. from torch._dynamo.utils import counters -# Clear the counters before compilation +# 컴파일하기 전에 카운터를 초기화합니다. counters.clear() -# Ensure pattern matcher is enabled +# 패턴 매처가 활성화되어 있는지 확인합니다. config.pattern_matcher = True fused_model = torch.compile(model, backend="inductor") inp = torch.randn(5, 1, 1, 1).to(device) -# Run the model to trigger compilation and pattern matching +# 모델을 실행하여 컴파일과 패턴 매칭 과정을 동작시킵니다. with torch.no_grad(): output = fused_model(inp) expected = model(inp) torch.testing.assert_close(output, expected) -# Check how many patterns were matched -assert counters['inductor']['pattern_matcher_count'] == 3, "Expected 3 conv-bn patterns to be matched" +# 몇 개의 패턴이 매칭되었는지 확인합니다. +assert counters['inductor']['pattern_matcher_count'] == 3, "3개의 conv-bn 패턴이 매칭될 것으로 예상됩니다." -# Create a model with different shapes than our example_inputs +# 앞선 예시 입력과는 다른 형태를 가진 모델을 만듭니다. test_model_diff_shape = nn.Sequential( nn.Conv2d(3, 16, 5), nn.BatchNorm2d(16), @@ -243,15 +234,15 @@ def conv_bn_fusion_pass(graph): with torch.no_grad(): compiled_diff_shape(test_input_diff_shape) -# Check how many patterns were matched -assert counters['inductor']['pattern_matcher_count'] == 2, "Expected 2 conv-bn patterns to be matched" +# 몇 개의 패턴이 매칭되었는지 확인합니다. +assert counters['inductor']['pattern_matcher_count'] == 2, "2개의 conv-bn 패턴이 매칭될 것으로 예상됩니다." ###################################################################### -# Benchmarking our Fusion on ResNet18 +# ResNet18 모델을 사용한 퓨전 성능 측정 # ----------------------------------- -# We can test our fusion pass on a larger model like ResNet18 and see how much -# this pass improves inference performance. +# ResNet18과 같은 더 큰 모델에 우리의 퓨전 최적화 단계를 테스트하여 +# 이 단계가 추론 성능을 얼마나 향상시키는지 확인할 수 있습니다. import torchvision.models as models import time @@ -270,23 +261,22 @@ def benchmark(model, iters=20): model(inp) return str(time.time()-begin) -# Benchmark original model +# 원본 모델의 성능을 측정합니다. print("Original model time: ", benchmark(rn18)) -# Compile with our custom pattern +# 우리의 커스텀 패턴을 적용하여 컴파일합니다. compiled_with_pattern_matching = torch.compile(rn18, backend="inductor") -# Benchmark compiled model +# 컴파일된 모델의 성능을 측정합니다. print("\ntorch.compile (with conv-bn pattern matching and other fusions): ", benchmark(compiled_with_pattern_matching)) -############ -# Conclusion +###################################################################### +# 결론 # ---------- -# As we can see, torch.compile provides a powerful way to implement -# graph transformations and optimizations through pattern matching. -# By registering custom patterns, we can extend torch.compile's -# optimization capabilities to handle domain-specific transformations. +# 보시다시피 torch.compile은 패턴 매칭을 통해 그래프 변환 및 최적화를 구현하는 매우 강력한 방법을 +# 제공합니다. 커스텀 패턴을 등록함으로써 torch.compile의 최적화 기능을 더욱 확장하여 특정 도메인에 +# 특화된 변환까지 처리할 수 있습니다. # -# The conv-bn fusion demonstrated here is just one example of what's -# possible with torch.compile's pattern matching system. \ No newline at end of file +# 여기서 보여드린 conv-bn 퓨전은 torch.compile의 패턴 매칭 시스템으로 할 수 있는 +# 무궁무진한 일들 중 하나의 예시일 뿐입니다.