diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..110b35bade 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -155,9 +155,10 @@ std::pair Float8Quantizer::create_tensor( py::object out_py; if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + out_py = + Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer, "fake_dtype"_a = GetATenDType(dtype)); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); @@ -360,9 +361,10 @@ std::pair Float8CurrentScalingQuantizer::create_tenso py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + out_py = + Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer, "fake_dtype"_a = GetATenDType(dtype)); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); @@ -624,7 +626,7 @@ std::pair Float8BlockQuantizer::create_tensor( "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, - "is_2D_scaled"_a = (block_scaling_dim == 2)); + "is_2D_scaled"_a = (block_scaling_dim == 2), "fake_dtype"_a = GetATenDType(dtype)); } else { py::handle Float8BlockwiseQTensorClass( reinterpret_cast(Float8BlockwiseQTensorPythonClass)); @@ -911,7 +913,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); out_py = MXFP8TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py, columnwise_scale_inv_py, this->dtype, this->quantizer, - with_gemm_swizzled_scales); + with_gemm_swizzled_scales, "fake_dtype"_a = GetATenDType(dtype)); } else { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); out_py = MXFP8TensorClass( @@ -1202,7 +1204,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); out_py = NVFP4TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py, columnwise_scale_inv_py, amax_rowwise_py, amax_columnwise_py, - this->dtype, this->quantizer, with_gemm_swizzled_scales); + this->dtype, this->quantizer, with_gemm_swizzled_scales, + "fake_dtype"_a = GetATenDType(dtype)); } else { py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); out_py = NVFP4TensorClass( diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index f269e21b8c..f3b6716200 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1079,7 +1079,7 @@ def _start_all_gather_fp8_blockwise( device = inp._columnwise_data.device else: raise ValueError("Got Float8BlockwiseQTensorStorage input tensor without any data") - dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. + dtype = inp._dtype else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or" @@ -1317,7 +1317,7 @@ def _all_gather_nvfp4( if inp._columnwise_data is not None: in_shape_t = inp._columnwise_data.size() device = inp._columnwise_data.device - dtype = torch.bfloat16 + dtype = inp._dtype else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, " @@ -1486,7 +1486,7 @@ def _all_gather_mxfp8( device = inp._columnwise_data.device else: raise ValueError("Got MXFP8 input tensor without any data") - dtype = torch.bfloat16 # Guess high-precision dtype. + dtype = inp._dtype else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorStorage, " diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 09b12afa21..9845a0b4b8 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -522,6 +522,7 @@ def fill_userbuffers_buffer_for_all_gather( data=global_tensor_data, fp8_scale_inv=local_tensor._scale_inv, fp8_dtype=local_tensor._fp8_dtype, + fake_dtype=local_tensor._dtype, quantizer=quantizer, ) return global_tensor, local_tensor @@ -596,6 +597,7 @@ def fill_userbuffers_buffer_for_all_gather( fp8_dtype=local_tensor._fp8_dtype, quantizer=quantizer, with_gemm_swizzled_scales=False, + fake_dtype=local_tensor._dtype, ) return global_tensor, local_tensor diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 0a6ad61ff0..1b79256b33 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -37,6 +37,7 @@ class QuantizedTensorStorage: XTensor should only implement the functionality needed to behave like regular torch.Tensor (like __torch_dispatch__).""" + _dtype: torch.dtype _quantizer: Optional[Quantizer] def update_usage( @@ -355,9 +356,12 @@ def __new__( shape: Iterable[int], dtype: torch.dtype, *, + fake_dtype: Optional[torch.dtype] = None, requires_grad: bool = False, device: Optional[torch.device] = None, ): + if fake_dtype is not None and fake_dtype != dtype: + raise ValueError(f"fake_dtype ({fake_dtype}) does not match dtype ({dtype})") # We are assuming only contiguous tensors stride = _stride_from_shape(shape) instance = torch.Tensor._make_wrapper_subclass( @@ -370,6 +374,7 @@ def __new__( requires_grad=requires_grad, device=torch.cuda.current_device() if device is None else device, ) + instance._dtype = dtype return instance @@ -403,7 +408,7 @@ def clear(self): ) def __repr__(self, *, tensor_contents=None) -> str: - return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" + return f"{self.__class__.__name__}(data={self.dequantize()})" def float(self) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -506,7 +511,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): def maybe_unwrap(arg): if isinstance(arg, QuantizedTensor): - return arg.dequantize(dtype=arg.dtype) + return arg.dequantize() return arg def maybe_update_inplace(arg, new_arg, schema_arg): diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ecafb6ddfc..1bd03e0e6f 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -325,7 +325,7 @@ def __repr__(self, *, tensor_contents=None): return ( f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," f" is_2D_scaled={self._is_2D_scaled}," - f" data={self.dequantize(dtype=self.dtype)})" + f" data={self.dequantize()})" ) def quantize_( diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 3aeace0a77..ce88b8ff2b 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -174,6 +174,7 @@ def create_tensor_from_data( data=data, fp8_scale_inv=1 / self.scale, fp8_dtype=self.dtype, + fake_dtype=fake_dtype, requires_grad=requires_grad, data_transpose=None, quantizer=self, @@ -393,6 +394,7 @@ def create_tensor_from_data( data=data, fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), fp8_dtype=self.dtype, + fake_dtype=fake_dtype, requires_grad=requires_grad, data_transpose=None, quantizer=self, @@ -480,7 +482,7 @@ def __repr__(self, *, tensor_contents=None): "Float8Tensor(" f"fp8_dtype={self._fp8_dtype}, " f"scale_inv={self._scale_inv.item()}, " - f"data={self.dequantize(dtype=self.dtype)}" + f"data={self.dequantize()}" ")" ) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 8dd2255d89..ca6dca9fc6 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -253,7 +253,7 @@ def __new__( ) def __repr__(self, *, tensor_contents=None): - return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" + return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize()})" def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 101cf78a8f..82aa3b4302 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -440,7 +440,7 @@ def __new__( return instance def __repr__(self, *, tensor_contents=None): - return f"NVFP4Tensor, data={self.dequantize(dtype=self.dtype)})" + return f"NVFP4Tensor, data={self.dequantize()})" def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 278d7dc039..418d027371 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -46,12 +46,14 @@ def __new__( quantizer: Quantizer, is_2D_scaled: bool, *args, + fake_dtype: Optional[torch.dtype] = None, **kwargs, ): if cls is Float8BlockwiseQTensorStorage: instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 else: - instance = super().__new__(cls, *args, **kwargs) + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data instance._quantizer = quantizer.copy() if quantizer is not None else None @@ -83,6 +85,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "quantizer": self._quantizer, "is_2D_scaled": self._is_2D_scaled, + "fake_dtype": self._dtype, } def prepare_for_saving( @@ -131,7 +134,9 @@ def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch. permute_dims.append(0) return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous() - def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def _dequantize_vectorwise(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if dtype is None: + dtype = self._dtype block_len = 128 q_M, q_K = 1, 1 @@ -193,10 +198,12 @@ def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch return self._transpose_dq_columnwise_output(result) return result - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ Construct plain PyTorch tensor from Float8BlockwiseQTensor """ + if dtype is None: + dtype = self._dtype block_len = 128 if not self._is_2D_scaled: return self._dequantize_vectorwise(dtype=dtype) diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index adf3ce8aea..bdeb945f0a 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -75,14 +75,16 @@ def __new__( data: Optional[torch.Tensor], fp8_scale_inv: torch.Tensor, fp8_dtype: TE_DType, + fake_dtype: Optional[torch.dtype] = None, data_transpose: Optional[torch.Tensor] = None, quantizer: Optional[Quantizer] = None, **kwargs, ): if cls is Float8TensorStorage: instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 else: - instance = super().__new__(cls, *args, **kwargs) + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._data = data instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype @@ -112,6 +114,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "data_transpose": self._transpose, "quantizer": self._quantizer, + "fake_dtype": self._dtype, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: @@ -141,8 +144,10 @@ def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = Tr return self._transpose raise ValueError("No data to get, both rowwise_data and columnwise_data are False") - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" + if dtype is None: + dtype = self._dtype return _FromFloat8Func.forward(None, self, dtype) def size(self, *args, **kwargs): @@ -165,6 +170,7 @@ def view(self, shape: torch.Size): data=out_data, fp8_scale_inv=self._scale_inv, fp8_dtype=self._fp8_dtype, + fake_dtype=self._dtype, data_transpose=out_transpose, quantizer=self._quantizer, ) diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 1951731c75..0a6a074085 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -84,12 +84,14 @@ def __new__( quantizer: Optional[Quantizer], with_gemm_swizzled_scales: bool, *args, + fake_dtype: Optional[torch.dtype] = None, **kwargs, ): if cls is MXFP8TensorStorage: instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 else: - instance = super().__new__(cls, *args, **kwargs) + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data instance._rowwise_scale_inv = rowwise_scale_inv @@ -121,6 +123,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "fake_dtype": self._dtype, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]: @@ -157,8 +160,10 @@ def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = Tr return self._columnwise_data raise ValueError("No data to get, both rowwise_data and columnwise_data are False") - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" + if dtype is None: + dtype = self._dtype return _FromMXFP8Func.forward(None, self, dtype) def size(self, *args, **kwargs): @@ -211,6 +216,7 @@ def view(self, shape: torch.Size): fp8_dtype=self._fp8_dtype, quantizer=self._quantizer, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + fake_dtype=self._dtype, ) def __repr__(self): diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index b064d711ce..9938beac03 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -106,10 +106,14 @@ def __new__( quantizer: Optional[Quantizer], with_gemm_swizzled_scales: bool, *args, + fake_dtype: Optional[torch.dtype] = None, **kwargs, ): - - instance = super().__new__(cls, *args, **kwargs) + if cls is NVFP4TensorStorage: + instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 + else: + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data @@ -148,6 +152,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp4_dtype": self._fp4_dtype, "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "fake_dtype": self._dtype, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]: @@ -184,8 +189,10 @@ def get_data_tensors(self): """Get this Tensor's data.""" return self._rowwise_data, self._columnwise_data - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" + if dtype is None: + dtype = self._dtype return _FromNVFP4Func.forward(None, self, dtype) def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: @@ -266,6 +273,7 @@ def view(self, shape: torch.Size): quantizer=self._quantizer, fp4_dtype=self._fp4_dtype, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + fake_dtype=self._dtype, ) def __repr__(self): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 47af9fabe1..43eb6f2de2 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -242,6 +242,7 @@ def forward( fp8_dtype=mixed_x_layer._fp8_dtype, data=x.squeeze(split_dim) if squeeze else x, shape=x.squeeze(split_dim).shape if squeeze else x.shape, + fake_dtype=mixed_x_layer._dtype, quantizer=mixed_x_layer._quantizer, ) for x in torch.split(