diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 2b4e144f8..ab688871a 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -4029,6 +4029,9 @@ "tag" ] }, + "torch.distributed.ProcessGroup": { + "Matcher": "ChangePrefixMatcher" + }, "torch.distributed.ReduceOp": { "Matcher": "ChangePrefixMatcher" }, @@ -4059,14 +4062,7 @@ } }, "torch.distributed.all_gather_object": { - "Matcher": "AllGatherObjectMatcher", - "paddle_api": "paddle.distributed.all_gather_object", - "args_list": [ - "object_list", - "obj", - "group" - ], - "min_input_args": 2 + "Matcher": "ChangePrefixMatcher" }, "torch.distributed.all_reduce": { "Matcher": "ReverseAsyncOpMatcher", @@ -4180,29 +4176,7 @@ "Matcher": "ChangePrefixMatcher" }, "torch.distributed.init_process_group": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.distributed.init_parallel_env", - "args_list": [ - "backend", - "init_method", - "timeout", - "world_size", - "rank", - "store", - "group_name", - "pg_options" - ], - "kwargs_change": { - "backend": "", - "init_method": "", - "timeout": "", - "world_size": "", - "rank": "", - "store": "", - "group_name": "", - "pg_options": "" - }, - "min_input_args": 0 + "Matcher": "ChangePrefixMatcher" }, "torch.distributed.irecv": { "Matcher": "GenericMatcher", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index ee78d4282..2ce44c284 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -4717,27 +4717,6 @@ def generate_code(self, kwargs): return code -class AllGatherObjectMatcher(BaseMatcher): - def generate_code(self, kwargs): - if "group" not in kwargs: - kwargs["group"] = None - - API_TEMPLATE = textwrap.dedent( - """ - {}=[] - {}(object_list={}, obj={}, group={}) - """ - ) - return API_TEMPLATE.format( - kwargs["object_list"], - self.get_paddle_api(), - kwargs["object_list"], - kwargs["obj"], - kwargs["group"], - self.kwargs_to_str(kwargs), - ) - - class SetUpMatcher(BaseMatcher): def generate_code(self, kwargs): is_torch_cpp_extension = False diff --git a/paconvert/attribute_mapping.json b/paconvert/attribute_mapping.json index 96d4603b8..9c5e3f5b6 100644 --- a/paconvert/attribute_mapping.json +++ b/paconvert/attribute_mapping.json @@ -127,6 +127,9 @@ "torch.distributed.ReduceOp.SUM": { "Matcher": "ChangePrefixMatcher" }, + "torch.distributed.group.WORLD": { + "Matcher": "ChangePrefixMatcher" + }, "torch.distributions.Distribution.batch_shape": {}, "torch.distributions.Distribution.event_shape": {}, "torch.distributions.Distribution.mean": {}, diff --git a/tests/test_distributed_ProcessGroup.py b/tests/test_distributed_ProcessGroup.py new file mode 100644 index 000000000..1ac799a72 --- /dev/null +++ b/tests/test_distributed_ProcessGroup.py @@ -0,0 +1,59 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributed.ProcessGroup") + + +def test_case_1(): + """Use as a type annotation.""" + pytorch_code = textwrap.dedent( + """ + import torch + from torch.distributed import ProcessGroup + def f(pg: ProcessGroup): + return pg + """ + ) + expect = textwrap.dedent( + """ + import paddle + + + def f(pg: paddle.distributed.ProcessGroup): + return pg + """ + ) + obj.run(pytorch_code, expect_paddle_code=expect) + + +def test_case_2(): + """Fully qualified attribute access.""" + pytorch_code = textwrap.dedent( + """ + import torch + cls = torch.distributed.ProcessGroup + """ + ) + expect = textwrap.dedent( + """ + import paddle + + cls = paddle.distributed.ProcessGroup + """ + ) + obj.run(pytorch_code, expect_paddle_code=expect) diff --git a/tests/test_distributed_group_WORLD.py b/tests/test_distributed_group_WORLD.py new file mode 100644 index 000000000..57684cb54 --- /dev/null +++ b/tests/test_distributed_group_WORLD.py @@ -0,0 +1,55 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributed.group.WORLD") + + +def test_case_1(): + """Module-style attribute access.""" + pytorch_code = textwrap.dedent( + """ + import torch + g = torch.distributed.group.WORLD + """ + ) + expect = textwrap.dedent( + """ + import paddle + + g = paddle.distributed.group.WORLD + """ + ) + obj.run(pytorch_code, expect_paddle_code=expect) + + +def test_case_2(): + """Aliased import.""" + pytorch_code = textwrap.dedent( + """ + import torch.distributed as dist + g = dist.group.WORLD + """ + ) + expect = textwrap.dedent( + """ + import paddle + + g = paddle.distributed.group.WORLD + """ + ) + obj.run(pytorch_code, expect_paddle_code=expect) diff --git a/tests/test_distributed_init_process_group.py b/tests/test_distributed_init_process_group.py new file mode 100644 index 000000000..558635b1f --- /dev/null +++ b/tests/test_distributed_init_process_group.py @@ -0,0 +1,80 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributed.init_process_group") + + +def test_case_1(): + """Bare call — kwargs forwarded as-is via ChangePrefixMatcher.""" + pytorch_code = textwrap.dedent( + """ + import torch + torch.distributed.init_process_group(backend='nccl') + """ + ) + expect = textwrap.dedent( + """ + import paddle + + paddle.distributed.init_process_group(backend="nccl") + """ + ) + obj.run(pytorch_code, expect_paddle_code=expect) + + +def test_case_2(): + """All torch kwargs preserved (paddle accepts and ignores the unused ones).""" + pytorch_code = textwrap.dedent( + """ + import torch + torch.distributed.init_process_group( + backend='nccl', + init_method='tcp://127.0.0.1:23456', + world_size=4, + rank=0, + ) + """ + ) + expect = textwrap.dedent( + """ + import paddle + + paddle.distributed.init_process_group( + backend="nccl", init_method="tcp://127.0.0.1:23456", world_size=4, rank=0 + ) + """ + ) + obj.run(pytorch_code, expect_paddle_code=expect) + + +def test_case_3(): + """Module-style import.""" + pytorch_code = textwrap.dedent( + """ + import torch.distributed as dist + dist.init_process_group(backend='gloo') + """ + ) + expect = textwrap.dedent( + """ + import paddle + + paddle.distributed.init_process_group(backend="gloo") + """ + ) + obj.run(pytorch_code, expect_paddle_code=expect)