From d4428f243f9112db4e15a66ebc179336b7dc116c Mon Sep 17 00:00:00 2001 From: Paarth Gupta Date: Thu, 8 Jan 2026 21:56:14 +0530 Subject: [PATCH 1/6] testing library ref --- fern/docs.yml | 10 +++++++++- fern/fern.config.json | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/fern/docs.yml b/fern/docs.yml index 6104e06..ad84cf9 100644 --- a/fern/docs.yml +++ b/fern/docs.yml @@ -1,7 +1,7 @@ # yaml-language-server: $schema=https://schema.buildwithfern.dev/docs-yml.json instances: - - url: plantstore.docs.buildwithfern.com # update this to {yourorg}.docs.buildwithfern.com + - url: library-docs-local.docs.buildwithfern.com # update this to {yourorg}.docs.buildwithfern.com # custom-domain: plantstore.dev # specify your custom domain when you are ready to go live ai-search: @@ -22,6 +22,9 @@ tabs: API Reference: display-name: API Reference icon: puzzle + Library Reference: + display-name: Library Reference + icon: book navigation: - tab: home @@ -66,6 +69,11 @@ navigation: referenced-packages: - user contents: [] + - tab: Library Reference + layout: + - library-docs: https://github.com/NVIDIA/NeMo-RL + title: NeMo-RL Reference + slug: nemo-rl navbar-links: - type: minimal diff --git a/fern/fern.config.json b/fern/fern.config.json index ba9b584..6b62e1a 100644 --- a/fern/fern.config.json +++ b/fern/fern.config.json @@ -1,4 +1,4 @@ { - "organization": "plantstore", + "organization": "library-docs-local", "version": "3.10.0" } From 26f3ce7b46d7e3e1803542b9aadb05dd99ec050b Mon Sep 17 00:00:00 2001 From: Paarth Gupta Date: Tue, 10 Feb 2026 18:03:04 -0500 Subject: [PATCH 2/6] library docs v2 --- fern/ai_examples_override.yml | 13 + fern/docs.yml | 12 +- fern/fern.config.json | 4 +- fern/static/nemo-rl-docs/_navigation.yml | 1017 +++++++++ fern/static/nemo-rl-docs/nemo-rl/nemo_rl.mdx | 149 ++ .../nemo-rl/nemo_rl/algorithms.mdx | 19 + .../algorithms/advantage_estimator.mdx | 196 ++ .../nemo_rl/algorithms/async_utils.mdx | 572 +++++ .../nemo_rl/algorithms/distillation.mdx | 326 +++ .../nemo-rl/nemo_rl/algorithms/dpo.mdx | 378 ++++ .../nemo-rl/nemo_rl/algorithms/grpo.mdx | 864 ++++++++ .../nemo-rl/nemo_rl/algorithms/interfaces.mdx | 123 ++ .../nemo_rl/algorithms/loss_functions.mdx | 875 ++++++++ .../nemo_rl/algorithms/reward_functions.mdx | 102 + .../nemo-rl/nemo_rl/algorithms/rm.mdx | 320 +++ .../nemo-rl/nemo_rl/algorithms/sft.mdx | 258 +++ .../nemo-rl/nemo_rl/algorithms/utils.mdx | 379 ++++ .../nemo-rl-docs/nemo-rl/nemo_rl/data.mdx | 466 +++++ .../nemo-rl/nemo_rl/data/chat_templates.mdx | 35 + .../nemo-rl/nemo_rl/data/collate_fn.mdx | 166 ++ .../nemo-rl/nemo_rl/data/datasets.mdx | 37 + .../nemo_rl/data/datasets/eval_datasets.mdx | 60 + .../data/datasets/eval_datasets/aime.mdx | 64 + .../data/datasets/eval_datasets/gpqa.mdx | 64 + .../eval_datasets/local_math_dataset.mdx | 65 + .../data/datasets/eval_datasets/math.mdx | 61 + .../data/datasets/eval_datasets/mmlu.mdx | 61 + .../data/datasets/eval_datasets/mmlu_pro.mdx | 60 + .../data/datasets/preference_datasets.mdx | 72 + .../binary_preference_dataset.mdx | 102 + .../preference_datasets/helpsteer3.mdx | 66 + .../preference_dataset.mdx | 77 + .../datasets/preference_datasets/tulu3.mdx | 59 + .../data/datasets/processed_dataset.mdx | 135 ++ .../nemo_rl/data/datasets/raw_dataset.mdx | 94 + .../data/datasets/response_datasets.mdx | 82 + .../datasets/response_datasets/aime24.mdx | 66 + .../data/datasets/response_datasets/clevr.mdx | 97 + .../datasets/response_datasets/dapo_math.mdx | 84 + .../datasets/response_datasets/deepscaler.mdx | 59 + .../datasets/response_datasets/geometry3k.mdx | 76 + .../datasets/response_datasets/helpsteer3.mdx | 66 + .../response_datasets/nemogym_dataset.mdx | 54 + .../response_datasets/oai_format_dataset.mdx | 214 ++ .../data/datasets/response_datasets/oasst.mdx | 127 ++ .../response_datasets/openmathinstruct2.mdx | 84 + .../datasets/response_datasets/refcoco.mdx | 160 ++ .../response_datasets/response_dataset.mdx | 104 + .../data/datasets/response_datasets/squad.mdx | 66 + .../data/datasets/response_datasets/tulu3.mdx | 76 + .../nemo-rl/nemo_rl/data/datasets/utils.mdx | 191 ++ .../nemo-rl/nemo_rl/data/interfaces.mdx | 284 +++ .../nemo_rl/data/llm_message_utils.mdx | 548 +++++ .../nemo-rl/nemo_rl/data/multimodal_utils.mdx | 298 +++ .../nemo-rl/nemo_rl/data/packing.mdx | 30 + .../nemo_rl/data/packing/algorithms.mdx | 791 +++++++ .../nemo-rl/nemo_rl/data/packing/metrics.mdx | 177 ++ .../nemo-rl/nemo_rl/data/processors.mdx | 353 ++++ .../nemo-rl/nemo_rl/data/utils.mdx | 104 + .../nemo-rl/nemo_rl/distributed.mdx | 17 + .../nemo_rl/distributed/batched_data_dict.mdx | 671 ++++++ .../nemo_rl/distributed/collectives.mdx | 108 + .../nemo_rl/distributed/model_utils.mdx | 851 ++++++++ .../nemo_rl/distributed/named_sharding.mdx | 236 +++ .../ray_actor_environment_registry.mdx | 105 + .../distributed/stateless_process_group.mdx | 73 + .../nemo_rl/distributed/virtual_cluster.mdx | 514 +++++ .../distributed/worker_group_utils.mdx | 81 + .../nemo_rl/distributed/worker_groups.mdx | 603 ++++++ .../nemo-rl/nemo_rl/environments.mdx | 19 + .../nemo_rl/environments/code_environment.mdx | 290 +++ .../environments/code_jaccard_environment.mdx | 268 +++ .../environments/dapo_math_verifier.mdx | 316 +++ .../nemo_rl/environments/interfaces.mdx | 151 ++ .../nemo_rl/environments/math_environment.mdx | 356 ++++ .../nemo-rl/nemo_rl/environments/metrics.mdx | 42 + .../nemo-rl/nemo_rl/environments/nemo_gym.mdx | 210 ++ .../environments/reward_model_environment.mdx | 276 +++ .../nemo-rl/nemo_rl/environments/rewards.mdx | 180 ++ .../nemo-rl/nemo_rl/environments/utils.mdx | 152 ++ .../nemo_rl/environments/vlm_environment.mdx | 243 +++ .../nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx | 10 + .../nemo-rl/nemo_rl/evals/answer_parsing.mdx | 86 + .../nemo-rl/nemo_rl/evals/eval.mdx | 399 ++++ .../nemo-rl/nemo_rl/experience.mdx | 9 + .../nemo-rl/nemo_rl/experience/rollouts.mdx | 469 +++++ .../nemo-rl-docs/nemo-rl/nemo_rl/models.mdx | 14 + .../nemo-rl/nemo_rl/models/automodel.mdx | 12 + .../nemo_rl/models/automodel/config.mdx | 125 ++ .../nemo-rl/nemo_rl/models/automodel/data.mdx | 374 ++++ .../nemo_rl/models/automodel/setup.mdx | 229 ++ .../nemo_rl/models/automodel/train.mdx | 841 ++++++++ .../nemo-rl/nemo_rl/models/dtensor.mdx | 9 + .../nemo_rl/models/dtensor/parallelize.mdx | 454 ++++ .../nemo-rl/nemo_rl/models/generation.mdx | 62 + .../nemo_rl/models/generation/interfaces.mdx | 569 +++++ .../nemo_rl/models/generation/sglang.mdx | 33 + .../models/generation/sglang/config.mdx | 299 +++ .../generation/sglang/sglang_copied_utils.mdx | 307 +++ .../generation/sglang/sglang_generation.mdx | 369 ++++ .../generation/sglang/sglang_worker.mdx | 529 +++++ .../models/generation/sglang/utils.mdx | 109 + .../nemo_rl/models/generation/vllm.mdx | 34 + .../nemo_rl/models/generation/vllm/config.mdx | 111 + .../nemo_rl/models/generation/vllm/utils.mdx | 113 + .../models/generation/vllm/vllm_backend.mdx | 236 +++ .../generation/vllm/vllm_generation.mdx | 656 ++++++ .../models/generation/vllm/vllm_worker.mdx | 545 +++++ .../generation/vllm/vllm_worker_async.mdx | 485 +++++ .../nemo-rl/nemo_rl/models/huggingface.mdx | 9 + .../nemo_rl/models/huggingface/common.mdx | 303 +++ .../nemo-rl/nemo_rl/models/megatron.mdx | 13 + .../nemo_rl/models/megatron/common.mdx | 212 ++ .../models/megatron/community_import.mdx | 76 + .../nemo_rl/models/megatron/config.mdx | 146 ++ .../nemo-rl/nemo_rl/models/megatron/data.mdx | 471 +++++ .../nemo-rl/nemo_rl/models/megatron/setup.mdx | 535 +++++ .../nemo-rl/nemo_rl/models/policy.mdx | 948 +++++++++ .../nemo_rl/models/policy/interfaces.mdx | 574 +++++ .../nemo_rl/models/policy/lm_policy.mdx | 609 ++++++ .../nemo-rl/nemo_rl/models/policy/utils.mdx | 624 ++++++ .../nemo-rl/nemo_rl/models/policy/workers.mdx | 13 + .../policy/workers/base_policy_worker.mdx | 309 +++ .../policy/workers/dtensor_policy_worker.mdx | 693 ++++++ .../workers/dtensor_policy_worker_v2.mdx | 714 +++++++ .../policy/workers/megatron_policy_worker.mdx | 682 ++++++ .../nemo_rl/models/policy/workers/patches.mdx | 85 + .../nemo-rl/nemo_rl/package_info.mdx | 235 +++ .../nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx | 22 + .../nemo_rl/utils/automodel_checkpoint.mdx | 436 ++++ .../nemo-rl/nemo_rl/utils/checkpoint.mdx | 411 ++++ .../nemo-rl/nemo_rl/utils/config.mdx | 266 +++ .../nemo-rl/nemo_rl/utils/flops_formulas.mdx | 501 +++++ .../nemo-rl/nemo_rl/utils/flops_tracker.mdx | 215 ++ .../nemo-rl/nemo_rl/utils/logger.mdx | 1856 +++++++++++++++++ .../nemo-rl/nemo_rl/utils/memory_tracker.mdx | 122 ++ .../nemo_rl/utils/native_checkpoint.mdx | 351 ++++ .../nemo-rl/nemo_rl/utils/nsys.mdx | 138 ++ .../nemo-rl/nemo_rl/utils/nvml.mdx | 100 + .../nemo-rl/nemo_rl/utils/packed_tensor.mdx | 140 ++ .../nemo-rl/nemo_rl/utils/prefetch_venvs.mdx | 108 + .../nemo-rl/nemo_rl/utils/timer.mdx | 441 ++++ .../nemo-rl/nemo_rl/utils/venvs.mdx | 177 ++ 143 files changed, 37322 insertions(+), 5 deletions(-) create mode 100644 fern/ai_examples_override.yml create mode 100644 fern/static/nemo-rl-docs/_navigation.yml create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx create mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx diff --git a/fern/ai_examples_override.yml b/fern/ai_examples_override.yml new file mode 100644 index 0000000..bce79c3 --- /dev/null +++ b/fern/ai_examples_override.yml @@ -0,0 +1,13 @@ +paths: + /user/username: + get: + x-fern-examples: + - path-parameters: + username: username + request: + body: {} + response: + body: + id: 42 + username: plantlover99 + email: plantlover99@example.com diff --git a/fern/docs.yml b/fern/docs.yml index ad84cf9..14db967 100644 --- a/fern/docs.yml +++ b/fern/docs.yml @@ -26,6 +26,14 @@ tabs: display-name: Library Reference icon: book +libraries: + nemo-rl: + input: + git: https://github.com/NVIDIA-NeMo/RL + output: + path: ./static/nemo-rl-docs + lang: python + navigation: - tab: home layout: @@ -71,9 +79,7 @@ navigation: contents: [] - tab: Library Reference layout: - - library-docs: https://github.com/NVIDIA/NeMo-RL - title: NeMo-RL Reference - slug: nemo-rl + - library: nemo-rl navbar-links: - type: minimal diff --git a/fern/fern.config.json b/fern/fern.config.json index 6b62e1a..208d5a3 100644 --- a/fern/fern.config.json +++ b/fern/fern.config.json @@ -1,4 +1,4 @@ { - "organization": "library-docs-local", - "version": "3.10.0" + "organization": "fern", + "version": "3.63.0" } diff --git a/fern/static/nemo-rl-docs/_navigation.yml b/fern/static/nemo-rl-docs/_navigation.yml new file mode 100644 index 0000000..2d1a93b --- /dev/null +++ b/fern/static/nemo-rl-docs/_navigation.yml @@ -0,0 +1,1017 @@ +# AUTO-GENERATED by `fern docs md generate` — DO NOT EDIT +- type: section + title: algorithms + slug: nemo-rl/nemo_rl/algorithms + children: + - type: section + title: advantage_estimator + slug: nemo-rl/nemo_rl/algorithms/advantage_estimator + children: + - type: page + title: advantage_estimator + slug: nemo-rl/nemo_rl/algorithms/advantage_estimator + pageId: nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx + - type: section + title: async_utils + slug: nemo-rl/nemo_rl/algorithms/async_utils + children: + - type: page + title: async_utils + slug: nemo-rl/nemo_rl/algorithms/async_utils + pageId: nemo-rl/nemo_rl/algorithms/async_utils.mdx + - type: section + title: distillation + slug: nemo-rl/nemo_rl/algorithms/distillation + children: + - type: page + title: distillation + slug: nemo-rl/nemo_rl/algorithms/distillation + pageId: nemo-rl/nemo_rl/algorithms/distillation.mdx + - type: section + title: dpo + slug: nemo-rl/nemo_rl/algorithms/dpo + children: + - type: page + title: dpo + slug: nemo-rl/nemo_rl/algorithms/dpo + pageId: nemo-rl/nemo_rl/algorithms/dpo.mdx + - type: section + title: grpo + slug: nemo-rl/nemo_rl/algorithms/grpo + children: + - type: page + title: grpo + slug: nemo-rl/nemo_rl/algorithms/grpo + pageId: nemo-rl/nemo_rl/algorithms/grpo.mdx + - type: section + title: interfaces + slug: nemo-rl/nemo_rl/algorithms/interfaces + children: + - type: page + title: interfaces + slug: nemo-rl/nemo_rl/algorithms/interfaces + pageId: nemo-rl/nemo_rl/algorithms/interfaces.mdx + - type: section + title: loss_functions + slug: nemo-rl/nemo_rl/algorithms/loss_functions + children: + - type: page + title: loss_functions + slug: nemo-rl/nemo_rl/algorithms/loss_functions + pageId: nemo-rl/nemo_rl/algorithms/loss_functions.mdx + - type: section + title: reward_functions + slug: nemo-rl/nemo_rl/algorithms/reward_functions + children: + - type: page + title: reward_functions + slug: nemo-rl/nemo_rl/algorithms/reward_functions + pageId: nemo-rl/nemo_rl/algorithms/reward_functions.mdx + - type: section + title: rm + slug: nemo-rl/nemo_rl/algorithms/rm + children: + - type: page + title: rm + slug: nemo-rl/nemo_rl/algorithms/rm + pageId: nemo-rl/nemo_rl/algorithms/rm.mdx + - type: section + title: sft + slug: nemo-rl/nemo_rl/algorithms/sft + children: + - type: page + title: sft + slug: nemo-rl/nemo_rl/algorithms/sft + pageId: nemo-rl/nemo_rl/algorithms/sft.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/algorithms/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/algorithms/utils + pageId: nemo-rl/nemo_rl/algorithms/utils.mdx +- type: section + title: data + slug: nemo-rl/nemo_rl/data + children: + - type: section + title: chat_templates + slug: nemo-rl/nemo_rl/data/chat_templates + children: + - type: page + title: chat_templates + slug: nemo-rl/nemo_rl/data/chat_templates + pageId: nemo-rl/nemo_rl/data/chat_templates.mdx + - type: section + title: collate_fn + slug: nemo-rl/nemo_rl/data/collate_fn + children: + - type: page + title: collate_fn + slug: nemo-rl/nemo_rl/data/collate_fn + pageId: nemo-rl/nemo_rl/data/collate_fn.mdx + - type: section + title: datasets + slug: nemo-rl/nemo_rl/data/datasets + children: + - type: section + title: eval_datasets + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets + children: + - type: section + title: aime + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime + children: + - type: page + title: aime + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx + - type: section + title: gpqa + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa + children: + - type: page + title: gpqa + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx + - type: section + title: local_math_dataset + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset + children: + - type: page + title: local_math_dataset + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx + - type: section + title: math + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/math + children: + - type: page + title: math + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/math + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx + - type: section + title: mmlu + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu + children: + - type: page + title: mmlu + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx + - type: section + title: mmlu_pro + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro + children: + - type: page + title: mmlu_pro + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx + - type: section + title: preference_datasets + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets + children: + - type: section + title: binary_preference_dataset + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset + children: + - type: page + title: binary_preference_dataset + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset + pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx + - type: section + title: helpsteer3 + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3 + children: + - type: page + title: helpsteer3 + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3 + pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx + - type: section + title: preference_dataset + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset + children: + - type: page + title: preference_dataset + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset + pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx + - type: section + title: tulu3 + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3 + children: + - type: page + title: tulu3 + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3 + pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx + - type: section + title: processed_dataset + slug: nemo-rl/nemo_rl/data/datasets/processed_dataset + children: + - type: page + title: processed_dataset + slug: nemo-rl/nemo_rl/data/datasets/processed_dataset + pageId: nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx + - type: section + title: raw_dataset + slug: nemo-rl/nemo_rl/data/datasets/raw_dataset + children: + - type: page + title: raw_dataset + slug: nemo-rl/nemo_rl/data/datasets/raw_dataset + pageId: nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx + - type: section + title: response_datasets + slug: nemo-rl/nemo_rl/data/datasets/response_datasets + children: + - type: section + title: aime24 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24 + children: + - type: page + title: aime24 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24 + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx + - type: section + title: clevr + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr + children: + - type: page + title: clevr + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx + - type: section + title: dapo_math + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math + children: + - type: page + title: dapo_math + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx + - type: section + title: deepscaler + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler + children: + - type: page + title: deepscaler + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx + - type: section + title: geometry3k + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k + children: + - type: page + title: geometry3k + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx + - type: section + title: helpsteer3 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3 + children: + - type: page + title: helpsteer3 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3 + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx + - type: section + title: nemogym_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset + children: + - type: page + title: nemogym_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx + - type: section + title: oai_format_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset + children: + - type: page + title: oai_format_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx + - type: section + title: oasst + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst + children: + - type: page + title: oasst + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx + - type: section + title: openmathinstruct2 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2 + children: + - type: page + title: openmathinstruct2 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2 + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx + - type: section + title: refcoco + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco + children: + - type: page + title: refcoco + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx + - type: section + title: response_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset + children: + - type: page + title: response_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx + - type: section + title: squad + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/squad + children: + - type: page + title: squad + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/squad + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx + - type: section + title: tulu3 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3 + children: + - type: page + title: tulu3 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3 + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/data/datasets/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/data/datasets/utils + pageId: nemo-rl/nemo_rl/data/datasets/utils.mdx + - type: section + title: interfaces + slug: nemo-rl/nemo_rl/data/interfaces + children: + - type: page + title: interfaces + slug: nemo-rl/nemo_rl/data/interfaces + pageId: nemo-rl/nemo_rl/data/interfaces.mdx + - type: section + title: llm_message_utils + slug: nemo-rl/nemo_rl/data/llm_message_utils + children: + - type: page + title: llm_message_utils + slug: nemo-rl/nemo_rl/data/llm_message_utils + pageId: nemo-rl/nemo_rl/data/llm_message_utils.mdx + - type: section + title: multimodal_utils + slug: nemo-rl/nemo_rl/data/multimodal_utils + children: + - type: page + title: multimodal_utils + slug: nemo-rl/nemo_rl/data/multimodal_utils + pageId: nemo-rl/nemo_rl/data/multimodal_utils.mdx + - type: section + title: packing + slug: nemo-rl/nemo_rl/data/packing + children: + - type: section + title: algorithms + slug: nemo-rl/nemo_rl/data/packing/algorithms + children: + - type: page + title: algorithms + slug: nemo-rl/nemo_rl/data/packing/algorithms + pageId: nemo-rl/nemo_rl/data/packing/algorithms.mdx + - type: section + title: metrics + slug: nemo-rl/nemo_rl/data/packing/metrics + children: + - type: page + title: metrics + slug: nemo-rl/nemo_rl/data/packing/metrics + pageId: nemo-rl/nemo_rl/data/packing/metrics.mdx + - type: section + title: processors + slug: nemo-rl/nemo_rl/data/processors + children: + - type: page + title: processors + slug: nemo-rl/nemo_rl/data/processors + pageId: nemo-rl/nemo_rl/data/processors.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/data/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/data/utils + pageId: nemo-rl/nemo_rl/data/utils.mdx +- type: section + title: distributed + slug: nemo-rl/nemo_rl/distributed + children: + - type: section + title: batched_data_dict + slug: nemo-rl/nemo_rl/distributed/batched_data_dict + children: + - type: page + title: batched_data_dict + slug: nemo-rl/nemo_rl/distributed/batched_data_dict + pageId: nemo-rl/nemo_rl/distributed/batched_data_dict.mdx + - type: section + title: collectives + slug: nemo-rl/nemo_rl/distributed/collectives + children: + - type: page + title: collectives + slug: nemo-rl/nemo_rl/distributed/collectives + pageId: nemo-rl/nemo_rl/distributed/collectives.mdx + - type: section + title: model_utils + slug: nemo-rl/nemo_rl/distributed/model_utils + children: + - type: page + title: model_utils + slug: nemo-rl/nemo_rl/distributed/model_utils + pageId: nemo-rl/nemo_rl/distributed/model_utils.mdx + - type: section + title: named_sharding + slug: nemo-rl/nemo_rl/distributed/named_sharding + children: + - type: page + title: named_sharding + slug: nemo-rl/nemo_rl/distributed/named_sharding + pageId: nemo-rl/nemo_rl/distributed/named_sharding.mdx + - type: section + title: ray_actor_environment_registry + slug: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry + children: + - type: page + title: ray_actor_environment_registry + slug: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry + pageId: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx + - type: section + title: stateless_process_group + slug: nemo-rl/nemo_rl/distributed/stateless_process_group + children: + - type: page + title: stateless_process_group + slug: nemo-rl/nemo_rl/distributed/stateless_process_group + pageId: nemo-rl/nemo_rl/distributed/stateless_process_group.mdx + - type: section + title: virtual_cluster + slug: nemo-rl/nemo_rl/distributed/virtual_cluster + children: + - type: page + title: virtual_cluster + slug: nemo-rl/nemo_rl/distributed/virtual_cluster + pageId: nemo-rl/nemo_rl/distributed/virtual_cluster.mdx + - type: section + title: worker_group_utils + slug: nemo-rl/nemo_rl/distributed/worker_group_utils + children: + - type: page + title: worker_group_utils + slug: nemo-rl/nemo_rl/distributed/worker_group_utils + pageId: nemo-rl/nemo_rl/distributed/worker_group_utils.mdx + - type: section + title: worker_groups + slug: nemo-rl/nemo_rl/distributed/worker_groups + children: + - type: page + title: worker_groups + slug: nemo-rl/nemo_rl/distributed/worker_groups + pageId: nemo-rl/nemo_rl/distributed/worker_groups.mdx +- type: section + title: environments + slug: nemo-rl/nemo_rl/environments + children: + - type: section + title: code_environment + slug: nemo-rl/nemo_rl/environments/code_environment + children: + - type: page + title: code_environment + slug: nemo-rl/nemo_rl/environments/code_environment + pageId: nemo-rl/nemo_rl/environments/code_environment.mdx + - type: section + title: code_jaccard_environment + slug: nemo-rl/nemo_rl/environments/code_jaccard_environment + children: + - type: page + title: code_jaccard_environment + slug: nemo-rl/nemo_rl/environments/code_jaccard_environment + pageId: nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx + - type: section + title: dapo_math_verifier + slug: nemo-rl/nemo_rl/environments/dapo_math_verifier + children: + - type: page + title: dapo_math_verifier + slug: nemo-rl/nemo_rl/environments/dapo_math_verifier + pageId: nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx + - type: section + title: interfaces + slug: nemo-rl/nemo_rl/environments/interfaces + children: + - type: page + title: interfaces + slug: nemo-rl/nemo_rl/environments/interfaces + pageId: nemo-rl/nemo_rl/environments/interfaces.mdx + - type: section + title: math_environment + slug: nemo-rl/nemo_rl/environments/math_environment + children: + - type: page + title: math_environment + slug: nemo-rl/nemo_rl/environments/math_environment + pageId: nemo-rl/nemo_rl/environments/math_environment.mdx + - type: section + title: metrics + slug: nemo-rl/nemo_rl/environments/metrics + children: + - type: page + title: metrics + slug: nemo-rl/nemo_rl/environments/metrics + pageId: nemo-rl/nemo_rl/environments/metrics.mdx + - type: section + title: nemo_gym + slug: nemo-rl/nemo_rl/environments/nemo_gym + children: + - type: page + title: nemo_gym + slug: nemo-rl/nemo_rl/environments/nemo_gym + pageId: nemo-rl/nemo_rl/environments/nemo_gym.mdx + - type: section + title: reward_model_environment + slug: nemo-rl/nemo_rl/environments/reward_model_environment + children: + - type: page + title: reward_model_environment + slug: nemo-rl/nemo_rl/environments/reward_model_environment + pageId: nemo-rl/nemo_rl/environments/reward_model_environment.mdx + - type: section + title: rewards + slug: nemo-rl/nemo_rl/environments/rewards + children: + - type: page + title: rewards + slug: nemo-rl/nemo_rl/environments/rewards + pageId: nemo-rl/nemo_rl/environments/rewards.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/environments/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/environments/utils + pageId: nemo-rl/nemo_rl/environments/utils.mdx + - type: section + title: vlm_environment + slug: nemo-rl/nemo_rl/environments/vlm_environment + children: + - type: page + title: vlm_environment + slug: nemo-rl/nemo_rl/environments/vlm_environment + pageId: nemo-rl/nemo_rl/environments/vlm_environment.mdx +- type: section + title: evals + slug: nemo-rl/nemo_rl/evals + children: + - type: section + title: answer_parsing + slug: nemo-rl/nemo_rl/evals/answer_parsing + children: + - type: page + title: answer_parsing + slug: nemo-rl/nemo_rl/evals/answer_parsing + pageId: nemo-rl/nemo_rl/evals/answer_parsing.mdx + - type: section + title: eval + slug: nemo-rl/nemo_rl/evals/eval + children: + - type: page + title: eval + slug: nemo-rl/nemo_rl/evals/eval + pageId: nemo-rl/nemo_rl/evals/eval.mdx +- type: section + title: experience + slug: nemo-rl/nemo_rl/experience + children: + - type: section + title: rollouts + slug: nemo-rl/nemo_rl/experience/rollouts + children: + - type: page + title: rollouts + slug: nemo-rl/nemo_rl/experience/rollouts + pageId: nemo-rl/nemo_rl/experience/rollouts.mdx +- type: section + title: models + slug: nemo-rl/nemo_rl/models + children: + - type: section + title: automodel + slug: nemo-rl/nemo_rl/models/automodel + children: + - type: section + title: config + slug: nemo-rl/nemo_rl/models/automodel/config + children: + - type: page + title: config + slug: nemo-rl/nemo_rl/models/automodel/config + pageId: nemo-rl/nemo_rl/models/automodel/config.mdx + - type: section + title: data + slug: nemo-rl/nemo_rl/models/automodel/data + children: + - type: page + title: data + slug: nemo-rl/nemo_rl/models/automodel/data + pageId: nemo-rl/nemo_rl/models/automodel/data.mdx + - type: section + title: setup + slug: nemo-rl/nemo_rl/models/automodel/setup + children: + - type: page + title: setup + slug: nemo-rl/nemo_rl/models/automodel/setup + pageId: nemo-rl/nemo_rl/models/automodel/setup.mdx + - type: section + title: train + slug: nemo-rl/nemo_rl/models/automodel/train + children: + - type: page + title: train + slug: nemo-rl/nemo_rl/models/automodel/train + pageId: nemo-rl/nemo_rl/models/automodel/train.mdx + - type: section + title: dtensor + slug: nemo-rl/nemo_rl/models/dtensor + children: + - type: section + title: parallelize + slug: nemo-rl/nemo_rl/models/dtensor/parallelize + children: + - type: page + title: parallelize + slug: nemo-rl/nemo_rl/models/dtensor/parallelize + pageId: nemo-rl/nemo_rl/models/dtensor/parallelize.mdx + - type: section + title: generation + slug: nemo-rl/nemo_rl/models/generation + children: + - type: section + title: interfaces + slug: nemo-rl/nemo_rl/models/generation/interfaces + children: + - type: page + title: interfaces + slug: nemo-rl/nemo_rl/models/generation/interfaces + pageId: nemo-rl/nemo_rl/models/generation/interfaces.mdx + - type: section + title: sglang + slug: nemo-rl/nemo_rl/models/generation/sglang + children: + - type: section + title: config + slug: nemo-rl/nemo_rl/models/generation/sglang/config + children: + - type: page + title: config + slug: nemo-rl/nemo_rl/models/generation/sglang/config + pageId: nemo-rl/nemo_rl/models/generation/sglang/config.mdx + - type: section + title: sglang_copied_utils + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils + children: + - type: page + title: sglang_copied_utils + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils + pageId: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx + - type: section + title: sglang_generation + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation + children: + - type: page + title: sglang_generation + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation + pageId: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx + - type: section + title: sglang_worker + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker + children: + - type: page + title: sglang_worker + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker + pageId: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/models/generation/sglang/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/models/generation/sglang/utils + pageId: nemo-rl/nemo_rl/models/generation/sglang/utils.mdx + - type: section + title: vllm + slug: nemo-rl/nemo_rl/models/generation/vllm + children: + - type: section + title: config + slug: nemo-rl/nemo_rl/models/generation/vllm/config + children: + - type: page + title: config + slug: nemo-rl/nemo_rl/models/generation/vllm/config + pageId: nemo-rl/nemo_rl/models/generation/vllm/config.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/models/generation/vllm/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/models/generation/vllm/utils + pageId: nemo-rl/nemo_rl/models/generation/vllm/utils.mdx + - type: section + title: vllm_backend + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend + children: + - type: page + title: vllm_backend + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend + pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx + - type: section + title: vllm_generation + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation + children: + - type: page + title: vllm_generation + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation + pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx + - type: section + title: vllm_worker + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker + children: + - type: page + title: vllm_worker + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker + pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx + - type: section + title: vllm_worker_async + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async + children: + - type: page + title: vllm_worker_async + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async + pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx + - type: section + title: huggingface + slug: nemo-rl/nemo_rl/models/huggingface + children: + - type: section + title: common + slug: nemo-rl/nemo_rl/models/huggingface/common + children: + - type: page + title: common + slug: nemo-rl/nemo_rl/models/huggingface/common + pageId: nemo-rl/nemo_rl/models/huggingface/common.mdx + - type: section + title: megatron + slug: nemo-rl/nemo_rl/models/megatron + children: + - type: section + title: common + slug: nemo-rl/nemo_rl/models/megatron/common + children: + - type: page + title: common + slug: nemo-rl/nemo_rl/models/megatron/common + pageId: nemo-rl/nemo_rl/models/megatron/common.mdx + - type: section + title: community_import + slug: nemo-rl/nemo_rl/models/megatron/community_import + children: + - type: page + title: community_import + slug: nemo-rl/nemo_rl/models/megatron/community_import + pageId: nemo-rl/nemo_rl/models/megatron/community_import.mdx + - type: section + title: config + slug: nemo-rl/nemo_rl/models/megatron/config + children: + - type: page + title: config + slug: nemo-rl/nemo_rl/models/megatron/config + pageId: nemo-rl/nemo_rl/models/megatron/config.mdx + - type: section + title: data + slug: nemo-rl/nemo_rl/models/megatron/data + children: + - type: page + title: data + slug: nemo-rl/nemo_rl/models/megatron/data + pageId: nemo-rl/nemo_rl/models/megatron/data.mdx + - type: section + title: setup + slug: nemo-rl/nemo_rl/models/megatron/setup + children: + - type: page + title: setup + slug: nemo-rl/nemo_rl/models/megatron/setup + pageId: nemo-rl/nemo_rl/models/megatron/setup.mdx + - type: section + title: policy + slug: nemo-rl/nemo_rl/models/policy + children: + - type: section + title: interfaces + slug: nemo-rl/nemo_rl/models/policy/interfaces + children: + - type: page + title: interfaces + slug: nemo-rl/nemo_rl/models/policy/interfaces + pageId: nemo-rl/nemo_rl/models/policy/interfaces.mdx + - type: section + title: lm_policy + slug: nemo-rl/nemo_rl/models/policy/lm_policy + children: + - type: page + title: lm_policy + slug: nemo-rl/nemo_rl/models/policy/lm_policy + pageId: nemo-rl/nemo_rl/models/policy/lm_policy.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/models/policy/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/models/policy/utils + pageId: nemo-rl/nemo_rl/models/policy/utils.mdx + - type: section + title: workers + slug: nemo-rl/nemo_rl/models/policy/workers + children: + - type: section + title: base_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker + children: + - type: page + title: base_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker + pageId: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx + - type: section + title: dtensor_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker + children: + - type: page + title: dtensor_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker + pageId: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx + - type: section + title: dtensor_policy_worker_v2 + slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2 + children: + - type: page + title: dtensor_policy_worker_v2 + slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2 + pageId: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx + - type: section + title: megatron_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker + children: + - type: page + title: megatron_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker + pageId: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx + - type: section + title: patches + slug: nemo-rl/nemo_rl/models/policy/workers/patches + children: + - type: page + title: patches + slug: nemo-rl/nemo_rl/models/policy/workers/patches + pageId: nemo-rl/nemo_rl/models/policy/workers/patches.mdx +- type: section + title: package_info + slug: nemo-rl/nemo_rl/package_info + children: + - type: page + title: package_info + slug: nemo-rl/nemo_rl/package_info + pageId: nemo-rl/nemo_rl/package_info.mdx +- type: section + title: utils + slug: nemo-rl/nemo_rl/utils + children: + - type: section + title: automodel_checkpoint + slug: nemo-rl/nemo_rl/utils/automodel_checkpoint + children: + - type: page + title: automodel_checkpoint + slug: nemo-rl/nemo_rl/utils/automodel_checkpoint + pageId: nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx + - type: section + title: checkpoint + slug: nemo-rl/nemo_rl/utils/checkpoint + children: + - type: page + title: checkpoint + slug: nemo-rl/nemo_rl/utils/checkpoint + pageId: nemo-rl/nemo_rl/utils/checkpoint.mdx + - type: section + title: config + slug: nemo-rl/nemo_rl/utils/config + children: + - type: page + title: config + slug: nemo-rl/nemo_rl/utils/config + pageId: nemo-rl/nemo_rl/utils/config.mdx + - type: section + title: flops_formulas + slug: nemo-rl/nemo_rl/utils/flops_formulas + children: + - type: page + title: flops_formulas + slug: nemo-rl/nemo_rl/utils/flops_formulas + pageId: nemo-rl/nemo_rl/utils/flops_formulas.mdx + - type: section + title: flops_tracker + slug: nemo-rl/nemo_rl/utils/flops_tracker + children: + - type: page + title: flops_tracker + slug: nemo-rl/nemo_rl/utils/flops_tracker + pageId: nemo-rl/nemo_rl/utils/flops_tracker.mdx + - type: section + title: logger + slug: nemo-rl/nemo_rl/utils/logger + children: + - type: page + title: logger + slug: nemo-rl/nemo_rl/utils/logger + pageId: nemo-rl/nemo_rl/utils/logger.mdx + - type: section + title: memory_tracker + slug: nemo-rl/nemo_rl/utils/memory_tracker + children: + - type: page + title: memory_tracker + slug: nemo-rl/nemo_rl/utils/memory_tracker + pageId: nemo-rl/nemo_rl/utils/memory_tracker.mdx + - type: section + title: native_checkpoint + slug: nemo-rl/nemo_rl/utils/native_checkpoint + children: + - type: page + title: native_checkpoint + slug: nemo-rl/nemo_rl/utils/native_checkpoint + pageId: nemo-rl/nemo_rl/utils/native_checkpoint.mdx + - type: section + title: nsys + slug: nemo-rl/nemo_rl/utils/nsys + children: + - type: page + title: nsys + slug: nemo-rl/nemo_rl/utils/nsys + pageId: nemo-rl/nemo_rl/utils/nsys.mdx + - type: section + title: nvml + slug: nemo-rl/nemo_rl/utils/nvml + children: + - type: page + title: nvml + slug: nemo-rl/nemo_rl/utils/nvml + pageId: nemo-rl/nemo_rl/utils/nvml.mdx + - type: section + title: packed_tensor + slug: nemo-rl/nemo_rl/utils/packed_tensor + children: + - type: page + title: packed_tensor + slug: nemo-rl/nemo_rl/utils/packed_tensor + pageId: nemo-rl/nemo_rl/utils/packed_tensor.mdx + - type: section + title: prefetch_venvs + slug: nemo-rl/nemo_rl/utils/prefetch_venvs + children: + - type: page + title: prefetch_venvs + slug: nemo-rl/nemo_rl/utils/prefetch_venvs + pageId: nemo-rl/nemo_rl/utils/prefetch_venvs.mdx + - type: section + title: timer + slug: nemo-rl/nemo_rl/utils/timer + children: + - type: page + title: timer + slug: nemo-rl/nemo_rl/utils/timer + pageId: nemo-rl/nemo_rl/utils/timer.mdx + - type: section + title: venvs + slug: nemo-rl/nemo_rl/utils/venvs + children: + - type: page + title: venvs + slug: nemo-rl/nemo_rl/utils/venvs + pageId: nemo-rl/nemo_rl/utils/venvs.mdx diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl.mdx new file mode 100644 index 0000000..002c19d --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl.mdx @@ -0,0 +1,149 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl +title: nemo_rl +--- + +## Subpackages + +- **[`nemo_rl.algorithms`](/nemo-rl/nemo_rl/algorithms)** +- **[`nemo_rl.data`](/nemo-rl/nemo_rl/data)** +- **[`nemo_rl.distributed`](/nemo-rl/nemo_rl/distributed)** +- **[`nemo_rl.environments`](/nemo-rl/nemo_rl/environments)** +- **[`nemo_rl.evals`](/nemo-rl/nemo_rl/evals)** +- **[`nemo_rl.experience`](/nemo-rl/nemo_rl/experience)** +- **[`nemo_rl.models`](/nemo-rl/nemo_rl/models)** +- **[`nemo_rl.utils`](/nemo-rl/nemo_rl/utils)** + +## Submodules + +- **[`nemo_rl.package_info`](/nemo-rl/nemo_rl/package_info)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_check_container_fingerprint`](#nemo_rl-_check_container_fingerprint) | Check if container dependencies match the current code (container-only). | +| [`_is_build_isolation`](#nemo_rl-_is_build_isolation) | Detect if we're running in a uv build isolation environment. | +| [`_patch_nsight_file`](#nemo_rl-_patch_nsight_file) | Patch the nsight.py file to fix the context.py_executable assignment. | +| [`patch_transformers_module_dir`](#nemo_rl-patch_transformers_module_dir) | - | + +### Data + +[`megatron_path`](#nemo_rl-megatron_path) + +### API + + + + + +```python +nemo_rl._check_container_fingerprint() +``` + + + + + + +Check if container dependencies match the current code (container-only). + +This check only runs when NRL_CONTAINER=1 is set (inside containers). +It compares the container's fingerprint (computed at build time) with +the current code's fingerprint to detect dependency drift. + +This check is also skipped entirely if NRL_FORCE_REBUILD_VENVS=true is set, +since environment rebuilding will ensure dependencies are consistent regardless +of a mismatch. + +If there's a mismatch, raises RuntimeError unless NRL_IGNORE_VERSION_MISMATCH is set. + + + + + + + + +```python +nemo_rl._is_build_isolation() +``` + + + + + + +Detect if we're running in a uv build isolation environment. + +When running uv lock/sync, uv creates a temporary isolated environment +in ~/.cache/uv/builds-v*/ to build packages and introspect metadata. +We skip the fingerprint check in this context since the user is updating dependencies. + +Returns True if in build isolation, False otherwise. + + + + + + + + +```python +nemo_rl._patch_nsight_file() +``` + + + + + + +Patch the nsight.py file to fix the context.py_executable assignment. + +Until this fix is upstreamed, we will maintain this patch here. This patching +logic is only applied if the user intends to use nsys profiling which they enable with +NRL_NSYS_WORKER_PATTERNS. + +If enabled, will effectively apply the following patch in an idempotent manner: + +https://github.com/ray-project/ray/compare/master...terrykong:ray:tk/nsight-py-exeutable-fix?expand=1 + +This hack works b/c the nsight plugin is not called from the main driver process, so +as soon as nemo_rl is imported, the patch is applied and the source of the nsight.py module +is up to date before the nsight.py is actually needed. + + + + + + + + +```python +nemo_rl.patch_transformers_module_dir( + env_vars: dict[str, str] +) +``` + + + + + + + + + + + + + +```python +nemo_rl.megatron_path = Path(__file__).parent.parent / '3rdparty' / 'Megatron-LM-workspace' / 'Megatron-... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx new file mode 100644 index 0000000..7f03746 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx @@ -0,0 +1,19 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms +title: nemo_rl.algorithms +--- + +## Submodules + +- **[`nemo_rl.algorithms.advantage_estimator`](/nemo-rl/nemo_rl/algorithms/advantage_estimator)** +- **[`nemo_rl.algorithms.async_utils`](/nemo-rl/nemo_rl/algorithms/async_utils)** +- **[`nemo_rl.algorithms.distillation`](/nemo-rl/nemo_rl/algorithms/distillation)** +- **[`nemo_rl.algorithms.dpo`](/nemo-rl/nemo_rl/algorithms/dpo)** +- **[`nemo_rl.algorithms.grpo`](/nemo-rl/nemo_rl/algorithms/grpo)** +- **[`nemo_rl.algorithms.interfaces`](/nemo-rl/nemo_rl/algorithms/interfaces)** +- **[`nemo_rl.algorithms.loss_functions`](/nemo-rl/nemo_rl/algorithms/loss_functions)** +- **[`nemo_rl.algorithms.reward_functions`](/nemo-rl/nemo_rl/algorithms/reward_functions)** +- **[`nemo_rl.algorithms.rm`](/nemo-rl/nemo_rl/algorithms/rm)** +- **[`nemo_rl.algorithms.sft`](/nemo-rl/nemo_rl/algorithms/sft)** +- **[`nemo_rl.algorithms.utils`](/nemo-rl/nemo_rl/algorithms/utils)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx new file mode 100644 index 0000000..0841909 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx @@ -0,0 +1,196 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/advantage_estimator +title: nemo_rl.algorithms.advantage_estimator +--- + +Advantage Estimators for RL algorithms. + +This module provides different advantage estimation strategies: +- GRPOAdvantageEstimator: Standard GRPO advantage with leave-one-out baseline +- ReinforcePlusPlusAdvantageEstimator: Reinforce++ with optional baseline subtraction (minus_baseline) and KL penalty in reward +Reference papers: +- ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/ +- Reinforce++: https://arxiv.org/abs/2501.03262 + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`GRPOAdvantageEstimator`](#nemo_rl-algorithms-advantage_estimator-GRPOAdvantageEstimator) | GRPO-style advantage estimator with leave-one-out baseline. | +| [`ReinforcePlusPlusAdvantageEstimator`](#nemo_rl-algorithms-advantage_estimator-ReinforcePlusPlusAdvantageEstimator) | Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward. | + +### API + + + + + +```python +class nemo_rl.algorithms.advantage_estimator.GRPOAdvantageEstimator( + estimator_config: dict, + loss_config: dict +) +``` + + + + + + +GRPO-style advantage estimator with leave-one-out baseline. + +Note: GRPO computes advantages over all responses for each prompt. + + + + + + + + + + + +```python +nemo_rl.algorithms.advantage_estimator.GRPOAdvantageEstimator.compute_advantage( + prompt_ids, + rewards, + mask, + kwargs = {} +) +``` + + + + + + +Compute GRPO advantages. + +**Parameters:** + + +Tensor of shape [batch_size] identifying which prompt each sample belongs to. + + + +Tensor of shape [batch_size] containing reward for each sample. + + + +Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. + Used only for expanding advantages to token-level shape. + + + +Additional arguments (unused). + + +**Returns:** + +Advantages tensor of shape [batch_size, seq_len]. + + + + + + + + + +```python +class nemo_rl.algorithms.advantage_estimator.ReinforcePlusPlusAdvantageEstimator( + estimator_config: dict, + loss_config: dict +) +``` + + + + + + +Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward. + +**Parameters:** + + +If True, subtract per-prompt mean baseline from rewards. + + + +If True, add KL penalty to reward instead of loss. + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.advantage_estimator.ReinforcePlusPlusAdvantageEstimator.compute_advantage( + prompt_ids, + rewards, + mask, + logprobs_policy = None, + logprobs_reference = None, + kwargs = {} +) +``` + + + + + + +Compute Reinforce++ advantages with optional KL penalty. + +**Parameters:** + + +Tensor of shape [batch_size] identifying which prompt each sample belongs to. + + + +Tensor of shape [batch_size] containing reward for each sample. + + + +Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. + Used for: (1) expanding advantages to token-level shape, (2) global normalization + that only considers valid tokens. + + + +Policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward. + + + +Reference policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward. + + + +Additional arguments (unused). + + +**Returns:** + +Advantages tensor of shape [batch_size, seq_len], globally normalized across valid tokens. + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx new file mode 100644 index 0000000..f9ad506 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx @@ -0,0 +1,572 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/async_utils +title: nemo_rl.algorithms.async_utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AsyncTrajectoryCollector`](#nemo_rl-algorithms-async_utils-AsyncTrajectoryCollector) | Collects trajectories asynchronously and adds them to replay buffer. | +| [`ReplayBuffer`](#nemo_rl-algorithms-async_utils-ReplayBuffer) | Replay buffer storing per-prompt groups. | + +### Data + +[`TokenizerType`](#nemo_rl-algorithms-async_utils-TokenizerType) + +### API + + + + + +```python +class nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + tokenizer: nemo_rl.algorithms.async_utils.TokenizerType, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + master_config: nemo_rl.algorithms.grpo.MasterConfig, + replay_buffer: typing.Any, + start_step: int = 0 +) +``` + + + + + + +Collects trajectories asynchronously and adds them to replay buffer. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._calculate_target_weights( + generation_weight_version: int +) -> list[int] +``` + + + + + + +Calculate target weight versions for given generation weight version. + +The list of versions returned enumerate the possible version a generation +server can target. These versions are looped over to see what training +step they can target. If all target versions are exhausted, this generation +server will remain idle until the next weight update. + +Example: +generation_weight_version = 10 +max_trajectory_age_steps = 4 + +**Returns:** `list[int]` + +[11, 12, 13, 14] # Meaning this generation server can create trajectories for training step 11, 12, 13, 14 + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._cleanup_finished_threads() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._collection_loop() +``` + + + + + + +Run the collection loop in background thread. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._get_next_target_for_generation( + generation_weight_version: int +) -> typing.Optional[int] +``` + + + + + + +Get the next target weight that needs generation (if any). + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._process_batch( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] +) -> None +``` + + + + + + +Process a single batch and generate for one target weight. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._run_prompt_group_worker( + repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + generation_weight_version: int, + target_weight_version: int, + prompt_idx: int +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._should_pause_for_generation_limits() -> bool +``` + + + + + + +Check if collection should be paused due to generation limits. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.get_dataloader_state() -> dict +``` + + + + + + +Get the current dataloader state for checkpointing. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.get_weight_version() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.pause() -> None +``` + + + + + + +Pause trajectory collection. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.prepare_for_refit() -> None +``` + + + + + + +Pause new generation starts and optionally wait for pending generations. + +For vLLM V1 async engine, leverages in-flight weight updates via collective_rpc, +allowing ongoing generations to continue with their current KV caches while +weights are updated. This significantly improves async performance. + +For non-async engines, waits for all pending generations to complete before refit. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.resume() -> None +``` + + + + + + +Resume trajectory collection. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.resume_after_refit() -> None +``` + + + + + + +Resume new generation starts after refit is complete. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.set_weight_version( + version: int +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.start_collection( + dataloader: torchdata.stateful_dataloader.StatefulDataLoader +) -> None +``` + + + + + + +Start collecting trajectories from dataloader. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.wait_for_pending_generations() -> None +``` + + + + + + +Wait for all in-flight generation threads to complete. + + + + + + + + + +```python +class nemo_rl.algorithms.async_utils.ReplayBuffer( + max_size: int +) +``` + + + + + + +Replay buffer storing per-prompt groups. + +A single entry corresponds to 1 prompt repeated by +grpo.num_generations_per_prompt (required to compute per-prompt advantages). + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.clear() -> None +``` + + + + + + +Clear the buffer. + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.get_debug_info() -> dict +``` + + + + + + +Get debug information about buffer state. + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.get_existing_target_weights() -> set[int] +``` + + + + + + +Get set of target weight versions that already have trajectories. + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.get_last_target_weight_already_generated() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.push_with_wait_signal( + trajectory: dict[str, typing.Any], + weight_version: int, + target_weight_version: int +) -> str +``` + + + + + + +Add a per-prompt trajectory group with metadata. + +**Parameters:** + + +data dict + + + +version of the model weights used for generation + + + +version of the model weights this trajectory is intended for training + + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.sample( + num_prompt_groups: int, + current_weight_version: int, + max_age_steps: int +) -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +Sample per-prompt trajectory groups intended for the current training step. + +Only returns trajectories with target_weight_version == current_weight_version. +If insufficient trajectories are available, returns None to stall training +until the remaining trajectories are generated. This ensures no trajectory +loses its last chance to be used for its intended training step. + +**Returns:** `Optional[dict[str, Any]]` + +Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None if insufficient data + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.size() -> int +``` + + + + + + +Return current buffer size. + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.TokenizerType = PreTrainedTokenizerBase +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx new file mode 100644 index 0000000..2dede47 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx @@ -0,0 +1,326 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/distillation +title: nemo_rl.algorithms.distillation +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DistillationConfig`](#nemo_rl-algorithms-distillation-DistillationConfig) | - | +| [`DistillationSaveState`](#nemo_rl-algorithms-distillation-DistillationSaveState) | - | +| [`MasterConfig`](#nemo_rl-algorithms-distillation-MasterConfig) | Main configuration structure. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_default_distillation_save_state`](#nemo_rl-algorithms-distillation-_default_distillation_save_state) | - | +| [`check_vocab_equality`](#nemo_rl-algorithms-distillation-check_vocab_equality) | Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal. | +| [`distillation_train`](#nemo_rl-algorithms-distillation-distillation_train) | Run Distillation training algorithm. | +| [`setup`](#nemo_rl-algorithms-distillation-setup) | Main entry point for distillation algorithm. | +| [`validate`](#nemo_rl-algorithms-distillation-validate) | Run validation on the validation dataset. | + +### Data + +[`TokenizerType`](#nemo_rl-algorithms-distillation-TokenizerType) + +### API + + + + + +```python +class nemo_rl.algorithms.distillation.DistillationConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.distillation.DistillationSaveState +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.distillation.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Main configuration structure. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.distillation._default_distillation_save_state() -> nemo_rl.algorithms.distillation.DistillationSaveState +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.distillation.check_vocab_equality( + tokenizer: nemo_rl.algorithms.distillation.TokenizerType, + student_model_name: str, + teacher_model_name: str +) -> None +``` + + + + + + +Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal. + + + + + + + + +```python +nemo_rl.algorithms.distillation.distillation_train( + student_policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, + teacher_policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, + student_generation: typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], + dataloader: torchdata.stateful_dataloader.StatefulDataLoader, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer: nemo_rl.algorithms.distillation.TokenizerType, + loss_fn: nemo_rl.algorithms.loss_functions.DistillationLossFn, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], + logger: nemo_rl.utils.logger.Logger, + checkpointer: nemo_rl.utils.checkpoint.CheckpointManager, + distillation_save_state: nemo_rl.algorithms.distillation.DistillationSaveState, + master_config: nemo_rl.algorithms.distillation.MasterConfig +) -> None +``` + + + + + + +Run Distillation training algorithm. + + + + + + + + +```python +nemo_rl.algorithms.distillation.setup( + master_config: nemo_rl.algorithms.distillation.MasterConfig, + tokenizer: nemo_rl.algorithms.distillation.TokenizerType, + train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, + val_dataset: typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset] +) -> tuple[nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], torchdata.stateful_dataloader.StatefulDataLoader, typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.DistillationLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.distillation.DistillationSaveState, nemo_rl.algorithms.distillation.MasterConfig] +``` + + + + + + +Main entry point for distillation algorithm. + +**Returns:** `ColocatablePolicyInterface` + +tuple of student_policy, teacher_policy, student_generation, + + + + + + + + +```python +nemo_rl.algorithms.distillation.validate( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer, + val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], + step: int, + master_config: nemo_rl.algorithms.distillation.MasterConfig +) -> tuple[dict[str, typing.Any], dict[str, typing.Any]] +``` + + + + + + +Run validation on the validation dataset. + + + + + + + + +```python +nemo_rl.algorithms.distillation.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx new file mode 100644 index 0000000..3d57261 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx @@ -0,0 +1,378 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/dpo +title: nemo_rl.algorithms.dpo +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DPOConfig`](#nemo_rl-algorithms-dpo-DPOConfig) | - | +| [`DPOSaveState`](#nemo_rl-algorithms-dpo-DPOSaveState) | - | +| [`DPOValMetrics`](#nemo_rl-algorithms-dpo-DPOValMetrics) | - | +| [`MasterConfig`](#nemo_rl-algorithms-dpo-MasterConfig) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_default_dpo_save_state`](#nemo_rl-algorithms-dpo-_default_dpo_save_state) | - | +| [`add_ref_logprobs_to_data`](#nemo_rl-algorithms-dpo-add_ref_logprobs_to_data) | - | +| [`dpo_train`](#nemo_rl-algorithms-dpo-dpo_train) | - | +| [`setup`](#nemo_rl-algorithms-dpo-setup) | Main entry point for running DPO algorithm. | +| [`validate`](#nemo_rl-algorithms-dpo-validate) | - | +| [`validate_one_dataset`](#nemo_rl-algorithms-dpo-validate_one_dataset) | Run validation on one validation dataset. | + +### API + + + + + +```python +class nemo_rl.algorithms.dpo.DPOConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.dpo.DPOSaveState +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.dpo.DPOValMetrics +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.dpo.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.dpo._default_dpo_save_state() -> nemo_rl.algorithms.dpo.DPOSaveState +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.dpo.add_ref_logprobs_to_data( + dataloader, + policy, + master_config, + is_val = False +) +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.dpo.dpo_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + checkpointer, + dpo_save_state: nemo_rl.algorithms.dpo.DPOSaveState +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.dpo.setup( + master_config: nemo_rl.algorithms.dpo.MasterConfig, + tokenizer: transformers.AutoTokenizer, + train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, + val_dataset: dict[str, nemo_rl.data.datasets.AllTaskProcessedDataset] +) -> tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, dict[str, torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.DPOLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.dpo.DPOSaveState, nemo_rl.algorithms.dpo.MasterConfig] +``` + + + + + + +Main entry point for running DPO algorithm. + +**Returns:** `tuple[Policy, RayVirtualCluster, StatefulDataLoader, dict[str, StatefulDataLoader], DPOLossFn, Logger, CheckpointManager, DPOSaveState, MasterConfig]` + +Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger + + + + + + + + +```python +nemo_rl.algorithms.dpo.validate( + policy: nemo_rl.models.policy.interfaces.PolicyInterface, + val_dataloader: dict[str, torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer, + loss_fn, + step: int, + master_config: nemo_rl.algorithms.dpo.MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int, + logger: nemo_rl.utils.logger.Logger +) +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.dpo.validate_one_dataset( + policy: nemo_rl.models.policy.interfaces.PolicyInterface, + val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader, + loss_fn, + step: int, + master_config: nemo_rl.algorithms.dpo.MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int, + dataset_name: str +) +``` + + + + + + +Run validation on one validation dataset. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx new file mode 100644 index 0000000..b8db0fe --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx @@ -0,0 +1,864 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/grpo +title: nemo_rl.algorithms.grpo +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AdvEstimatorConfig`](#nemo_rl-algorithms-grpo-AdvEstimatorConfig) | Configuration for advantage estimator (GRPO or Reinforce++). | +| [`AsyncGRPOConfig`](#nemo_rl-algorithms-grpo-AsyncGRPOConfig) | - | +| [`GRPOConfig`](#nemo_rl-algorithms-grpo-GRPOConfig) | - | +| [`GRPOLoggerConfig`](#nemo_rl-algorithms-grpo-GRPOLoggerConfig) | - | +| [`GRPOSaveState`](#nemo_rl-algorithms-grpo-GRPOSaveState) | - | +| [`MasterConfig`](#nemo_rl-algorithms-grpo-MasterConfig) | - | +| [`RewardScalingConfig`](#nemo_rl-algorithms-grpo-RewardScalingConfig) | Configure linear reward scaling with clamping. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_create_advantage_estimator`](#nemo_rl-algorithms-grpo-_create_advantage_estimator) | Create and return an advantage estimator based on configuration. | +| [`_default_grpo_save_state`](#nemo_rl-algorithms-grpo-_default_grpo_save_state) | - | +| [`_extract_prompt_only_messages`](#nemo_rl-algorithms-grpo-_extract_prompt_only_messages) | Extract only prompt messages (user/system) from message logs. | +| [`_log_mixed_rewards_and_advantages_information`](#nemo_rl-algorithms-grpo-_log_mixed_rewards_and_advantages_information) | - | +| [`_should_log_nemo_gym_responses`](#nemo_rl-algorithms-grpo-_should_log_nemo_gym_responses) | - | +| [`_should_use_async_rollouts`](#nemo_rl-algorithms-grpo-_should_use_async_rollouts) | Determine if async rollouts should be used based on the configuration. | +| [`_should_use_nemo_gym`](#nemo_rl-algorithms-grpo-_should_use_nemo_gym) | Determine if NeMo-Gym should be used for rollouts and validation based on the configuration. | +| [`async_grpo_train`](#nemo_rl-algorithms-grpo-async_grpo_train) | Run asynchronous GRPO training with replay buffer. | +| [`dynamic_sampling`](#nemo_rl-algorithms-grpo-dynamic_sampling) | Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation. | +| [`grpo_train`](#nemo_rl-algorithms-grpo-grpo_train) | Run GRPO training algorithm. | +| [`refit_policy_generation`](#nemo_rl-algorithms-grpo-refit_policy_generation) | Refit the policy generation interface with the latest policy weights. | +| [`scale_rewards`](#nemo_rl-algorithms-grpo-scale_rewards) | Linearly scales rewards from a source range to a target range. | +| [`setup`](#nemo_rl-algorithms-grpo-setup) | Main entry point for running GRPO algorithm. | +| [`validate`](#nemo_rl-algorithms-grpo-validate) | Run validation on the validation dataset. | + +### Data + +[`TokenizerType`](#nemo_rl-algorithms-grpo-TokenizerType) + +### API + + + + + +```python +class nemo_rl.algorithms.grpo.AdvEstimatorConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for advantage estimator (GRPO or Reinforce++). + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.AsyncGRPOConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.GRPOConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.GRPOLoggerConfig() +``` + + + + + + +**Bases:** [LoggerConfig](/nemo-rl/nemo_rl/utils/logger#nemo_rl-utils-logger-LoggerConfig) + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.GRPOSaveState +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.RewardScalingConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configure linear reward scaling with clamping. + +When `enabled` is True, each reward is clamped to the source interval +[source_min, source_max] and linearly mapped to the target interval +[target_min, target_max]. Refer to the scale_rewards function for the implementation. + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.grpo._create_advantage_estimator( + master_config: nemo_rl.algorithms.grpo.MasterConfig +) +``` + + + + + + +Create and return an advantage estimator based on configuration. + +**Parameters:** + + +The master configuration dictionary. + + +**Returns:** + +An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator). + +**Raises:** + +- `ValueError`: If the advantage estimator name is not recognized. + + + + + + + + +```python +nemo_rl.algorithms.grpo._default_grpo_save_state() -> nemo_rl.algorithms.grpo.GRPOSaveState +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.grpo._extract_prompt_only_messages( + message_logs: list +) -> list +``` + + + + + + +Extract only prompt messages (user/system) from message logs. + +This is used to get prompt IDs for advantage estimation, excluding +any assistant responses. + +**Parameters:** + + +List of message logs, where each log is a list of messages. + + +**Returns:** `list` + +List of message logs containing only user and system messages. + + + + + + + + +```python +nemo_rl.algorithms.grpo._log_mixed_rewards_and_advantages_information( + logger: nemo_rl.utils.logger.Logger, + total_steps: int, + metrics: dict[str, typing.Any], + baseline: torch.Tensor, + advantages: torch.Tensor +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.grpo._should_log_nemo_gym_responses( + master_config: nemo_rl.algorithms.grpo.MasterConfig +) -> bool +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.grpo._should_use_async_rollouts( + master_config: nemo_rl.algorithms.grpo.MasterConfig +) -> bool +``` + + + + + + +Determine if async rollouts should be used based on the configuration. + +Returns True if vLLM backend is used with async_engine enabled. + + + + + + + + +```python +nemo_rl.algorithms.grpo._should_use_nemo_gym( + master_config: nemo_rl.algorithms.grpo.MasterConfig +) -> bool +``` + + + + + + +Determine if NeMo-Gym should be used for rollouts and validation based on the configuration. + + + + + + + + +```python +nemo_rl.algorithms.grpo.async_grpo_train( + policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, + policy_generation: typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], + dataloader: torchdata.stateful_dataloader.StatefulDataLoader, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer: nemo_rl.algorithms.grpo.TokenizerType, + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], + logger: nemo_rl.utils.logger.Logger, + checkpointer: nemo_rl.utils.checkpoint.CheckpointManager, + grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState, + master_config: nemo_rl.algorithms.grpo.MasterConfig, + max_trajectory_age_steps: int = 1 +) -> None +``` + + + + + + +Run asynchronous GRPO training with replay buffer. + +**Parameters:** + + +Training policy + + + +Generation interface + + + +Training data loader + + + +Validation data loader + + + +Tokenizer + + + +Loss function + + + +Training environments + + + +Validation environments + + + +Logger + + + +Checkpoint manager + + + +Training state + + + +Master configuration + + + +Maximum age (in training steps) for trajectories to be used in training + + + + + + + + + +```python +nemo_rl.algorithms.grpo.dynamic_sampling( + repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + std: torch.Tensor, + baseline: torch.Tensor, + dynamic_sampling_num_gen_batches: int, + master_config: nemo_rl.algorithms.grpo.MasterConfig, + timer: nemo_rl.utils.timer.Timer, + batch_cache: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] +``` + + + + + + +Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation. + +This function filters the current batch to retain only those prompts that have a non-zero standard deviation. +If the current batch has fewer number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, +we store it in the batch_cache to be used in later iterations. +If the current batch has more number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, +the batch is sliced to ensure batch size is num_prompts_per_step * num_generations_per_prompt. +is_batch_complete is set to False to indicate that the current batch is not enough to meet the required batch size. This is used as a signal in the GRPO training loop +to continue sampling or proceed to training. +This approach is based on the dynamic sampling algorithm from the DAPO paper: +https://arxiv.org/pdf/2503.14476. + +**Parameters:** + + +The current batch of data containing prompts, responses, rewards, baselines, and std. + + + +Tensor representing the standard deviation for each prompt group. + + + +Baseline values for each prompt group. + + + +Number of generation batches processed at the current step. + + + +Configuration containing GRPO and policy settings. + + + +Cache storing previously selected prompts with non-zero std. + + +**Returns:** `BatchedDataDict[DatumSpec]` + +A tuple containing: +- repeated_batch (BatchedDataDict[DatumSpec]): Updated batch with selected prompts. +- is_batch_complete (bool): Indicates if the batch has enough samples with non-zero std for training. +- batch_cache (BatchedDataDict[DatumSpec]): Updated cache for future iterations. + + + + + + + + +```python +nemo_rl.algorithms.grpo.grpo_train( + policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, + policy_generation: typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], + dataloader: torchdata.stateful_dataloader.StatefulDataLoader, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer: nemo_rl.algorithms.grpo.TokenizerType, + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], + logger: nemo_rl.utils.logger.Logger, + checkpointer: nemo_rl.utils.checkpoint.CheckpointManager, + grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState, + master_config: nemo_rl.algorithms.grpo.MasterConfig +) -> None +``` + + + + + + +Run GRPO training algorithm. + + + + + + + + +```python +nemo_rl.algorithms.grpo.refit_policy_generation( + policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + colocated_inference: bool, + _refit_buffer_size_gb: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None, + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Refit the policy generation interface with the latest policy weights. + +**Parameters:** + + +The policy to provide weights to the inference engine. + + + +The inference engine to refit. + + + +The size of the buffer to use for refitting. +If it is None, the buffer size will be computed by the remaining memory. +This parameter is primarily used for testing. + + + +Optional Timer used to time the prepare/transfer/update phase + + + +Optional dictionary of KV cache scales for FP8 quantization. + + + + + + + + + +```python +nemo_rl.algorithms.grpo.scale_rewards( + repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + reward_scaling_cfg: nemo_rl.algorithms.grpo.RewardScalingConfig +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] +``` + + + + + + +Linearly scales rewards from a source range to a target range. + +If `reward_scaling.enabled` is True, each reward in `repeated_batch["total_reward"]` +is clamped to the configured source interval [source_min, source_max] and then +rescaled to the target interval [target_min, target_max]. + + + + + + + + +```python +nemo_rl.algorithms.grpo.setup( + master_config: nemo_rl.algorithms.grpo.MasterConfig, + tokenizer: nemo_rl.algorithms.grpo.TokenizerType, + dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, + val_dataset: typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset], + processor: typing.Optional[transformers.AutoProcessor] = None +) -> tuple[nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], tuple[nemo_rl.distributed.virtual_cluster.RayVirtualCluster, nemo_rl.distributed.virtual_cluster.RayVirtualCluster], torchdata.stateful_dataloader.StatefulDataLoader, typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.ClippedPGLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.grpo.GRPOSaveState, nemo_rl.algorithms.grpo.MasterConfig] +``` + + + + + + +Main entry point for running GRPO algorithm. + +**Returns:** `tuple[ColocatablePolicyInterface, Optional[GenerationInterface], tuple[RayVirtualCluster, RayVirtualCluster], StatefulDataLoader, Optional[StatefulDataLoader], ClippedPGLossFn, Logger, CheckpointManager, GRPOSaveState, MasterConfig]` + +tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader + + + + + + + + +```python +nemo_rl.algorithms.grpo.validate( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer, + val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], + step: int, + master_config: nemo_rl.algorithms.grpo.MasterConfig, + logger: typing.Optional[nemo_rl.utils.logger.Logger] = None +) -> tuple[dict[str, typing.Any], dict[str, typing.Any]] +``` + + + + + + +Run validation on the validation dataset. + + + + + + + + +```python +nemo_rl.algorithms.grpo.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx new file mode 100644 index 0000000..7976052 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx @@ -0,0 +1,123 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/interfaces +title: nemo_rl.algorithms.interfaces +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LossFunction`](#nemo_rl-algorithms-interfaces-LossFunction) | Signature for loss functions used in reinforcement learning algorithms. | +| [`LossType`](#nemo_rl-algorithms-interfaces-LossType) | - | + +### API + + + + + +```python +class nemo_rl.algorithms.interfaces.LossFunction() +``` + + + + + + +Protocol + +Signature for loss functions used in reinforcement learning algorithms. + +Loss functions compute a scalar loss value and associated metrics from +model logprobs and other data contained in a BatchedDataDict. + + + + + + + + +```python +nemo_rl.algorithms.interfaces.LossFunction.__call__( + next_token_logits: torch.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + +Compute loss and metrics from logprobs and other data. + +**Parameters:** + + +Logits from the model, typically with shape [batch_size, seq_len, vocab_size]. + For each position (b, i), contains the logit distribution over the entire vocabulary + for predicting the next token (at position i+1). For example, if processing "The cat sat on", + then next_token_logits[b, 3] would contain the logits for predicting the word + that follows "on". + + + +Dictionary containing all relevant data for loss computation + such as rewards, values, actions, advantages, masks, and other + algorithm-specific information needed for the particular loss calculation. + + + +torch.Tensor +this tensor should contain the number of valid sequences in the microbatch. +It's used for global normalization for losses/metrics that are computed at the sequence level +and needs to be aggregated across all microbatches. + + + +torch.Tensor +This tensor should contain the number of valid tokens in the microbatch. +It's used for global normalization for losses/metrics that are computed at the token level +and needs to be aggregated across all microbatches. + + +**Returns:** `tuple[torch.Tensor, dict[str, Any]]` + +(loss, metrics) +- loss: A scalar tensor representing the loss value to be minimized during training +- metrics: A dictionary of metrics related to the loss computation, which may include + component losses, statistics about gradients/rewards, and other diagnostic information + + + + + + + + + +```python +class nemo_rl.algorithms.interfaces.LossType +``` + + + + + + +**Bases:** `enum.Enum` + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx new file mode 100644 index 0000000..f8307d1 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx @@ -0,0 +1,875 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/loss_functions +title: nemo_rl.algorithms.loss_functions +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ClippedPGLossConfig`](#nemo_rl-algorithms-loss_functions-ClippedPGLossConfig) | - | +| [`ClippedPGLossDataDict`](#nemo_rl-algorithms-loss_functions-ClippedPGLossDataDict) | Required keys for the Clipped Policy Gradient loss function. | +| [`ClippedPGLossFn`](#nemo_rl-algorithms-loss_functions-ClippedPGLossFn) | Generalized Clipped Policy Gradient loss function w/ KL regularization. | +| [`DPOLossConfig`](#nemo_rl-algorithms-loss_functions-DPOLossConfig) | - | +| [`DPOLossDataDict`](#nemo_rl-algorithms-loss_functions-DPOLossDataDict) | Required keys for the DPO loss function. | +| [`DPOLossFn`](#nemo_rl-algorithms-loss_functions-DPOLossFn) | Direct Preference Optimization (DPO) loss function. | +| [`DistillationLossConfig`](#nemo_rl-algorithms-loss_functions-DistillationLossConfig) | - | +| [`DistillationLossDataDict`](#nemo_rl-algorithms-loss_functions-DistillationLossDataDict) | - | +| [`DistillationLossFn`](#nemo_rl-algorithms-loss_functions-DistillationLossFn) | Distillation loss function. | +| [`NLLLoss`](#nemo_rl-algorithms-loss_functions-NLLLoss) | Negative Log Likelihood Loss function. | +| [`PreferenceLoss`](#nemo_rl-algorithms-loss_functions-PreferenceLoss) | Preference Loss function. | +| [`PreferenceLossDataDict`](#nemo_rl-algorithms-loss_functions-PreferenceLossDataDict) | Required keys for the preference loss function. | +| [`SequencePackingLossWrapper`](#nemo_rl-algorithms-loss_functions-SequencePackingLossWrapper) | - | + +### Data + +[`Tensor`](#nemo_rl-algorithms-loss_functions-Tensor) + +### API + + + + + +```python +class nemo_rl.algorithms.loss_functions.ClippedPGLossConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict +``` + + + + + + +**Bases:** `typing.TypedDict` + +Required keys for the Clipped Policy Gradient loss function. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.ClippedPGLossFn( + cfg: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig +) +``` + + + + + + +**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) + +Generalized Clipped Policy Gradient loss function w/ KL regularization. + +This implements: + +- PPO (Clipped) - https://arxiv.org/abs/1707.06347 +- GRPO - https://arxiv.org/abs/2402.03300 +- REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740 +- GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071 +- Truly on-policy (set force_on_policy_ratio = True to force ratio = 1.0, requires one update per rollout) + +Formula: +L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref) + +where: +- r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio +- A_t is the advantage estimate +- ε is the clip parameter (ratio_clip_min/ratio_clip_max) + - As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476), + we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.) + - ratio_clip_min: minimum value for the clip parameter + - ratio_clip_max: maximum value for the clip parameter +- β is the KL penalty coefficient (reference_policy_kl_penalty) +- KL(π_θ || π_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.) + +For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: +L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref) + +Also supports "Dual-Clipping" from https://arxiv.org/pdf/1912.09729, which +imposes an additional upper bound on the probability ratio when advantages are negative. +This prevents excessive policy updates. $rA << 0$ -> $cA$(clipped) +The loss function is modified to the following when A_t < 0: +L(θ) = E_t [ max(min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t), c * A_t) ] - β * KL(π_θ || π_ref) + +where: +- c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is + usually set as 3 empirically. + +Due to potential numerical instability, we cast the logits to float32 before computing the loss. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.ClippedPGLossFn.__call__( + next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict], + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> tuple[torch.Tensor, dict] +``` + + + + + + +Clipped Policy Gradient RL loss function. + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DPOLossConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DPOLossDataDict +``` + + + + + + +**Bases:** `typing.TypedDict` + +Required keys for the DPO loss function. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DPOLossFn( + cfg: nemo_rl.algorithms.loss_functions.DPOLossConfig +) +``` + + + + + + +**Bases:** [PreferenceLoss](#nemo_rl-algorithms-loss_functions-PreferenceLoss) + +Direct Preference Optimization (DPO) loss function. + +This loss function implements the DPO algorithm as described in: +"Direct Preference Optimization: Your Language Model is Secretly a Reward Model" +(https://arxiv.org/abs/2305.18290) + +The loss combines two main components: +1. Preference Loss: Optimizes the model to prefer chosen responses over rejected ones +2. SFT Loss (optional): Auxiliary supervised fine-tuning loss on chosen responses + +The total loss is computed as: +L(θ) = w_p * L_pref(θ) + w_s * L_sft(θ) + +where: +- w_p is the preference_loss_weight +- w_s is the sft_loss_weight +- L_pref(θ) is the preference loss term +- L_sft(θ) is the supervised fine-tuning loss term + +The preference loss term is computed as: +L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))] + +where: +- σ is the sigmoid function +- β is the reference_policy_kl_penalty +- r_chosen and r_rejected are the rewards for chosen and rejected responses +- The rewards are computed as the sum of log probability differences between + the current policy and reference policy + +If preference_average_log_probs is True, the rewards are averaged over tokens: +r = (1/n) * Σ_t (log π_θ(a_t|s_t) - log π_ref(a_t|s_t)) + +Otherwise, the rewards are summed over tokens. + +The SFT loss term is a standard negative log likelihood loss on the chosen responses. +If sft_average_log_probs is True, the loss is averaged over tokens. + +**Parameters:** + + +Configuration dictionary containing: +- reference_policy_kl_penalty (float): Strength of the KL penalty term (β) +- preference_loss_weight (float): Weight for the preference loss term (w_p) +- sft_loss_weight (float): Weight for the SFT loss term (w_s) +- preference_average_log_probs (bool): Whether to average log probs across tokens in preference loss +- sft_average_log_probs (bool): Whether to average log probs across tokens in SFT loss + + +**Returns:** + +tuple[torch.Tensor, dict]: A tuple containing: +- The total loss value +- A dictionary with metrics including: + - loss: Total loss value + - sft_loss: SFT loss component + - preference_loss: Preference loss component + - accuracy: Fraction of examples where chosen response has higher reward + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.DPOLossFn.__call__( + next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict], + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, + global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.DPOLossFn._dpo_loss( + next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict], + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor] +``` + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DistillationLossConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DistillationLossDataDict +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DistillationLossFn( + cfg: nemo_rl.algorithms.loss_functions.DistillationLossConfig +) +``` + + + + + + +**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) + +Distillation loss function. + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.DistillationLossFn.__call__( + next_token_logits: torch.Tensor, + data: nemo_rl.algorithms.loss_functions.DistillationLossDataDict, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + +Compute distillation loss between teacher and student logits. + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.NLLLoss() +``` + + + + + + +**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) + +Negative Log Likelihood Loss function. + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.NLLLoss.__call__( + next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor | None, + global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + dpo_loss: bool = False, + dpo_average_log_probs: bool = False +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.PreferenceLoss() +``` + + + + + + +**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) + +Preference Loss function. + +Optimizes the model to prefer chosen responses over rejected ones + +The preference loss is computed as: +L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))] + +where: +- σ is the sigmoid function +- β is a scaling factor (ex: `reference_policy_kl_penalty` in DPO) +- r_chosen and r_rejected are the rewards for chosen and rejected responses + +**Returns:** + +tuple[torch.Tensor, dict]: A tuple containing: +- The preference loss value +- A dictionary with metrics including: + - loss: Preference loss + - accuracy: Fraction of examples where chosen response has higher reward + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.PreferenceLoss.__call__( + rewards: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.PreferenceLossDataDict], + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, + global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.PreferenceLoss._preference_loss( + rewards: nemo_rl.algorithms.loss_functions.Tensor, + sample_mask: nemo_rl.algorithms.loss_functions.Tensor, + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, + beta: float = 1.0 +) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor] +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.PreferenceLoss.split_output_tensor( + tensor: nemo_rl.algorithms.loss_functions.Tensor +) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor] +``` + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.PreferenceLossDataDict +``` + + + + + + +**Bases:** `typing.TypedDict` + +Required keys for the preference loss function. + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.SequencePackingLossWrapper( + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + cu_seqlens_q: nemo_rl.algorithms.loss_functions.Tensor, + cu_seqlens_q_padded: typing.Optional[nemo_rl.algorithms.loss_functions.Tensor] = None +) +``` + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.SequencePackingLossWrapper.__call__( + next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor | None, + global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, dict[str, typing.Any]] +``` + + + + + + +Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding. + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.Tensor = TypeVar('Tensor', bound=(torch.Tensor)) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx new file mode 100644 index 0000000..ffcae23 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx @@ -0,0 +1,102 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/reward_functions +title: nemo_rl.algorithms.reward_functions +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RewardShapingConfig`](#nemo_rl-algorithms-reward_functions-RewardShapingConfig) | Configuration for reward function processing. | + +### Functions + +| Name | Description | +|------|-------------| +| [`apply_reward_shaping`](#nemo_rl-algorithms-reward_functions-apply_reward_shaping) | Process rewards by applying penalties for responses exceeding max_response_length. Currently, this function only supports DAPO reward shaping as illustrated in the DAPO paper : https://arxiv.org/pdf/2503.14476. | + +### Data + +[`Tensor`](#nemo_rl-algorithms-reward_functions-Tensor) + +### API + + + + + +```python +class nemo_rl.algorithms.reward_functions.RewardShapingConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for reward function processing. + +This configuration enables custom reward shaping, currently supporting DAPO-style +penalties for responses that exceed the maximum response length threshold. + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.reward_functions.apply_reward_shaping( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict, + cfg: nemo_rl.algorithms.reward_functions.RewardShapingConfig +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict +``` + + + + + + +Process rewards by applying penalties for responses exceeding max_response_length. Currently, this function only supports DAPO reward shaping as illustrated in the DAPO paper : https://arxiv.org/pdf/2503.14476. + +Nonetheless, it can be potentially extended to support any custom reward logic. + + + + + + + + +```python +nemo_rl.algorithms.reward_functions.Tensor = TypeVar('Tensor', bound=(torch.Tensor)) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx new file mode 100644 index 0000000..ed41f3a --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx @@ -0,0 +1,320 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/rm +title: nemo_rl.algorithms.rm +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MasterConfig`](#nemo_rl-algorithms-rm-MasterConfig) | - | +| [`RMConfig`](#nemo_rl-algorithms-rm-RMConfig) | - | +| [`RMSaveState`](#nemo_rl-algorithms-rm-RMSaveState) | - | +| [`RMValMetrics`](#nemo_rl-algorithms-rm-RMValMetrics) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_default_rm_save_state`](#nemo_rl-algorithms-rm-_default_rm_save_state) | - | +| [`rm_train`](#nemo_rl-algorithms-rm-rm_train) | - | +| [`setup`](#nemo_rl-algorithms-rm-setup) | Main entry point for running RM algorithm. | +| [`validate`](#nemo_rl-algorithms-rm-validate) | - | +| [`validate_one_dataset`](#nemo_rl-algorithms-rm-validate_one_dataset) | Run validation on one validation dataset. | + +### API + + + + + +```python +class nemo_rl.algorithms.rm.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.rm.RMConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.rm.RMSaveState +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.rm.RMValMetrics +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.rm._default_rm_save_state() -> nemo_rl.algorithms.rm.RMSaveState +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.rm.rm_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + checkpointer, + rm_save_state +) +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.rm.setup( + master_config: nemo_rl.algorithms.rm.MasterConfig, + tokenizer: transformers.AutoTokenizer, + train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, + val_dataset: dict[str, nemo_rl.data.datasets.AllTaskProcessedDataset] +) -> tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, dict[str, torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.PreferenceLoss, nemo_rl.algorithms.rm.MasterConfig, nemo_rl.utils.logger.Logger, nemo_rl.data.interfaces.TaskDataSpec, nemo_rl.algorithms.rm.RMSaveState] +``` + + + + + + +Main entry point for running RM algorithm. + +**Returns:** `tuple[Policy, RayVirtualCluster, StatefulDataLoader, dict[str, StatefulDataLoader], PreferenceLoss, MasterConfig, Logger, TaskDataSpec, RMSaveState]` + +Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger + + + + + + + + +```python +nemo_rl.algorithms.rm.validate( + policy: nemo_rl.models.policy.interfaces.PolicyInterface, + val_dataloader: dict[str, torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer, + loss_fn, + step: int, + master_config: nemo_rl.algorithms.rm.MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int, + logger: nemo_rl.utils.logger.Logger +) +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.rm.validate_one_dataset( + policy: nemo_rl.models.policy.interfaces.PolicyInterface, + val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader, + loss_fn, + step: int, + master_config: nemo_rl.algorithms.rm.MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int, + dataset_name: str +) +``` + + + + + + +Run validation on one validation dataset. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx new file mode 100644 index 0000000..d9a3bd6 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx @@ -0,0 +1,258 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/sft +title: nemo_rl.algorithms.sft +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MasterConfig`](#nemo_rl-algorithms-sft-MasterConfig) | - | +| [`SFTConfig`](#nemo_rl-algorithms-sft-SFTConfig) | - | +| [`SFTSaveState`](#nemo_rl-algorithms-sft-SFTSaveState) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_default_sft_save_state`](#nemo_rl-algorithms-sft-_default_sft_save_state) | - | +| [`setup`](#nemo_rl-algorithms-sft-setup) | Main entry point for running SFT algorithm. | +| [`sft_train`](#nemo_rl-algorithms-sft-sft_train) | - | +| [`validate`](#nemo_rl-algorithms-sft-validate) | Run validation on the validation dataset. | + +### API + + + + + +```python +class nemo_rl.algorithms.sft.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.sft.SFTConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.sft.SFTSaveState +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.sft._default_sft_save_state() -> nemo_rl.algorithms.sft.SFTSaveState +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.sft.setup( + master_config: nemo_rl.algorithms.sft.MasterConfig, + tokenizer: transformers.AutoTokenizer, + train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, + val_dataset: typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset] +) -> tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.NLLLoss, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.sft.SFTSaveState, nemo_rl.algorithms.sft.MasterConfig] +``` + + + + + + +Main entry point for running SFT algorithm. + +**Returns:** `tuple[Policy, RayVirtualCluster, StatefulDataLoader, Optional[StatefulDataLoader], NLLLoss, Logger, CheckpointManager, SFTSaveState, MasterConfig]` + +Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger + + + + + + + + +```python +nemo_rl.algorithms.sft.sft_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + checkpointer, + sft_save_state: nemo_rl.algorithms.sft.SFTSaveState +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.sft.validate( + policy: nemo_rl.models.policy.interfaces.PolicyInterface, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer, + loss_fn, + step: int, + master_config: nemo_rl.algorithms.sft.MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int +) +``` + + + + + + +Run validation on the validation dataset. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx new file mode 100644 index 0000000..200ecb0 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx @@ -0,0 +1,379 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/utils +title: nemo_rl.algorithms.utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`calculate_baseline_and_std_per_prompt`](#nemo_rl-algorithms-utils-calculate_baseline_and_std_per_prompt) | Function to compute a baseline for each (prompt, response) pair in the batch. | +| [`calculate_kl`](#nemo_rl-algorithms-utils-calculate_kl) | Calculates a per-token estimate of the KL Divergence between two logprobs. | +| [`get_tokenizer`](#nemo_rl-algorithms-utils-get_tokenizer) | Get the tokenizer and set pad token to eos token if it is not already set. | +| [`log_generation_metrics_to_wandb`](#nemo_rl-algorithms-utils-log_generation_metrics_to_wandb) | Log generation metrics to wandb. | +| [`masked_mean`](#nemo_rl-algorithms-utils-masked_mean) | Computes the mean of a microbatch, using a global statistic as the normalization factor. | +| [`maybe_pad_last_batch`](#nemo_rl-algorithms-utils-maybe_pad_last_batch) | Pads the given batch so that its size is divisible by (mbs * dp_size). | +| [`print_performance_metrics`](#nemo_rl-algorithms-utils-print_performance_metrics) | Print performance metrics for GRPO. | +| [`set_seed`](#nemo_rl-algorithms-utils-set_seed) | Sets the seed for python, numpy, and pytorch. | +| [`surpress_user_warnings`](#nemo_rl-algorithms-utils-surpress_user_warnings) | - | + +### API + + + + + +```python +nemo_rl.algorithms.utils.calculate_baseline_and_std_per_prompt( + prompts: torch.Tensor, + rewards: torch.Tensor, + valid_mask: torch.Tensor, + leave_one_out_baseline: bool = True +) -> tuple[torch.Tensor, torch.Tensor] +``` + + + + + + +Function to compute a baseline for each (prompt, response) pair in the batch. + +The same baseline is calculated for each prompt. Samples set to 0 in 'valid_mask' +are not included in the baseline calculation. + +prompts: tensor (b, s) Tensor of prompts the model used. May be on any device +rewards: tensor (b,) Float-valued rewards. May be on any device +valid_mask: tensor (b,) Vector of 0/1, where 0 is to ignore and 1 is to keep +leave_one_out_baseline: bool Compute an unbiased baseline by leaving out the sample that + the baseline is for (from RLOO https://arxiv.org/abs/2402.14740) + +Returns: +tensor (b,), tensor (b,) of baselines and std on the same device as 'rewards' + + + + + + + + +```python +nemo_rl.algorithms.utils.calculate_kl( + logprobs: torch.Tensor, + logprobs_reference: torch.Tensor, + kl_type: str = 'k3', + input_clamp_value: float | None = 20.0, + output_clamp_value: float | None = 10.0 +) -> torch.Tensor +``` + + + + + + +Calculates a per-token estimate of the KL Divergence between two logprobs. + +From Schulman 2020, http://joschu.net/blog/kl-approx.html. + +**Parameters:** + + +torch.Tensor (b, s) + + + +torch.Tensor (b, s) + + + +Type of KL approximation to use. Valid values: "k1", "k2", "k3". + + + +Optional clamping value for logr to prevent numerical instability. + If None, no clamping is applied. + + + +Optional clamping value for kl to prevent numerical instability. + If None, no clamping is applied. + + +**Returns:** `torch.Tensor` + +torch.Tensor: Per-token KL penalty values (b, s) + + + + + + + + +```python +nemo_rl.algorithms.utils.get_tokenizer( + tokenizer_config: nemo_rl.models.policy.TokenizerConfig, + get_processor: bool = False +) -> transformers.PreTrainedTokenizerBase +``` + + + + + + +Get the tokenizer and set pad token to eos token if it is not already set. + +This function initializes a tokenizer from the Hugging Face transformers library +and configures it with appropriate chat templates and padding tokens. + +**Parameters:** + + +A dictionary containing tokenizer configuration. +Required keys: + - name: The name or path of the pretrained tokenizer +Optional keys: + - chat_template: The chat template to use. Can be: + - None: Uses a passthrough template that just returns message content + - "default": Uses the tokenizer's default template + - A custom jinja2 template string + If not specified, the tokenizer's default template will be used. + + + +Whether to return a processor (via AutoProcessor) instead of a tokenizer. + + +**Returns:** `PreTrainedTokenizerBase` + +The configured tokenizer instance + +**Examples:** + + + +```python +>>> from transformers import AutoTokenizer +>>> from nemo_rl.algorithms.utils import get_tokenizer +>>> # not specifying a chat template uses the tokenizer's default +>>> config = {"name": "meta-llama/Llama-3.2-1B-Instruct"} +>>> tokenizer = get_tokenizer(config) +No chat template provided, using tokenizer's default +>>> messages = [ +... {"role": "system", "content": "You are a helpful AI assistant."}, +... {"role": "user", "content": "Hello!"} +... ] +>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) +>>> assert formatted == AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct").apply_chat_template(messages, tokenize=False) + +>>> # Using a passthrough template +>>> config = { +... "name": "meta-llama/Llama-3.2-1B-Instruct", +... "chat_template": None +... } +>>> tokenizer = get_tokenizer(config) +Using passthrough chat template +>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) +>>> assert formatted == "".join(msg["content"] for msg in messages) + +>>> # Using a custom template +>>> config = { +... "name": "meta-llama/Llama-3.2-1B-Instruct", +... "chat_template": "{% for message in messages %}{{ ' START: ' + message['content'] + ' END.' }}{% endfor %}" +... } +>>> tokenizer = get_tokenizer(config) +Using custom chat template +>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) +>>> assert formatted == " START: You are a helpful AI assistant. END. START: Hello! END." + +>>> # Requesting a processor (for multimodal models like Qwen-VL) +>>> config = {"name": "Qwen/Qwen2.5-VL-3B-Instruct"} +>>> processor = get_tokenizer(config, get_processor=True) +No chat template provided, using tokenizer's default +>>> messages = [ +... {"role": "system", "content": "You are a helpful AI assistant."}, +... {"role": "user", "content": "Hello!"} +... ] +>>> formatted = processor.tokenizer.apply_chat_template(messages, tokenize=False) +>>> assert formatted == AutoTokenizer.from_pretrained( +... "Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True +... ).apply_chat_template(messages, tokenize=False) +>>> assert processor.pad_token_id == processor.tokenizer.pad_token_id +>>> +``` + + + + + + + + + + +```python +nemo_rl.algorithms.utils.log_generation_metrics_to_wandb( + generation_logger_metrics: dict[str, dict[int, list[typing.Any]]], + step: int, + timeline_interval: float, + logger: nemo_rl.utils.logger.Logger +) -> None +``` + + + + + + +Log generation metrics to wandb. + +**Parameters:** + + +Dictionary of generation logger metrics + + + +Global step value + + + +Interval between timeline points (in seconds) + + + +Logger instance + + + + + + + + + +```python +nemo_rl.algorithms.utils.masked_mean( + values: torch.Tensor, + mask: torch.Tensor, + dim: typing.Optional[int] = None, + global_normalization_factor: typing.Optional[torch.Tensor | float] = None +) +``` + + + + + + +Computes the mean of a microbatch, using a global statistic as the normalization factor. + + + + + + + + +```python +nemo_rl.algorithms.utils.maybe_pad_last_batch( + batch: dict, + dp_size: int, + mbs: int +) -> dict +``` + + + + + + +Pads the given batch so that its size is divisible by (mbs * dp_size). + +**Parameters:** + + +The batch to pad. + + + +Data parallel size. + + + +Micro batch size. + + +**Returns:** `dict` + +The padded batch. + + + + + + + + +```python +nemo_rl.algorithms.utils.print_performance_metrics( + train_results: dict[str, float], + metrics: dict[str, typing.Any], + timing_metrics: dict[str, float], + master_config: dict +) -> dict[str, float] +``` + + + + + + +Print performance metrics for GRPO. + + + + + + + + +```python +nemo_rl.algorithms.utils.set_seed( + seed: int +) -> None +``` + + + + + + +Sets the seed for python, numpy, and pytorch. + + + + + + + + +```python +nemo_rl.algorithms.utils.surpress_user_warnings( + f +) +``` + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx new file mode 100644 index 0000000..3cafa95 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx @@ -0,0 +1,466 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data +title: nemo_rl.data +--- + +## Subpackages + +- **[`nemo_rl.data.datasets`](/nemo-rl/nemo_rl/data/datasets)** +- **[`nemo_rl.data.packing`](/nemo-rl/nemo_rl/data/packing)** + +## Submodules + +- **[`nemo_rl.data.chat_templates`](/nemo-rl/nemo_rl/data/chat_templates)** +- **[`nemo_rl.data.collate_fn`](/nemo-rl/nemo_rl/data/collate_fn)** +- **[`nemo_rl.data.interfaces`](/nemo-rl/nemo_rl/data/interfaces)** +- **[`nemo_rl.data.llm_message_utils`](/nemo-rl/nemo_rl/data/llm_message_utils)** +- **[`nemo_rl.data.multimodal_utils`](/nemo-rl/nemo_rl/data/multimodal_utils)** +- **[`nemo_rl.data.processors`](/nemo-rl/nemo_rl/data/processors)** +- **[`nemo_rl.data.utils`](/nemo-rl/nemo_rl/data/utils)** + +## Package Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AIMEEvalDataConfig`](#nemo_rl-data-AIMEEvalDataConfig) | Config for AIME datasets. | +| [`DataConfig`](#nemo_rl-data-DataConfig) | - | +| [`GPQAEvalDataConfig`](#nemo_rl-data-GPQAEvalDataConfig) | Config for GPQA datasets. | +| [`LocalMathEvalDataConfig`](#nemo_rl-data-LocalMathEvalDataConfig) | Config for local math datasets loaded from files. | +| [`MMLUEvalDataConfig`](#nemo_rl-data-MMLUEvalDataConfig) | Config for MMLU and multilingual MMLU datasets. | +| [`MMLUProEvalDataConfig`](#nemo_rl-data-MMLUProEvalDataConfig) | Config for MMLU Pro dataset. | +| [`MathEvalDataConfig`](#nemo_rl-data-MathEvalDataConfig) | Config for Math datasets. | +| [`PreferenceDatasetConfig`](#nemo_rl-data-PreferenceDatasetConfig) | - | +| [`ResponseDatasetConfig`](#nemo_rl-data-ResponseDatasetConfig) | - | + +### Data + +[`EvalDataConfigType`](#nemo_rl-data-EvalDataConfigType) + +### API + + + + + +```python +class nemo_rl.data.AIMEEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for AIME datasets. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.DataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.GPQAEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for GPQA datasets. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.LocalMathEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for local math datasets loaded from files. + +dataset_name can be a URL or local file path. +Requires additional fields: problem_key, solution_key, file_format, split. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.MMLUEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for MMLU and multilingual MMLU datasets. + +Supports dataset_name: "mmlu" or "mmlu_{language}" where language is one of: +AR-XY, BN-BD, DE-DE, EN-US, ES-LA, FR-FR, HI-IN, ID-ID, IT-IT, JA-JP, +KO-KR, PT-BR, ZH-CN, SW-KE, YO-NG + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.MMLUProEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for MMLU Pro dataset. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.MathEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for Math datasets. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.PreferenceDatasetConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.ResponseDatasetConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.EvalDataConfigType = MMLUEvalDataConfig | MMLUProEvalDataConfig | AIMEEvalDataConfig | GPQAEvalDataCo... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx new file mode 100644 index 0000000..11e5f15 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx @@ -0,0 +1,35 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/chat_templates +title: nemo_rl.data.chat_templates +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`COMMON_CHAT_TEMPLATES`](#nemo_rl-data-chat_templates-COMMON_CHAT_TEMPLATES) | - | + +### API + + + + + +```python +class nemo_rl.data.chat_templates.COMMON_CHAT_TEMPLATES() +``` + + + + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx new file mode 100644 index 0000000..56b6bb7 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx @@ -0,0 +1,166 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/collate_fn +title: nemo_rl.data.collate_fn +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`eval_collate_fn`](#nemo_rl-data-collate_fn-eval_collate_fn) | Collate function for evaluation. | +| [`preference_collate_fn`](#nemo_rl-data-collate_fn-preference_collate_fn) | Collate function for preference data training. | +| [`rl_collate_fn`](#nemo_rl-data-collate_fn-rl_collate_fn) | Collate function for RL training. | + +### Data + +[`TokenizerType`](#nemo_rl-data-collate_fn-TokenizerType) + +### API + + + + + +```python +nemo_rl.data.collate_fn.eval_collate_fn( + data_batch: list[nemo_rl.data.interfaces.DatumSpec] +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +``` + + + + + + +Collate function for evaluation. + +Takes a list of data samples and combines them into a single batched dictionary +for model evaluation. + +Examples: + + +```python +>>> import torch +>>> from nemo_rl.data.collate_fn import eval_collate_fn +>>> from nemo_rl.data.interfaces import DatumSpec +>>> data_batch = [ +... DatumSpec( +... message_log=[{"role": "user", "content": "Hello", "token_ids": torch.tensor([1, 2, 3])}], +... extra_env_info={'ground_truth': '1'}, +... idx=0, +... ), +... DatumSpec( +... message_log=[{"role": "assistant", "content": "Hi there", "token_ids": torch.tensor([4, 5, 6, 7])}], +... extra_env_info={'ground_truth': '2'}, +... idx=1, +... ), +... ] +>>> output = eval_collate_fn(data_batch) +>>> output['message_log'][0] +[{'role': 'user', 'content': 'Hello', 'token_ids': tensor([1, 2, 3])}] +>>> output['message_log'][1] +[{'role': 'assistant', 'content': 'Hi there', 'token_ids': tensor([4, 5, 6, 7])}] +>>> output['extra_env_info'] +[{'ground_truth': '1'}, {'ground_truth': '2'}] +>>> output['idx'] +[0, 1] +``` + + + +**Parameters:** + + +List of data samples with message_log, extra_env_info, and idx fields. + + +**Returns:** `BatchedDataDict[Any]` + +BatchedDataDict with message_log, extra_env_info, and idx fields. + + + + + + + + +```python +nemo_rl.data.collate_fn.preference_collate_fn( + data_batch: list[nemo_rl.data.interfaces.PreferenceDatumSpec], + tokenizer: nemo_rl.data.collate_fn.TokenizerType, + make_sequence_length_divisible_by: int, + add_loss_mask: bool +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +``` + + + + + + +Collate function for preference data training. + +This function separates the chosen and rejected responses to create +two examples per prompt. The chosen and rejected examples are interleaved +along the batch dimension, resulting in a batch size of 2 * len(data_batch). + +Returns: + BatchedDataDict with input_ids, input_lengths, token_mask (optional), and sample_mask fields. + +**Parameters:** + + +List of data samples with message_log_chosen, message_log_rejected, length_chosen, length_rejected, loss_multiplier, idx, and task_name fields. + + + +Tokenizer for text processing + + + +Make the sequence length divisible by this value + + + +Whether to add a token_mask to the returned data + + + + + + + + + +```python +nemo_rl.data.collate_fn.rl_collate_fn( + data_batch: list[nemo_rl.data.interfaces.DatumSpec] +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +``` + + + + + + +Collate function for RL training. + + + + + + + + +```python +nemo_rl.data.collate_fn.TokenizerType = Union[PreTrainedTokenizerBase, AutoProcessor] +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx new file mode 100644 index 0000000..88450e5 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx @@ -0,0 +1,37 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets +title: nemo_rl.data.datasets +--- + +## Subpackages + +- **[`nemo_rl.data.datasets.eval_datasets`](/nemo-rl/nemo_rl/data/datasets/eval_datasets)** +- **[`nemo_rl.data.datasets.preference_datasets`](/nemo-rl/nemo_rl/data/datasets/preference_datasets)** +- **[`nemo_rl.data.datasets.response_datasets`](/nemo-rl/nemo_rl/data/datasets/response_datasets)** + +## Submodules + +- **[`nemo_rl.data.datasets.processed_dataset`](/nemo-rl/nemo_rl/data/datasets/processed_dataset)** +- **[`nemo_rl.data.datasets.raw_dataset`](/nemo-rl/nemo_rl/data/datasets/raw_dataset)** +- **[`nemo_rl.data.datasets.utils`](/nemo-rl/nemo_rl/data/datasets/utils)** + +## Package Contents + +### Data + +[`__all__`](#nemo_rl-data-datasets-__all__) + +### API + + + + + +```python +nemo_rl.data.datasets.__all__ = ['AllTaskProcessedDataset', 'load_eval_dataset', 'load_preference_dataset', 'loa... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx new file mode 100644 index 0000000..433590f --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx @@ -0,0 +1,60 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets +title: nemo_rl.data.datasets.eval_datasets +--- + +## Submodules + +- **[`nemo_rl.data.datasets.eval_datasets.aime`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime)** +- **[`nemo_rl.data.datasets.eval_datasets.gpqa`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa)** +- **[`nemo_rl.data.datasets.eval_datasets.local_math_dataset`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset)** +- **[`nemo_rl.data.datasets.eval_datasets.math`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/math)** +- **[`nemo_rl.data.datasets.eval_datasets.mmlu`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu)** +- **[`nemo_rl.data.datasets.eval_datasets.mmlu_pro`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`load_eval_dataset`](#nemo_rl-data-datasets-eval_datasets-load_eval_dataset) | Loads evaluation dataset. | + +### Data + +[`__all__`](#nemo_rl-data-datasets-eval_datasets-__all__) + +### API + + + + + +```python +nemo_rl.data.datasets.eval_datasets.load_eval_dataset( + data_config +) +``` + + + + + + +Loads evaluation dataset. + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.__all__ = ['AIMEDataset', 'GPQADataset', 'LocalMathDataset', 'MathDataset', 'MMLUDataset',... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx new file mode 100644 index 0000000..155c936 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx @@ -0,0 +1,64 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime +title: nemo_rl.data.datasets.eval_datasets.aime +--- + +AIME dataset. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AIMEDataset`](#nemo_rl-data-datasets-eval_datasets-aime-AIMEDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.aime.AIMEDataset( + variant: typing.Literal['2024', '2025'] = '2025', + prompt_file: typing.Optional[str] = None, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.aime.AIMEDataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx new file mode 100644 index 0000000..d1ca3a9 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx @@ -0,0 +1,64 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa +title: nemo_rl.data.datasets.eval_datasets.gpqa +--- + +GPQA dataset and its variants. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`GPQADataset`](#nemo_rl-data-datasets-eval_datasets-gpqa-GPQADataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.gpqa.GPQADataset( + variant: typing.Literal['diamond', 'main'] = 'diamond', + prompt_file: typing.Optional[str] = None, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.gpqa.GPQADataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx new file mode 100644 index 0000000..e6d6754 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx @@ -0,0 +1,65 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset +title: nemo_rl.data.datasets.eval_datasets.local_math_dataset +--- + +Local math dataset. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LocalMathDataset`](#nemo_rl-data-datasets-eval_datasets-local_math_dataset-LocalMathDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.local_math_dataset.LocalMathDataset( + data_path: str, + problem_key: str, + solution_key: str, + split: typing.Optional[str] = None, + file_format: typing.Literal['csv', 'json'] = 'csv', + prompt_file: typing.Optional[str] = None, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.local_math_dataset.LocalMathDataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx new file mode 100644 index 0000000..c00f375 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx @@ -0,0 +1,61 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/math +title: nemo_rl.data.datasets.eval_datasets.math +--- + +Math dataset and its variants. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MathDataset`](#nemo_rl-data-datasets-eval_datasets-math-MathDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.math.MathDataset( + variant: typing.Literal['math_test', 'math_500_test'] = 'math_test', + prompt_file: typing.Optional[str] = None, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.math.MathDataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx new file mode 100644 index 0000000..1114133 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx @@ -0,0 +1,61 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu +title: nemo_rl.data.datasets.eval_datasets.mmlu +--- + +MMLU dataset and its variants. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MMLUDataset`](#nemo_rl-data-datasets-eval_datasets-mmlu-MMLUDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.mmlu.MMLUDataset( + language: typing.Literal['AR-XY', 'BN-BD', 'DE-DE', 'EN-US', 'ES-LA', 'FR-FR', 'HI-IN', 'ID-ID', 'IT-IT', 'JA-JP', 'KO-KR', 'PT-BR', 'ZH-CN', 'SW-KE', 'YO-NG'] = 'EN-US', + prompt_file: typing.Optional[str] = None, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.mmlu.MMLUDataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx new file mode 100644 index 0000000..998a593 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx @@ -0,0 +1,60 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro +title: nemo_rl.data.datasets.eval_datasets.mmlu_pro +--- + +MMLU-Pro dataset. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MMLUProDataset`](#nemo_rl-data-datasets-eval_datasets-mmlu_pro-MMLUProDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.mmlu_pro.MMLUProDataset( + prompt_file: str, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.mmlu_pro.MMLUProDataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx new file mode 100644 index 0000000..1b101aa --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx @@ -0,0 +1,72 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/preference_datasets +title: nemo_rl.data.datasets.preference_datasets +--- + +## Submodules + +- **[`nemo_rl.data.datasets.preference_datasets.binary_preference_dataset`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset)** +- **[`nemo_rl.data.datasets.preference_datasets.helpsteer3`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3)** +- **[`nemo_rl.data.datasets.preference_datasets.preference_dataset`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset)** +- **[`nemo_rl.data.datasets.preference_datasets.tulu3`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`load_preference_dataset`](#nemo_rl-data-datasets-preference_datasets-load_preference_dataset) | Loads preference dataset. | + +### Data + +[`DATASET_REGISTRY`](#nemo_rl-data-datasets-preference_datasets-DATASET_REGISTRY) + +[`__all__`](#nemo_rl-data-datasets-preference_datasets-__all__) + +### API + + + + + +```python +nemo_rl.data.datasets.preference_datasets.load_preference_dataset( + data_config: nemo_rl.data.PreferenceDatasetConfig +) +``` + + + + + + +Loads preference dataset. + + + + + + + + +```python +nemo_rl.data.datasets.preference_datasets.DATASET_REGISTRY = {'HelpSteer3': HelpSteer3Dataset, 'Tulu3Preference': Tulu3PreferenceDataset, 'Bi... +``` + + + + + + + + + +```python +nemo_rl.data.datasets.preference_datasets.__all__ = ['BinaryPreferenceDataset', 'HelpSteer3Dataset', 'PreferenceDataset', 'Tulu3Pref... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx new file mode 100644 index 0000000..762ddd7 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx @@ -0,0 +1,102 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset +title: nemo_rl.data.datasets.preference_datasets.binary_preference_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BinaryPreferenceDataset`](#nemo_rl-data-datasets-preference_datasets-binary_preference_dataset-BinaryPreferenceDataset) | Dataset class for binary preference data which can be loaded from a JSON file. | + +### API + + + + + +```python +class nemo_rl.data.datasets.preference_datasets.binary_preference_dataset.BinaryPreferenceDataset( + data_path: str, + prompt_key: str = 'prompt', + chosen_key: str = 'chosen', + rejected_key: str = 'rejected', + subset: typing.Optional[str] = None, + split: typing.Optional[str] = None, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Dataset class for binary preference data which can be loaded from a JSON file. + +This class handles loading of preference data for DPO and RM training. +It will be converted to the format of PreferenceDataset through the `to_preference_data_format` function. + +The input JSONL files should contain valid JSON objects formatted like this: +{ + prompt_key: str, # The input prompt/context + chosen_key: str, # The preferred/winning response + rejected_key: str, # The non-preferred/losing response +} +Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/dpo.md#datasets for more details. + +**Parameters:** + + +Path to the dataset JSON file + + + +Key for the input prompt/context, default is "prompt" + + + +Key for the preferred/winning response, default is "chosen" + + + +Key for the non-preferred/losing response, default is "rejected" + + + +Optional subset name for the dataset, used for HuggingFace datasets + + + +Optional split name for the dataset, used for HuggingFace datasets + + + + + + + + + + + + +```python +nemo_rl.data.datasets.preference_datasets.binary_preference_dataset.BinaryPreferenceDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx new file mode 100644 index 0000000..a88c29e --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx @@ -0,0 +1,66 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3 +title: nemo_rl.data.datasets.preference_datasets.helpsteer3 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`HelpSteer3Dataset`](#nemo_rl-data-datasets-preference_datasets-helpsteer3-HelpSteer3Dataset) | HelpSteer3 preference dataset for DPO training. | + +### API + + + + + +```python +class nemo_rl.data.datasets.preference_datasets.helpsteer3.HelpSteer3Dataset( + split: str = 'train', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +HelpSteer3 preference dataset for DPO training. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + + + + + + + + + + +```python +nemo_rl.data.datasets.preference_datasets.helpsteer3.HelpSteer3Dataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx new file mode 100644 index 0000000..0264eec --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx @@ -0,0 +1,77 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset +title: nemo_rl.data.datasets.preference_datasets.preference_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`PreferenceDataset`](#nemo_rl-data-datasets-preference_datasets-preference_dataset-PreferenceDataset) | Dataset class for preference data which can be loaded from a JSON file. | + +### API + + + + + +```python +class nemo_rl.data.datasets.preference_datasets.preference_dataset.PreferenceDataset( + data_path: str, + subset: typing.Optional[str] = None, + split: typing.Optional[str] = None, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Dataset class for preference data which can be loaded from a JSON file. + +This class handles loading of preference data for DPO and RM training. +The input JSONL files should contain valid JSON objects formatted like this: +{ + "context": list[dict], # The prompt message (including previous turns, if any) + "completions": [ # The list of completions + { + "rank": 0, # The rank of the completion (lower rank is preferred) + "completion": list[dict], # The completion message(s) + }, + { + "rank": 1, # The rank of the completion (lower rank is preferred) + "completion": list[dict], # The completion message(s) + }, + ... # More completions can be added if needed + ] +} +Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/dpo.md#datasets for more details. + +**Parameters:** + + +Path to the dataset JSON file + + + +Optional subset name for the dataset, used for HuggingFace datasets + + + +Optional split name for the dataset, used for HuggingFace datasets + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx new file mode 100644 index 0000000..0a7c89c --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx @@ -0,0 +1,59 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3 +title: nemo_rl.data.datasets.preference_datasets.tulu3 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Tulu3PreferenceDataset`](#nemo_rl-data-datasets-preference_datasets-tulu3-Tulu3PreferenceDataset) | Tulu3 preference dataset for DPO training. | + +### API + + + + + +```python +class nemo_rl.data.datasets.preference_datasets.tulu3.Tulu3PreferenceDataset( + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Tulu3 preference dataset for DPO training. + + + + + + + + + + + +```python +nemo_rl.data.datasets.preference_datasets.tulu3.Tulu3PreferenceDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx new file mode 100644 index 0000000..130991c --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx @@ -0,0 +1,135 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/processed_dataset +title: nemo_rl.data.datasets.processed_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AllTaskProcessedDataset`](#nemo_rl-data-datasets-processed_dataset-AllTaskProcessedDataset) | Dataset for processing single or multi-task data with task-specific tokenization and processing. | + +### Data + +[`TokenizerType`](#nemo_rl-data-datasets-processed_dataset-TokenizerType) + +### API + + + + + +```python +class nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset( + dataset: datasets.Dataset | typing.Any, + tokenizer: nemo_rl.data.datasets.processed_dataset.TokenizerType, + default_task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + task_data_processors: dict[str, tuple[nemo_rl.data.interfaces.TaskDataSpec, nemo_rl.data.interfaces.TaskDataProcessFnCallable]] | nemo_rl.data.interfaces.TaskDataProcessFnCallable, + max_seq_length: typing.Optional[int] = None +) +``` + + + + + + +Dataset for processing single or multi-task data with task-specific tokenization and processing. + +**Parameters:** + + +Input dataset containing raw data + + + +Tokenizer for text processing + + + +Default task processing specifications. +In the case of single-task, this is the spec used for processing all entries. +In the case of multi-task, any values not specified in the task-specific specs will be taken from the default spec. + + + +Either a single TaskDataProcessFnCallable for single-task, +or a dict mapping task names to (TaskDataSpec, TaskDataProcessFnCallable) for multi-task + + + +Maximum sequence length for tokenized outputs + + + + + + + +```python +nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset.__getitem__( + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Return a single prompt. + + + + + + + +```python +nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset.__len__() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset.encode_single( + text: typing.Union[str, list[str]] +) -> tuple[list[int] | torch.Tensor, int] +``` + + + + + + +Takes either a single string or a list of strings that represent multiple turns for the same conversation. + +Returns a single (concatenated) list of tokenized ids and the length of the tokenized ids. + + + + + + + + + +```python +nemo_rl.data.datasets.processed_dataset.TokenizerType = Union[PreTrainedTokenizerBase, AutoProcessor] +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx new file mode 100644 index 0000000..af7a37b --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx @@ -0,0 +1,94 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/raw_dataset +title: nemo_rl.data.datasets.raw_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RawDataset`](#nemo_rl-data-datasets-raw_dataset-RawDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.raw_dataset.RawDataset() +``` + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.raw_dataset.RawDataset.set_processor() +``` + + + + + + + + + + + + +```python +nemo_rl.data.datasets.raw_dataset.RawDataset.set_task_spec( + data_config: nemo_rl.data.ResponseDatasetConfig | nemo_rl.data.PreferenceDatasetConfig +) +``` + + + + + + + + + + + + +```python +nemo_rl.data.datasets.raw_dataset.RawDataset.split_train_validation( + test_size: float, + seed: int +) +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx new file mode 100644 index 0000000..3892488 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx @@ -0,0 +1,82 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets +title: nemo_rl.data.datasets.response_datasets +--- + +## Submodules + +- **[`nemo_rl.data.datasets.response_datasets.aime24`](/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24)** +- **[`nemo_rl.data.datasets.response_datasets.clevr`](/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr)** +- **[`nemo_rl.data.datasets.response_datasets.dapo_math`](/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math)** +- **[`nemo_rl.data.datasets.response_datasets.deepscaler`](/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler)** +- **[`nemo_rl.data.datasets.response_datasets.geometry3k`](/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k)** +- **[`nemo_rl.data.datasets.response_datasets.helpsteer3`](/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3)** +- **[`nemo_rl.data.datasets.response_datasets.nemogym_dataset`](/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset)** +- **[`nemo_rl.data.datasets.response_datasets.oai_format_dataset`](/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset)** +- **[`nemo_rl.data.datasets.response_datasets.oasst`](/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst)** +- **[`nemo_rl.data.datasets.response_datasets.openmathinstruct2`](/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2)** +- **[`nemo_rl.data.datasets.response_datasets.refcoco`](/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco)** +- **[`nemo_rl.data.datasets.response_datasets.response_dataset`](/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset)** +- **[`nemo_rl.data.datasets.response_datasets.squad`](/nemo-rl/nemo_rl/data/datasets/response_datasets/squad)** +- **[`nemo_rl.data.datasets.response_datasets.tulu3`](/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`load_response_dataset`](#nemo_rl-data-datasets-response_datasets-load_response_dataset) | Loads response dataset. | + +### Data + +[`DATASET_REGISTRY`](#nemo_rl-data-datasets-response_datasets-DATASET_REGISTRY) + +[`__all__`](#nemo_rl-data-datasets-response_datasets-__all__) + +### API + + + + + +```python +nemo_rl.data.datasets.response_datasets.load_response_dataset( + data_config: nemo_rl.data.ResponseDatasetConfig +) +``` + + + + + + +Loads response dataset. + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.DATASET_REGISTRY = {'AIME2024': AIME2024Dataset, 'clevr-cogent': CLEVRCoGenTDataset, 'DAPOMath17K':... +``` + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.__all__ = ['AIME2024Dataset', 'CLEVRCoGenTDataset', 'DAPOMath17KDataset', 'DAPOMathAIME202... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx new file mode 100644 index 0000000..334fee4 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx @@ -0,0 +1,66 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24 +title: nemo_rl.data.datasets.response_datasets.aime24 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AIME2024Dataset`](#nemo_rl-data-datasets-response_datasets-aime24-AIME2024Dataset) | Simple wrapper around the AIME2024 dataset with train split. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.aime24.AIME2024Dataset( + repeat: int = 16, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the AIME2024 dataset with train split. + +**Parameters:** + + +Number of times to repeat the dataset, default is 16 + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.aime24.AIME2024Dataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx new file mode 100644 index 0000000..2bf7236 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx @@ -0,0 +1,97 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr +title: nemo_rl.data.datasets.response_datasets.clevr +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CLEVRCoGenTDataset`](#nemo_rl-data-datasets-response_datasets-clevr-CLEVRCoGenTDataset) | Simple wrapper around the CLEVR-CoGenT dataset. | + +### Functions + +| Name | Description | +|------|-------------| +| [`format_answer_fromtags`](#nemo_rl-data-datasets-response_datasets-clevr-format_answer_fromtags) | Extract content between <answer> tags and strip whitespace. | +| [`format_clevr_cogent_dataset`](#nemo_rl-data-datasets-response_datasets-clevr-format_clevr_cogent_dataset) | Format the CLEVR-CoGenT dataset into an OpenAI-API-like message log. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.clevr.CLEVRCoGenTDataset( + split: str = 'train', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the CLEVR-CoGenT dataset. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.clevr.format_answer_fromtags( + answer: str +) -> str +``` + + + + + + +Extract content between <answer> tags and strip whitespace. + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.clevr.format_clevr_cogent_dataset( + example: dict[str, typing.Any], + return_pil: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Format the CLEVR-CoGenT dataset into an OpenAI-API-like message log. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx new file mode 100644 index 0000000..6866067 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx @@ -0,0 +1,84 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math +title: nemo_rl.data.datasets.response_datasets.dapo_math +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DAPOMath17KDataset`](#nemo_rl-data-datasets-response_datasets-dapo_math-DAPOMath17KDataset) | Simple wrapper around the DAPO Math 17K dataset with train split. | +| [`DAPOMathAIME2024Dataset`](#nemo_rl-data-datasets-response_datasets-dapo_math-DAPOMathAIME2024Dataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.dapo_math.DAPOMath17KDataset( + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the DAPO Math 17K dataset with train split. + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.dapo_math.DAPOMath17KDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + + + + + + +```python +class nemo_rl.data.datasets.response_datasets.dapo_math.DAPOMathAIME2024Dataset( + kwargs = {} +) +``` + + + + + + +**Bases:** [DAPOMath17KDataset](#nemo_rl-data-datasets-response_datasets-dapo_math-DAPOMath17KDataset) + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx new file mode 100644 index 0000000..e1a6e7d --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx @@ -0,0 +1,59 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler +title: nemo_rl.data.datasets.response_datasets.deepscaler +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DeepScalerDataset`](#nemo_rl-data-datasets-response_datasets-deepscaler-DeepScalerDataset) | Simple wrapper around the DeepScaler dataset with train split. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.deepscaler.DeepScalerDataset( + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the DeepScaler dataset with train split. + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.deepscaler.DeepScalerDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx new file mode 100644 index 0000000..a4be5a2 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx @@ -0,0 +1,76 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k +title: nemo_rl.data.datasets.response_datasets.geometry3k +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Geometry3KDataset`](#nemo_rl-data-datasets-response_datasets-geometry3k-Geometry3KDataset) | Simple wrapper around the Geometry3K dataset. | + +### Functions + +| Name | Description | +|------|-------------| +| [`format_geometry3k_dataset`](#nemo_rl-data-datasets-response_datasets-geometry3k-format_geometry3k_dataset) | Format the Geometry3K dataset into an OpenAI-API-like message log. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.geometry3k.Geometry3KDataset( + split: str = 'train', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the Geometry3K dataset. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.geometry3k.format_geometry3k_dataset( + example: dict[str, typing.Any], + return_pil: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Format the Geometry3K dataset into an OpenAI-API-like message log. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx new file mode 100644 index 0000000..7176bf4 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx @@ -0,0 +1,66 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3 +title: nemo_rl.data.datasets.response_datasets.helpsteer3 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`HelpSteer3Dataset`](#nemo_rl-data-datasets-response_datasets-helpsteer3-HelpSteer3Dataset) | Simple wrapper around the HelpSteer3 dataset with preference subset. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.helpsteer3.HelpSteer3Dataset( + split: str = 'train', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the HelpSteer3 dataset with preference subset. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.helpsteer3.HelpSteer3Dataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx new file mode 100644 index 0000000..54915b0 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx @@ -0,0 +1,54 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset +title: nemo_rl.data.datasets.response_datasets.nemogym_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`NemoGymDataset`](#nemo_rl-data-datasets-response_datasets-nemogym_dataset-NemoGymDataset) | Simple wrapper around the Nemo Gym dataset. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.nemogym_dataset.NemoGymDataset( + data_path: str, + repeat: int = 1, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the Nemo Gym dataset. + +**Parameters:** + + +Path to the dataset JSONL file + + + +Number of times to repeat the dataset, default is 1 + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx new file mode 100644 index 0000000..f1c75e2 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx @@ -0,0 +1,214 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset +title: nemo_rl.data.datasets.response_datasets.oai_format_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`OpenAIFormatDataset`](#nemo_rl-data-datasets-response_datasets-oai_format_dataset-OpenAIFormatDataset) | This class is used to load an SFT dataset in the OpenAI format. | +| [`PreservingDataset`](#nemo_rl-data-datasets-response_datasets-oai_format_dataset-PreservingDataset) | A dataset wrapper that preserves original dict structure without None-filling. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.oai_format_dataset.OpenAIFormatDataset( + data_path: str, + chat_key: str = 'messages', + system_key: str | None = None, + system_prompt: str | None = None, + tool_key: str | None = 'tools', + use_preserving_dataset: bool = False, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +This class is used to load an SFT dataset in the OpenAI format. + +The dataset should be in the following format: +{ + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."} + ] +} +Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#openai-format-datasets-with-tool-calling-support for more details. + +**Parameters:** + + +Path to the dataset JSON file + + + +Key for the messages list in the dataset (default: "messages") + + + +Optional key for system prompt in the dataset + + + +Optional system prompt to add if not in the dataset + + + +Key for tools in the dataset (default: "tools") + + + +If True, uses PreservingDataset to maintain +heterogeneous schemas (e.g., for tool calls with varying argument +structures). If False, uses standard HuggingFace dataset loading. +Default is False for backward compatibility. + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oai_format_dataset.OpenAIFormatDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + + + + + + +```python +class nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset( + data: list[dict[str, typing.Any]] +) +``` + + + + + + +A dataset wrapper that preserves original dict structure without None-filling. + +Unlike HuggingFace's Dataset class which enforces schema uniformity across all samples +(filling missing keys with None), this class maintains the exact structure of each sample. +This is critical for heterogeneous data like tool calls where different samples may have +different argument structures. + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.__getitem__( + idx: typing.Union[int, slice, list] +) -> typing.Union[dict[str, typing.Any], list[dict[str, typing.Any]]] +``` + + + + + + +Support integer indexing, slicing, and list indexing. + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.__iter__() +``` + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.__len__() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.map( + function: typing.Callable, + args = (), + kwargs = {} +) -> nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset +``` + + + + + + +Apply a function to each sample in the dataset. + +**Parameters:** + + +Function to apply to each sample + + + +If True, pass index as second argument to function + + +**Returns:** `PreservingDataset` + +New PreservingDataset with transformed samples + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx new file mode 100644 index 0000000..8a92408 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx @@ -0,0 +1,127 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst +title: nemo_rl.data.datasets.response_datasets.oasst +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`OasstDataset`](#nemo_rl-data-datasets-response_datasets-oasst-OasstDataset) | Simple wrapper around the OASST dataset. | + +### Functions + +| Name | Description | +|------|-------------| +| [`get_data_records`](#nemo_rl-data-datasets-response_datasets-oasst-get_data_records) | - | +| [`parse_conversations`](#nemo_rl-data-datasets-response_datasets-oasst-parse_conversations) | Recusive function that returns all the sub converstaions in a list starting from node tree_obj. | + +### Data + +[`SYSTEM_PROMPT`](#nemo_rl-data-datasets-response_datasets-oasst-SYSTEM_PROMPT) + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.oasst.OasstDataset( + split_validation_size: float = 0.05, + seed: int = 42, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the OASST dataset. + +**Parameters:** + + +Size of the validation data, default is 0.05 + + + +Seed for train/validation split when split_validation_size > 0, default is 42 + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oasst.get_data_records( + objs, + task_name: str = 'oasst' +) +``` + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oasst.parse_conversations( + tree_obj, + first: bool = False +) +``` + + + + + + +Recusive function that returns all the sub converstaions in a list starting from node tree_obj. + +**Parameters:** + + +current conversation node + + +**Returns:** + +a list of sub conversation threads including the current conversation node + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oasst.SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The ass... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx new file mode 100644 index 0000000..6124a92 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx @@ -0,0 +1,84 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2 +title: nemo_rl.data.datasets.response_datasets.openmathinstruct2 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`OpenMathInstruct2Dataset`](#nemo_rl-data-datasets-response_datasets-openmathinstruct2-OpenMathInstruct2Dataset) | Simple wrapper around the OpenMathInstruct2 dataset. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.openmathinstruct2.OpenMathInstruct2Dataset( + output_key: str = 'expected_answer', + split: str = 'train_1M', + split_validation_size: float = 0.05, + seed: int = 42, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the OpenMathInstruct2 dataset. + +**Parameters:** + + +Key for the output text, default is "expected_answer" + + + +Split name for the dataset, default is "train_1M" + + + +Size of the validation data, default is 0.05 + + + +Seed for train/validation split when split_validation_size > 0, default is 42 + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.openmathinstruct2.OpenMathInstruct2Dataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx new file mode 100644 index 0000000..34e0b8d --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx @@ -0,0 +1,160 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco +title: nemo_rl.data.datasets.response_datasets.refcoco +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RefCOCODataset`](#nemo_rl-data-datasets-response_datasets-refcoco-RefCOCODataset) | Simple wrapper around the RefCOCO dataset. | + +### Functions + +| Name | Description | +|------|-------------| +| [`download_and_unzip`](#nemo_rl-data-datasets-response_datasets-refcoco-download_and_unzip) | Downloads a zip file from a given URL to a target directory and unzips it into a specified subdirectory within the target directory, showing download progress. | +| [`format_refcoco_dataset`](#nemo_rl-data-datasets-response_datasets-refcoco-format_refcoco_dataset) | Format the RefCOCO dataset from huggingface. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.refcoco.RefCOCODataset( + split: str = 'train', + download_dir: str = './coco_images', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the RefCOCO dataset. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + +Directory to download the dataset to, default is "./coco_images" + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.refcoco.RefCOCODataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.refcoco.download_and_unzip( + url: str, + target_directory: str, + subdir_name: str = '.' +) +``` + + + + + + +Downloads a zip file from a given URL to a target directory and unzips it into a specified subdirectory within the target directory, showing download progress. + +**Parameters:** + + +The URL of the zip file to download. + + + +The directory where the zip file will be downloaded + and unzipped. + + + +The name of the subdirectory within the target_directory + where the contents of the zip file will be unzipped. + Defaults to "train". + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.refcoco.format_refcoco_dataset( + example: dict[str, typing.Any], + width: int = 256, + height: int = 256, + caption_type: str = 'random' +) -> dict[str, typing.Any] +``` + + + + + + +Format the RefCOCO dataset from huggingface. + +This should be replaced with our own curated RefCOCO/+/g dataset soon + +**Parameters:** + + +The example to format. + + + +The width of the resized image. + + + +The height of the resized image. + + + +The type of caption to use. + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx new file mode 100644 index 0000000..09bbf07 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx @@ -0,0 +1,104 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset +title: nemo_rl.data.datasets.response_datasets.response_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ResponseDataset`](#nemo_rl-data-datasets-response_datasets-response_dataset-ResponseDataset) | Dataset class for response data which can be loaded from a JSON file. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.response_dataset.ResponseDataset( + data_path: str, + input_key: str = 'input', + output_key: str = 'output', + subset: typing.Optional[str] = None, + split: typing.Optional[str] = None, + split_validation_size: float = 0, + seed: int = 42, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Dataset class for response data which can be loaded from a JSON file. + +This class handles loading of response data for SFT and RL training. +The input JSONL files should contain valid JSON objects formatted like this: +{ + input_key: str, # The input prompt/context + output_key: str, # The output response/answer +} +Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details. + +**Parameters:** + + +Path to the dataset JSON file + + + +Key for the input text, default is "input" + + + +Key for the output text, default is "output" + + + +Optional subset name for the dataset, used for HuggingFace datasets + + + +Optional split name for the dataset, used for HuggingFace datasets + + + +Size of the validation data, default is 0 + + + +Seed for train/validation split when split_validation_size > 0, default is 42 + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.response_dataset.ResponseDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx new file mode 100644 index 0000000..f201b41 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx @@ -0,0 +1,66 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/squad +title: nemo_rl.data.datasets.response_datasets.squad +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`SquadDataset`](#nemo_rl-data-datasets-response_datasets-squad-SquadDataset) | Simple wrapper around the squad dataset. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.squad.SquadDataset( + split: str = 'train', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the squad dataset. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.squad.SquadDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx new file mode 100644 index 0000000..68dfa11 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx @@ -0,0 +1,76 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3 +title: nemo_rl.data.datasets.response_datasets.tulu3 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Tulu3SftMixtureDataset`](#nemo_rl-data-datasets-response_datasets-tulu3-Tulu3SftMixtureDataset) | Simple wrapper around the Tulu3 SFT mixture dataset with train split. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.tulu3.Tulu3SftMixtureDataset( + split_validation_size: float = 0.05, + seed: int = 42, + max_samples: int | None = None, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the Tulu3 SFT mixture dataset with train split. + +**Parameters:** + + +Size of the validation data, default is 0.05 + + + +Seed for train/validation split when split_validation_size > 0, default is 42 + + + +Optional maximum number of samples to use from the dataset + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.tulu3.Tulu3SftMixtureDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx new file mode 100644 index 0000000..d5a02db --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx @@ -0,0 +1,191 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/utils +title: nemo_rl.data.datasets.utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`assert_no_double_bos`](#nemo_rl-data-datasets-utils-assert_no_double_bos) | Assert that there are no double starting BOS tokens in the message. | +| [`extract_necessary_env_names`](#nemo_rl-data-datasets-utils-extract_necessary_env_names) | Extract the necessary environment names from the data config. | +| [`load_dataset_from_path`](#nemo_rl-data-datasets-utils-load_dataset_from_path) | Load a dataset from a local file, huggingface dataset, or Arrow dataset (saved with save_to_disk). | +| [`pil_to_base64`](#nemo_rl-data-datasets-utils-pil_to_base64) | Converts a PIL Image object to a base64 encoded string. | +| [`update_single_dataset_config`](#nemo_rl-data-datasets-utils-update_single_dataset_config) | Fill the single dataset config with default dataset config. | + +### Data + +[`TokenizerType`](#nemo_rl-data-datasets-utils-TokenizerType) + +### API + + + + + +```python +nemo_rl.data.datasets.utils.assert_no_double_bos( + token_ids: torch.Tensor, + tokenizer: nemo_rl.data.datasets.utils.TokenizerType +) -> None +``` + + + + + + +Assert that there are no double starting BOS tokens in the message. + +**Parameters:** + + +List of token IDs + + + +Tokenizer + + + + + + + + + +```python +nemo_rl.data.datasets.utils.extract_necessary_env_names( + data_config: dict +) -> list[str] +``` + + + + + + +Extract the necessary environment names from the data config. + +Some environments are set in env_configs but not used in the data config. +This function extracts the necessary environment names from the data config. + +**Parameters:** + + +The data config. + + +**Returns:** `list[str]` + +The necessary environment names. + + + + + + + + +```python +nemo_rl.data.datasets.utils.load_dataset_from_path( + data_path: str, + data_subset: typing.Optional[str] = None, + data_split: typing.Optional[str] = 'train' +) +``` + + + + + + +Load a dataset from a local file, huggingface dataset, or Arrow dataset (saved with save_to_disk). + +**Parameters:** + + +The path to the dataset. + + + +The subset to load from the dataset. Only supported for huggingface datasets. + + + +The split to load from the dataset. + + + + + + + + + +```python +nemo_rl.data.datasets.utils.pil_to_base64( + image: PIL.Image.Image, + format: str = 'PNG' +) -> str +``` + + + + + + +Converts a PIL Image object to a base64 encoded string. + +**Parameters:** + + +The PIL Image object to convert. + + + +The image format (e.g., "PNG", "JPEG"). Defaults to "PNG". + + +**Returns:** `str` + +A base64 encoded string representation of the image. + + + + + + + + +```python +nemo_rl.data.datasets.utils.update_single_dataset_config( + data_config: dict, + default_data_config: dict +) -> None +``` + + + + + + +Fill the single dataset config with default dataset config. + + + + + + + + +```python +nemo_rl.data.datasets.utils.TokenizerType = Union[PreTrainedTokenizerBase, AutoProcessor] +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx new file mode 100644 index 0000000..b435173 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx @@ -0,0 +1,284 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/interfaces +title: nemo_rl.data.interfaces +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DatumSpec`](#nemo_rl-data-interfaces-DatumSpec) | - | +| [`PreferenceDatumSpec`](#nemo_rl-data-interfaces-PreferenceDatumSpec) | - | +| [`TaskDataProcessFnCallable`](#nemo_rl-data-interfaces-TaskDataProcessFnCallable) | A callable that processes a loaded datum dictionary into a DatumSpec. | +| [`TaskDataSpec`](#nemo_rl-data-interfaces-TaskDataSpec) | - | + +### Data + +[`FlatMessagesType`](#nemo_rl-data-interfaces-FlatMessagesType) + +[`LLMMessageLogType`](#nemo_rl-data-interfaces-LLMMessageLogType) + +[`PathLike`](#nemo_rl-data-interfaces-PathLike) + +[`TokenizerType`](#nemo_rl-data-interfaces-TokenizerType) + +[`VLMMessageLogType`](#nemo_rl-data-interfaces-VLMMessageLogType) + +### API + + + + + +```python +class nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.interfaces.PreferenceDatumSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.interfaces.TaskDataProcessFnCallable() +``` + + + + + + +Protocol + +A callable that processes a loaded datum dictionary into a DatumSpec. + + + + + + +```python +nemo_rl.data.interfaces.TaskDataProcessFnCallable.__call__( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.interfaces.TokenizerType, + max_seq_length: int | None, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + + + + + + + + + +```python +class nemo_rl.data.interfaces.TaskDataSpec( + task_name: typing.Optional[str] = None, + prompt_file: typing.Optional[nemo_rl.data.interfaces.PathLike] = None, + system_prompt_file: typing.Optional[nemo_rl.data.interfaces.PathLike] = None +) +``` + + + + + + +Dataclass + + + + + + + + + + + + + +```python +nemo_rl.data.interfaces.TaskDataSpec.__post_init__() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.data.interfaces.TaskDataSpec.copy_defaults( + from_spec: nemo_rl.data.interfaces.TaskDataSpec +) -> None +``` + + + + + + +Apply default values from another Task instance for any None attributes. + + + + + + + + + +```python +nemo_rl.data.interfaces.FlatMessagesType = dict[str, Union[list[str], torch.Tensor]] +``` + + + + + + + + + +```python +nemo_rl.data.interfaces.LLMMessageLogType = list[dict[str, Union[str, torch.Tensor]]] +``` + + + + + + + + + +```python +nemo_rl.data.interfaces.PathLike = Union[str, 'os.PathLike[Any]'] +``` + + + + + + + + + +```python +nemo_rl.data.interfaces.TokenizerType = PreTrainedTokenizerBase +``` + + + + + + + + + +```python +nemo_rl.data.interfaces.VLMMessageLogType = list[dict[str, Union[str, torch.Tensor, PackedTensor]]] +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx new file mode 100644 index 0000000..49a124a --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx @@ -0,0 +1,548 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/llm_message_utils +title: nemo_rl.data.llm_message_utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_pad_tensor`](#nemo_rl-data-llm_message_utils-_pad_tensor) | Pad a tensor to the specified length. | +| [`_validate_tensor_consistency`](#nemo_rl-data-llm_message_utils-_validate_tensor_consistency) | Validate that all tensors have consistent dtypes and devices. | +| [`add_loss_mask_to_message_log`](#nemo_rl-data-llm_message_utils-add_loss_mask_to_message_log) | Add token-level loss masks to each message in a message log. | +| [`batched_message_log_to_flat_message`](#nemo_rl-data-llm_message_utils-batched_message_log_to_flat_message) | Process and pad a batch of message logs for model input. | +| [`get_first_index_that_differs`](#nemo_rl-data-llm_message_utils-get_first_index_that_differs) | Get the first index that differs between two strings. | +| [`get_formatted_message_log`](#nemo_rl-data-llm_message_utils-get_formatted_message_log) | Format and tokenize chat messages using the specified template. | +| [`get_images_from_message`](#nemo_rl-data-llm_message_utils-get_images_from_message) | Get all images from a message log item. | +| [`get_keys_from_message_log`](#nemo_rl-data-llm_message_utils-get_keys_from_message_log) | Return a new LLMMessageLogType containing only the specified keys from each message. | +| [`message_log_shape`](#nemo_rl-data-llm_message_utils-message_log_shape) | Get the shape of the tensors in the message log. | +| [`message_log_to_flat_messages`](#nemo_rl-data-llm_message_utils-message_log_to_flat_messages) | Converts a message log (sequence of message turns) into a flattened representation. | +| [`remap_dataset_keys`](#nemo_rl-data-llm_message_utils-remap_dataset_keys) | Remap dataset keys as per mapping. | + +### Data + +[`Tensor`](#nemo_rl-data-llm_message_utils-Tensor) + +[`TokenizerType`](#nemo_rl-data-llm_message_utils-TokenizerType) + +### API + + + + + +```python +nemo_rl.data.llm_message_utils._pad_tensor( + tensor: nemo_rl.data.llm_message_utils.Tensor, + max_len: int, + pad_side: str, + pad_value: int = 0 +) -> nemo_rl.data.llm_message_utils.Tensor +``` + + + + + + +Pad a tensor to the specified length. + +**Parameters:** + + +Tensor to pad + + + +Length to pad to + + + +Whether to pad on the 'left' or 'right' + + + +Value to use for padding + + +**Returns:** `Tensor` + +torch.Tensor: Padded tensor + + + + + + + + +```python +nemo_rl.data.llm_message_utils._validate_tensor_consistency( + tensors: list[nemo_rl.data.llm_message_utils.Tensor] +) -> None +``` + + + + + + +Validate that all tensors have consistent dtypes and devices. + +**Parameters:** + + +List of tensors to validate + + +**Raises:** + +- `RuntimeError`: If tensors have different dtypes or devices + + + + + + + + +```python +nemo_rl.data.llm_message_utils.add_loss_mask_to_message_log( + batch_message_log: list[nemo_rl.data.interfaces.LLMMessageLogType], + roles_to_train_on: list[str] = ['assistant'], + only_unmask_final: bool = False +) -> None +``` + + + + + + +Add token-level loss masks to each message in a message log. + +**Parameters:** + + +List of message dictionaries containing token IDs and metadata + + + +List of strings indicating which speakers to unmask. Default: ["assistant"] + + + +If True, only unmask the final message in the log. Default: False + + + + + + + + + +```python +nemo_rl.data.llm_message_utils.batched_message_log_to_flat_message( + message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], + pad_value_dict: typing.Optional[dict[str, int]] = None, + make_sequence_length_divisible_by: int = 1 +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.FlatMessagesType], nemo_rl.data.llm_message_utils.Tensor] +``` + + + + + + +Process and pad a batch of message logs for model input. + +For each message log in the batch: +1. Converts it to a flat representation using message_log_to_flat_messages +2. Pads all resulting tensors to the same length for batching +3. Returns a BatchedDataDict and sequence lengths tensor + +Padding is always applied to the right side of sequences. + +Examples: + + +```python +>>> import torch +>>> from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message +>>> from nemo_rl.distributed.batched_data_dict import BatchedDataDict +>>> # Create a batch of two message logs with different lengths +>>> message_log_batch = [ +... # First conversation +... [ +... {'role': 'user', 'content': 'What is 2+2?', 'token_ids': torch.tensor([1, 2, 3, 4, 5])}, +... {'role': 'assistant', 'content': '4', 'token_ids': torch.tensor([6, 7])} +... ], +... # Second conversation +... [ +... {'role': 'user', 'content': 'Solve x+10=15', 'token_ids': torch.tensor([1, 8, 9, 10, 11, 12])}, +... {'role': 'assistant', 'content': 'x=5', 'token_ids': torch.tensor([13, 14, 15])} +... ] +... ] +>>> pad_value_dict = {'token_ids': 0} +>>> batched_flat, input_lengths = batched_message_log_to_flat_message(message_log_batch, pad_value_dict) +>>> batched_flat['token_ids'][0].tolist() +[1, 2, 3, 4, 5, 6, 7, 0, 0] +>>> batched_flat['token_ids'][1].tolist() +[1, 8, 9, 10, 11, 12, 13, 14, 15] +>>> batched_flat['content'][0] +['What is 2+2?', '4'] +>>> batched_flat['content'][1] +['Solve x+10=15', 'x=5'] +>>> batched_flat['role'] +[['user', 'assistant'], ['user', 'assistant']] +>>> input_lengths +tensor([7, 9], dtype=torch.int32) +>>> +>>> # Multimodal example: include images on both conversations and verify packing +>>> from nemo_rl.data.multimodal_utils import PackedTensor +>>> mm_batch = [ +... [ +... {'role': 'user', 'content': 'look', 'token_ids': torch.tensor([1, 2, 3]), 'images': PackedTensor(torch.randn(2, 3, 4, 4), dim_to_pack=0)}, +... {'role': 'assistant', 'content': 'ok', 'token_ids': torch.tensor([4])} +... ], +... [ +... {'role': 'user', 'content': 'again', 'token_ids': torch.tensor([5, 6]), 'images': PackedTensor(torch.randn(1, 3, 4, 4), dim_to_pack=0)}, +... {'role': 'assistant', 'content': 'fine', 'token_ids': torch.tensor([7, 8])} +... ] +... ] +>>> mm_flat, mm_lengths = batched_message_log_to_flat_message(mm_batch, pad_value_dict={'token_ids': 0}) +>>> isinstance(mm_flat['images'], PackedTensor) +True +>>> tuple(mm_flat['images'].as_tensor().shape) # 2 + 1 images +(3, 3, 4, 4) +>>> mm_lengths +tensor([4, 4], dtype=torch.int32) +>>> +``` + + + +**Parameters:** + + +List of LLMMessageLogType (each a conversation with multiple turns) + + + +Dictionary mapping keys to padding values (default is 0) + + + +forces the data to be divisible by this value + + +**Returns:** `BatchedDataDict[FlatMessagesType]` + +BatchedDataDict[FlatMessagesType]: Dictionary containing padded stacked tensors + +**Raises:** + +- `RuntimeError`: If tensors have different dtypes or devices + + + + + + + + +```python +nemo_rl.data.llm_message_utils.get_first_index_that_differs( + str1: str, + str2: str +) -> int +``` + + + + + + +Get the first index that differs between two strings. + + + + + + + + +```python +nemo_rl.data.llm_message_utils.get_formatted_message_log( + message_log: nemo_rl.data.interfaces.LLMMessageLogType, + tokenizer: nemo_rl.data.llm_message_utils.TokenizerType, + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + add_bos_token: bool = True, + add_eos_token: bool = True, + add_generation_prompt: bool = False, + tools: typing.Optional[list[dict[str, typing.Any]]] = None +) -> nemo_rl.data.interfaces.LLMMessageLogType +``` + + + + + + +Format and tokenize chat messages using the specified template. + +Returns: + The message log with updated 'token_ids' and 'content' fields. + +**Parameters:** + + +List of message dicts with 'role' and 'content' keys + + + +Tokenizer for converting text to token IDs + + + +Task spec for this dataset. + + + +Whether to add bos token to first message if it is not already present. Default: True + + + +Whether to add eos token to last message if it is not already present. Default: True + + + +Whether to include assistant's generation prompt in user messages. Default: False + + + +Optional list of tool/function definitions to pass to the chat template. Default: None + + + + + + + + + +```python +nemo_rl.data.llm_message_utils.get_images_from_message( + message: dict[str, typing.Any] +) -> list[typing.Any] +``` + + + + + + +Get all images from a message log item. + + + + + + + + +```python +nemo_rl.data.llm_message_utils.get_keys_from_message_log( + message_log: nemo_rl.data.interfaces.LLMMessageLogType, + keys: list[str] +) -> nemo_rl.data.interfaces.LLMMessageLogType +``` + + + + + + +Return a new LLMMessageLogType containing only the specified keys from each message. + +**Parameters:** + + +Original message log to extract keys from + + + +List of keys to keep in each message + + +**Returns:** `LLMMessageLogType` + +New list with only specified keys + + + + + + + + +```python +nemo_rl.data.llm_message_utils.message_log_shape( + message_log: nemo_rl.data.interfaces.LLMMessageLogType +) -> list[dict[str, torch.Size]] +``` + + + + + + +Get the shape of the tensors in the message log. + +This utility function examines each message in the message log and reports +the shape of tensor values or recursively processes list values. + +**Parameters:** + + +The message log to analyze + + +**Returns:** `list[dict[str, torch.Size]]` + +List of dictionaries containing tensor shapes for each key in messages + + + + + + + + +```python +nemo_rl.data.llm_message_utils.message_log_to_flat_messages( + message_log: nemo_rl.data.interfaces.LLMMessageLogType +) -> nemo_rl.data.interfaces.FlatMessagesType +``` + + + + + + +Converts a message log (sequence of message turns) into a flattened representation. + +This function takes a message log (list of dict messages with 'role', 'content', 'token_ids', etc.) +and converts it to a flat dictionary where all tensors of the same key are concatenated and +all strings of the same key are put into lists. + +Examples: + + +```python +>>> import torch +>>> from nemo_rl.data.llm_message_utils import message_log_to_flat_messages +>>> # Create a simple message log with two messages +>>> message_log = [ +... {'role': 'user', 'content': 'Hello', 'token_ids': torch.tensor([1, 2, 3])}, +... {'role': 'assistant', 'content': 'Hi there', 'token_ids': torch.tensor([4, 5, 6, 7])} +... ] +>>> flat_msgs = message_log_to_flat_messages(message_log) +>>> flat_msgs['role'] +['user', 'assistant'] +>>> flat_msgs['content'] +['Hello', 'Hi there'] +>>> flat_msgs['token_ids'] +tensor([1, 2, 3, 4, 5, 6, 7]) +>>> +>>> # Multimodal example: +>>> from nemo_rl.data.multimodal_utils import PackedTensor +>>> img1 = torch.randn(2, 3, 4, 4) +>>> img2 = torch.randn(3, 3, 4, 4) +>>> mm_log = [ +... {'role': 'user', 'content': 'see', 'token_ids': torch.tensor([1]), 'images': PackedTensor(img1, dim_to_pack=0)}, +... {'role': 'assistant', 'content': 'ok', 'token_ids': torch.tensor([2, 3]), 'images': PackedTensor(img2, dim_to_pack=0)}, +... ] +>>> flat_mm = message_log_to_flat_messages(mm_log) +>>> tuple(flat_mm['images'].as_tensor().shape) +(5, 3, 4, 4) +>>> +``` + + + +**Parameters:** + + +List of message dictionaries with 'role', 'content', and potentially 'token_ids' + + +**Returns:** `FlatMessagesType` + +Dictionary mapping keys to concatenated tensors and string lists + + + + + + + + +```python +nemo_rl.data.llm_message_utils.remap_dataset_keys( + dataset: datasets.Dataset, + mapping_dict: dict[str, str] +) -> datasets.Dataset +``` + + + + + + +Remap dataset keys as per mapping. + +**Parameters:** + + +The input dataset to remap keys in + + + +A dictionary mapping input keys to output keys + + +**Returns:** `Dataset` + +A new dataset with remapped keys + + + + + + + + +```python +nemo_rl.data.llm_message_utils.Tensor = torch.Tensor +``` + + + + + + + + + +```python +nemo_rl.data.llm_message_utils.TokenizerType = PreTrainedTokenizerBase +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx new file mode 100644 index 0000000..89f71d0 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx @@ -0,0 +1,298 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/multimodal_utils +title: nemo_rl.data.multimodal_utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`PackedTensor`](#nemo_rl-data-multimodal_utils-PackedTensor) | Wrapper around a list of torch tensors and a dimension along which to pack the tensors. | + +### Functions + +| Name | Description | +|------|-------------| +| [`get_dim_to_pack_along`](#nemo_rl-data-multimodal_utils-get_dim_to_pack_along) | Special considerations for packing certain keys from certain processors. | +| [`get_multimodal_keys_from_processor`](#nemo_rl-data-multimodal_utils-get_multimodal_keys_from_processor) | Get keys of the multimodal data that can be used as model inputs. | +| [`resolve_to_image`](#nemo_rl-data-multimodal_utils-resolve_to_image) | Resolve the image path to a PIL.Image object. | + +### API + + + + + +```python +class nemo_rl.data.multimodal_utils.PackedTensor( + tensors: typing.Union[torch.Tensor, list[typing.Optional[torch.Tensor]], list[None]], + dim_to_pack: int +) +``` + + + + + + +Wrapper around a list of torch tensors and a dimension along which to pack the tensors. + +This class is used to wrap a list of tensors along with a `dim_to_pack` parameter. +It can be used for data that can be packed along different dimensions (such as multimodal data). + +`dim_to_pack` is used to specify the dimension along which to pack the tensors. + +The list of tensors can be returned as a single packed tensor by calling `as_tensor` which will concatenate the tensors along the `dim_to_pack` dimension. + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.__len__() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.as_tensor( + device: typing.Optional[torch.device] = None +) -> typing.Optional[torch.Tensor] +``` + + + + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.concat( + from_packed_tensors: list[nemo_rl.data.multimodal_utils.PackedTensor] +) -> nemo_rl.data.multimodal_utils.PackedTensor +``` + + + + + + +classmethod + +Concatenate a list of PackedTensor objects into a single PackedTensor. + +The underlying tensors from the PackedTensors are combined into a single list of tensors and used to create a new PackedTensor. + +Each batch must have the same dim_to_pack. + +Example: + + +```python +>>> import torch +>>> from nemo_rl.data.multimodal_utils import PackedTensor +>>> p1 = PackedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])], dim_to_pack=0) +>>> p2 = PackedTensor([torch.tensor([7, 8, 9])], dim_to_pack=0) +>>> p3 = PackedTensor.concat([p1, p2]) +>>> p3.tensors +[tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])] +>>> p3.as_tensor() +tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]) +>>> +``` + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.empty_like( + other: nemo_rl.data.multimodal_utils.PackedTensor +) -> nemo_rl.data.multimodal_utils.PackedTensor +``` + + + + + + +classmethod + +Return a new PackedTensor with same length and dim_to_pack as `other`, with all entries None. + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.flattened_concat( + from_packed_tensors: list[nemo_rl.data.multimodal_utils.PackedTensor] +) -> nemo_rl.data.multimodal_utils.PackedTensor +``` + + + + + + +classmethod + +Given a list of PackedTensor objects, flattens each PackedTensor and then concatenates them into a single PackedTensor. + +Each PackedTensor is first flattened by packing along the PackedTensor's `dim_to_pack` dimension. Then, the resulting flattened tensors are used to create a new PackedTensor. + +This is different from `PackedTensor.concat` which simply extends the underlying list of tensors. This is important because the `slice` and `__len__` methods operate on the underlying list of tensors. Note, however, that calling `as_tensor` on the resulting PackedTensor will result in the same tensor as `concat`. + +Each batch must have the same dim_to_pack. + +Example: + + +```python +>>> import torch +>>> from nemo_rl.data.multimodal_utils import PackedTensor +>>> p1 = PackedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])], dim_to_pack=0) +>>> p2 = PackedTensor([torch.tensor([7, 8, 9])], dim_to_pack=0) +>>> p3 = PackedTensor.flattened_concat([p1, p2]) +>>> p3.tensors +[tensor([1, 2, 3, 4, 5, 6]), tensor([7, 8, 9])] +>>> p3.as_tensor() +tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]) +>>> +``` + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.slice( + indices: typing.Union[list[int], torch.Tensor] +) -> nemo_rl.data.multimodal_utils.PackedTensor +``` + + + + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.to( + device: str | torch.device +) -> nemo_rl.data.multimodal_utils.PackedTensor +``` + + + + + + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.get_dim_to_pack_along( + processor, + key: str +) -> int +``` + + + + + + +Special considerations for packing certain keys from certain processors. + +In most cases, the packed items are along dim 0 + + + + + + + + +```python +nemo_rl.data.multimodal_utils.get_multimodal_keys_from_processor( + processor +) -> list[str] +``` + + + + + + +Get keys of the multimodal data that can be used as model inputs. + +This will be used in the data_processor function to determine which keys to use as model inputs. + + + + + + + + +```python +nemo_rl.data.multimodal_utils.resolve_to_image( + image_path_or_image: str | PIL.Image.Image +) -> PIL.Image.Image +``` + + + + + + +Resolve the image path to a PIL.Image object. + +image_path can be either: +- path to local file +- url to image +- base64 encoded image + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx new file mode 100644 index 0000000..8161181 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx @@ -0,0 +1,30 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/packing +title: nemo_rl.data.packing +--- + +## Submodules + +- **[`nemo_rl.data.packing.algorithms`](/nemo-rl/nemo_rl/data/packing/algorithms)** +- **[`nemo_rl.data.packing.metrics`](/nemo-rl/nemo_rl/data/packing/metrics)** + +## Package Contents + +### Data + +[`__all__`](#nemo_rl-data-packing-__all__) + +### API + + + + + +```python +nemo_rl.data.packing.__all__ = ['PackingAlgorithm', 'SequencePacker', 'ConcatenativePacker', 'FirstFitDecreasin... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx new file mode 100644 index 0000000..7337f16 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx @@ -0,0 +1,791 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/packing/algorithms +title: nemo_rl.data.packing.algorithms +--- + +Sequence packing algorithms for efficient batching of variable-length sequences. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ConcatenativePacker`](#nemo_rl-data-packing-algorithms-ConcatenativePacker) | Concatenative packing algorithm. | +| [`FirstFitDecreasingPacker`](#nemo_rl-data-packing-algorithms-FirstFitDecreasingPacker) | First-Fit Decreasing (FFD) algorithm for sequence packing. | +| [`FirstFitPacker`](#nemo_rl-data-packing-algorithms-FirstFitPacker) | Base class for First-Fit algorithms. | +| [`FirstFitShufflePacker`](#nemo_rl-data-packing-algorithms-FirstFitShufflePacker) | First-Fit Shuffle algorithm for sequence packing. | +| [`ModifiedFirstFitDecreasingPacker`](#nemo_rl-data-packing-algorithms-ModifiedFirstFitDecreasingPacker) | Modified First-Fit Decreasing (MFFD) algorithm for sequence packing. | +| [`PackingAlgorithm`](#nemo_rl-data-packing-algorithms-PackingAlgorithm) | Enum for supported sequence packing algorithms. | +| [`SequencePacker`](#nemo_rl-data-packing-algorithms-SequencePacker) | Abstract base class for sequence packing algorithms. | + +### Functions + +| Name | Description | +|------|-------------| +| [`get_packer`](#nemo_rl-data-packing-algorithms-get_packer) | Factory function to get a sequence packer based on the algorithm. | + +### API + + + + + +```python +class nemo_rl.data.packing.algorithms.ConcatenativePacker() +``` + + + + + + +**Bases:** [SequencePacker](#nemo_rl-data-packing-algorithms-SequencePacker) + +Concatenative packing algorithm. + +This algorithm simply concatenates sequences in order until reaching the bin capacity, +then starts a new bin. It doesn't try to optimize the packing in any way. + +Time complexity: O(n) where n is the number of sequences. + +Example: + + +```python +>>> examples = { +... "sequence_lengths": [4, 1, 3, 2, 1, 3, 4, 5] +... } +>>> # If packed with seq_length=5: +... {"bins": [ [0, 1], [2, 3], [4, 5], [6], [7] ]} +>>> # If packed with seq_length=8: +... {"bins": [ [0, 1, 2], [3, 4, 5], [6], [7] ]} +``` + + + + + + + + + + +```python +nemo_rl.data.packing.algorithms.ConcatenativePacker._pack_implementation( + sequence_lengths: typing.List[int] +) -> typing.List[typing.List[int]] +``` + + + + + + +Pack sequences using the Concatenative algorithm. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[List[int]]` + +A list of bins, where each bin is a list of indices into the original + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.FirstFitDecreasingPacker() +``` + + + + + + +**Bases:** [FirstFitPacker](#nemo_rl-data-packing-algorithms-FirstFitPacker) + +First-Fit Decreasing (FFD) algorithm for sequence packing. + +This algorithm sorts sequences by length in descending order and then +places each sequence into the first bin where it fits. + +Time complexity: O(n log n) for sorting + O(n * m) for packing, +where n is the number of sequences and m is the number of bins. + + + + + + +```python +nemo_rl.data.packing.algorithms.FirstFitDecreasingPacker._prepare_sequences( + sequence_lengths: typing.List[int] +) -> typing.List[typing.Tuple[int, int]] +``` + + + + + + +Prepare sequences for packing by sorting them in descending order. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[Tuple[int, int]]` + +A list of (length, index) pairs sorted by length in descending order. + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.FirstFitPacker() +``` + + + + + + +**Bases:** [SequencePacker](#nemo_rl-data-packing-algorithms-SequencePacker) + +Base class for First-Fit algorithms. + +First-Fit algorithms place each sequence into the first bin where it fits. +If no bin can fit the sequence, a new bin is created. + +This is an abstract base class that provides the common implementation for +First-Fit variants. Subclasses must implement the _prepare_sequences method +to determine the order in which sequences are processed. + + + + + + +```python +nemo_rl.data.packing.algorithms.FirstFitPacker._pack_implementation( + sequence_lengths: typing.List[int] +) -> typing.List[typing.List[int]] +``` + + + + + + +Pack sequences using the First-Fit algorithm. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[List[int]]` + +A list of bins, where each bin is a list of indices into the original + + + + + + + +```python +nemo_rl.data.packing.algorithms.FirstFitPacker._prepare_sequences( + sequence_lengths: typing.List[int] +) -> typing.List[typing.Tuple[int, int]] +``` + + + + + + +Prepare sequences for packing. + +This method determines the order in which sequences are processed. +Subclasses must override this method. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[Tuple[int, int]]` + +A list of (length, index) pairs. + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.FirstFitShufflePacker() +``` + + + + + + +**Bases:** [FirstFitPacker](#nemo_rl-data-packing-algorithms-FirstFitPacker) + +First-Fit Shuffle algorithm for sequence packing. + +This algorithm randomly shuffles the sequences and then places each +sequence into the first bin where it fits. + +Time complexity: O(n * m) for packing, where n is the number of sequences +and m is the number of bins. + + + + + + +```python +nemo_rl.data.packing.algorithms.FirstFitShufflePacker._prepare_sequences( + sequence_lengths: typing.List[int] +) -> typing.List[typing.Tuple[int, int]] +``` + + + + + + +Prepare sequences for packing by randomly shuffling them. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[Tuple[int, int]]` + +A list of (length, index) pairs in random order. + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.ModifiedFirstFitDecreasingPacker() +``` + + + + + + +**Bases:** [SequencePacker](#nemo_rl-data-packing-algorithms-SequencePacker) + +Modified First-Fit Decreasing (MFFD) algorithm for sequence packing. + +This algorithm implements the Johnson & Garey (1985) Modified First-Fit-Decreasing +heuristic. It classifies items into four categories (large, medium, small, tiny) +and uses a sophisticated 5-phase packing strategy to achieve better bin utilization +than standard First-Fit Decreasing. + +The algorithm phases: +1. Classify items by size relative to bin capacity +2. Create one bin per large item +3. Add medium items to large bins (forward pass) +4. Add pairs of small items to bins with medium items (backward pass) +5. Greedily fit remaining items +6. Apply FFD to any leftovers + +Time complexity: O(n log n) for sorting + O(n * m) for packing, +where n is the number of sequences and m is the number of bins. + + + + + + +```python +nemo_rl.data.packing.algorithms.ModifiedFirstFitDecreasingPacker._classify_items( + items: typing.List[typing.Tuple[int, int]] +) -> typing.Tuple[typing.List[typing.Tuple[int, int]], typing.List[typing.Tuple[int, int]], typing.List[typing.Tuple[int, int]], typing.List[typing.Tuple[int, int]]] +``` + + + + + + +Split items into large / medium / small / tiny classes. + +Follows the classification used by Johnson & Garey: + large : (C/2, C] + medium : (C/3, C/2] + small : (C/6, C/3] + tiny : (0 , C/6] + +**Parameters:** + + +List of (index, size) tuples + + +**Returns:** `Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]` + +Tuple of four lists (large, medium, small, tiny) without additional sorting. + + + + + + + +```python +nemo_rl.data.packing.algorithms.ModifiedFirstFitDecreasingPacker._pack_implementation( + sequence_lengths: typing.List[int] +) -> typing.List[typing.List[int]] +``` + + + + + + +Pack sequences using the Modified First-Fit Decreasing algorithm. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[List[int]]` + +A list of bins, where each bin is a list of indices into the original + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.PackingAlgorithm +``` + + + + + + +**Bases:** `enum.Enum` + +Enum for supported sequence packing algorithms. + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.SequencePacker( + bin_capacity: int, + collect_metrics: bool = False, + min_bin_count: typing.Optional[int] = None, + bin_count_multiple: typing.Optional[int] = None +) +``` + + + + + + +Abstract + +Abstract base class for sequence packing algorithms. + +Sequence packing is the process of efficiently arranging sequences of different +lengths into fixed-capacity bins (batches) to maximize computational efficiency. + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker._adjust_bin_count( + bins: typing.List[typing.List[int]] +) -> typing.List[typing.List[int]] +``` + + + + + + +Adjust the number of bins to meet minimum and multiple constraints. + +This method preserves the existing bin packing as much as possible and only +moves sequences one at a time to create additional bins when needed. + +**Parameters:** + + +The original bins from the packing algorithm. + + +**Returns:** `List[List[int]]` + +Adjusted bins with minimal changes to meet constraints. + +**Raises:** + +- `ValueError`: If there aren't enough sequences to fill the required number of bins. + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker._create_indexed_lengths( + sequence_lengths: typing.List[int], + reverse: bool = False +) -> typing.List[typing.Tuple[int, int]] +``` + + + + + + +Create a list of (length, index) pairs from sequence lengths. + +**Parameters:** + + +A list of sequence lengths. + + + +Whether to sort in descending order (True) or ascending order (False). + + +**Returns:** `List[Tuple[int, int]]` + +A list of (length, index) pairs, optionally sorted. + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker._estimate_bins_needed( + sequence_lengths: typing.List[int] +) -> int +``` + + + + + + +Estimate the number of bins needed based on total length. + +**Parameters:** + + +A list of sequence lengths. + + +**Returns:** `int` + +Estimated number of bins needed. + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker._pack_implementation( + sequence_lengths: typing.List[int] +) -> typing.List[typing.List[int]] +``` + + + + + + +abstract + +Implementation of the packing algorithm. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[List[int]]` + +A list of bins, where each bin is a list of indices into the original + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker._validate_sequence_lengths( + sequence_lengths: typing.List[int] +) -> None +``` + + + + + + +Validate that all sequence lengths are within bin capacity. + +**Parameters:** + + +A list of sequence lengths to validate. + + +**Raises:** + +- `ValueError`: If any sequence length exceeds bin capacity. + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker.compute_metrics( + sequence_lengths: typing.List[int], + bins: typing.List[typing.List[int]] +) -> typing.Dict[str, float] +``` + + + + + + +Calculate metrics for a packing solution without updating the metrics tracker. + +**Parameters:** + + +List of sequence lengths + + + +List of bins, where each bin is a list of indices + + +**Returns:** `Dict[str, float]` + +Dictionary of packing metrics + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker.get_aggregated_metrics() -> typing.Dict[str, float] +``` + + + + + + +Get aggregated metrics across all packing operations. + +**Returns:** `Dict[str, float]` + +Dictionary of aggregated metrics, or empty dict if not collecting + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker.pack( + sequence_lengths: typing.List[int] +) -> typing.List[typing.List[int]] +``` + + + + + + +Pack sequences into bins and update metrics if enabled. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[List[int]]` + +A list of bins, where each bin is a list of indices into the original + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker.print_metrics() -> None +``` + + + + + + +Print the current metrics in a formatted way. + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker.reset_metrics() -> None +``` + + + + + + +Reset collected metrics. + + + + + + + + + +```python +nemo_rl.data.packing.algorithms.get_packer( + algorithm: typing.Union[nemo_rl.data.packing.algorithms.PackingAlgorithm, str], + bin_capacity: int, + collect_metrics: bool = False, + min_bin_count: typing.Optional[int] = None, + bin_count_multiple: typing.Optional[int] = None +) -> nemo_rl.data.packing.algorithms.SequencePacker +``` + + + + + + +Factory function to get a sequence packer based on the algorithm. + +**Parameters:** + + +The packing algorithm to use. Can be either a PackingAlgorithm enum value + or a string (case-insensitive) matching one of the enum names. + + + +The maximum capacity of each bin. + + + +Whether to collect metrics across multiple packing operations. + + + +Minimum number of bins to create, even if fewer would suffice. + If None, no minimum is enforced. + + + +The total number of bins must be a multiple of this value. + If None, no multiple constraint is enforced. + + +**Returns:** `SequencePacker` + +A SequencePacker instance for the specified algorithm. + +**Raises:** + +- `ValueError`: If the algorithm is not recognized. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx new file mode 100644 index 0000000..3f1b4d0 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx @@ -0,0 +1,177 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/packing/metrics +title: nemo_rl.data.packing.metrics +--- + +Metrics for evaluating sequence packing algorithms. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`PackingMetrics`](#nemo_rl-data-packing-metrics-PackingMetrics) | Class for tracking and computing metrics for sequence packing algorithms. | + +### API + + + + + +```python +class nemo_rl.data.packing.metrics.PackingMetrics() +``` + + + + + + +Class for tracking and computing metrics for sequence packing algorithms. + +This class provides methods to calculate various metrics that evaluate the +efficiency and effectiveness of sequence packing algorithms, such as bin +utilization, waste, and imbalance. + + + + + + +```python +nemo_rl.data.packing.metrics.PackingMetrics.calculate_stats_only( + sequence_lengths: typing.List[int], + bins: typing.List[typing.List[int]], + bin_capacity: int +) -> typing.Dict[str, float] +``` + + + + + + +Calculate metrics for a packing solution without updating the tracker. + +**Parameters:** + + +List of sequence lengths + + + +List of bins, where each bin is a list of indices + + + +Maximum capacity of each bin + + +**Returns:** `Dict[str, float]` + +Dictionary of metrics for this packing solution + + + + + + + +```python +nemo_rl.data.packing.metrics.PackingMetrics.get_aggregated_stats() -> typing.Dict[str, float] +``` + + + + + + +Get aggregated metrics across all packing operations. + +**Returns:** `Dict[str, float]` + +Dictionary of aggregated metrics + + + + + + + +```python +nemo_rl.data.packing.metrics.PackingMetrics.print_aggregated_stats() -> None +``` + + + + + + +Print the aggregated metrics in a formatted way. + + + + + + + +```python +nemo_rl.data.packing.metrics.PackingMetrics.reset() -> None +``` + + + + + + +Reset all metrics. + + + + + + + +```python +nemo_rl.data.packing.metrics.PackingMetrics.update( + sequence_lengths: typing.List[int], + bins: typing.List[typing.List[int]], + bin_capacity: int, + packing_time: typing.Optional[float] = None +) -> typing.Dict[str, float] +``` + + + + + + +Update metrics with a new packing solution. + +**Parameters:** + + +List of sequence lengths + + + +List of bins, where each bin is a list of indices + + + +Maximum capacity of each bin + + + +Optional time taken to compute the packing solution + + +**Returns:** `Dict[str, float]` + +Dictionary of metrics for this packing solution + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx new file mode 100644 index 0000000..7660a0d --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx @@ -0,0 +1,353 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/processors +title: nemo_rl.data.processors +--- + +Contains data processors for evaluation. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_construct_multichoice_prompt`](#nemo_rl-data-processors-_construct_multichoice_prompt) | Construct prompt from question and options. | +| [`helpsteer3_data_processor`](#nemo_rl-data-processors-helpsteer3_data_processor) | Process a HelpSteer3 preference datum into a DatumSpec for GRPO training. | +| [`math_data_processor`](#nemo_rl-data-processors-math_data_processor) | Process a datum dictionary (directly loaded from dataset) into a DatumSpec for the Math Environment. | +| [`math_hf_data_processor`](#nemo_rl-data-processors-math_hf_data_processor) | Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Reward Model Environment. | +| [`multichoice_qa_processor`](#nemo_rl-data-processors-multichoice_qa_processor) | Process a datum dictionary (directly loaded from dataset) into a DatumSpec for multiple-choice problems. | +| [`nemo_gym_data_processor`](#nemo_rl-data-processors-nemo_gym_data_processor) | Process a datum dictionary (directly loaded from dataset) into a DatumSpec for Nemo Gym. | +| [`preference_preprocessor`](#nemo_rl-data-processors-preference_preprocessor) | Process a datum dictionary for RM/DPO training. | +| [`register_processor`](#nemo_rl-data-processors-register_processor) | - | +| [`sft_processor`](#nemo_rl-data-processors-sft_processor) | Process a datum dictionary for SFT training. | +| [`vlm_hf_data_processor`](#nemo_rl-data-processors-vlm_hf_data_processor) | Process a datum dictionary (directly loaded from response_datasets/<dataset_name>.py) into a DatumSpec for the VLM Environment. | + +### Data + +[`PROCESSOR_REGISTRY`](#nemo_rl-data-processors-PROCESSOR_REGISTRY) + +[`TokenizerType`](#nemo_rl-data-processors-TokenizerType) + +### API + + + + + +```python +nemo_rl.data.processors._construct_multichoice_prompt( + prompt: str, + question: str, + options: dict[str, str] +) -> str +``` + + + + + + +Construct prompt from question and options. + + + + + + + + +```python +nemo_rl.data.processors.helpsteer3_data_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.processors.TokenizerType, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a HelpSteer3 preference datum into a DatumSpec for GRPO training. + +This function converts HelpSteer3 preference data to work with GRPO by: +1. Using the context as the prompt +2. Using the preferred completion as the target response +3. Creating a reward signal based on preference scores + + + + + + + + +```python +nemo_rl.data.processors.math_data_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.processors.TokenizerType, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary (directly loaded from dataset) into a DatumSpec for the Math Environment. + + + + + + + + +```python +nemo_rl.data.processors.math_hf_data_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.processors.TokenizerType, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Reward Model Environment. + + + + + + + + +```python +nemo_rl.data.processors.multichoice_qa_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.processors.TokenizerType, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary (directly loaded from dataset) into a DatumSpec for multiple-choice problems. + + + + + + + + +```python +nemo_rl.data.processors.nemo_gym_data_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.processors.TokenizerType, + max_seq_length: int | None, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary (directly loaded from dataset) into a DatumSpec for Nemo Gym. + + + + + + + + +```python +nemo_rl.data.processors.preference_preprocessor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.PreferenceDatumSpec +``` + + + + + + +Process a datum dictionary for RM/DPO training. + +**Examples:** + + + +```python +>>> from transformers import AutoTokenizer +>>> from nemo_rl.data.interfaces import TaskDataSpec +>>> from nemo_rl.data.processors import preference_preprocessor +>>> +>>> # Initialize tokenizer and task spec +>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") +>>> ## set a passthrough chat template for simplicity +>>> tokenizer.chat_template = "{% for message in messages %}{{ message['content'] }}{% endfor %}" +>>> task_spec = TaskDataSpec(task_name="test_preference") +>>> +>>> datum = { +... "context": [{"role": "user", "content": "What is 2+2?"}], +... "completions": [ +... {"rank": 0, "completion": [{"role": "assistant", "content": "4"}]}, +... {"rank": 1, "completion": [{"role": "assistant", "content": "5"}]} +... ] +... } +>>> +>>> processed = preference_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) # doctest: +ELLIPSIS + +... +>>> len(processed["message_log_chosen"]) +2 +>>> processed["message_log_chosen"][0]["content"] +'<|begin_of_text|>What is 2+2?' +>>> processed["message_log_chosen"][-1]["content"] +'4<|eot_id|>' +>>> processed["message_log_rejected"][-1]["content"] +'5<|eot_id|>' +>>> +>>> # context can also contain multiple turns +>>> datum = { +... "context": [{"role": "user", "content": "I have a question."}, {"role": "assistant", "content": "Sure!"}, {"role": "user", "content": "What is 2+2?"}], +... "completions": [ +... {"rank": 0, "completion": [{"role": "assistant", "content": "4"}]}, +... {"rank": 1, "completion": [{"role": "assistant", "content": "5"}]} +... ] +... } +>>> processed = preference_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) +>>> len(processed["message_log_chosen"]) +4 +>>> processed["message_log_chosen"][1]["content"] +'Sure!' +>>> processed["message_log_chosen"][-1]["content"] +'4<|eot_id|>' +>>> processed["message_log_rejected"][-1]["content"] +'5<|eot_id|>' +``` + + + + + + + + + + +```python +nemo_rl.data.processors.register_processor( + processor_name: str, + processor_function: nemo_rl.data.interfaces.TaskDataProcessFnCallable +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.data.processors.sft_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, + add_bos: bool = True, + add_eos: bool = True, + add_generation_prompt: bool = False +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary for SFT training. + + + + + + + + +```python +nemo_rl.data.processors.vlm_hf_data_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + processor: transformers.AutoProcessor, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary (directly loaded from response_datasets/<dataset_name>.py) into a DatumSpec for the VLM Environment. + + + + + + + + +```python +nemo_rl.data.processors.PROCESSOR_REGISTRY: Dict[str, TaskDataProcessFnCallable] = cast(Dict[str, TaskDataProcessFnCallable], {'default': math_hf_data_processor, '... +``` + + + + + + + + + +```python +nemo_rl.data.processors.TokenizerType = PreTrainedTokenizerBase +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx new file mode 100644 index 0000000..386880c --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx @@ -0,0 +1,104 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/utils +title: nemo_rl.data.utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`setup_preference_data`](#nemo_rl-data-utils-setup_preference_data) | Setup preference data. | +| [`setup_response_data`](#nemo_rl-data-utils-setup_response_data) | Setup data with environments. | + +### API + + + + + +```python +nemo_rl.data.utils.setup_preference_data( + tokenizer: transformers.AutoTokenizer, + data_config: nemo_rl.data.DataConfig +) +``` + + + + + + +Setup preference data. + +This function is used to setup the preference data for the training and validation datasets. + +**Parameters:** + + +Tokenizer. + + + +Data config for preference dataset. + + +**Returns:** + +A tuple of (train dataset, validation dataset). + + + + + + + + +```python +nemo_rl.data.utils.setup_response_data( + tokenizer: transformers.AutoProcessor | transformers.AutoTokenizer, + data_config: nemo_rl.data.DataConfig, + env_configs: typing.Optional[dict[str, typing.Any]] = None, + is_vlm: bool = False +) -> typing.Union[tuple[nemo_rl.data.datasets.AllTaskProcessedDataset, typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset]], tuple[nemo_rl.data.datasets.AllTaskProcessedDataset, typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset], dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]]] +``` + + + + + + +Setup data with environments. + +This function is used to setup the data and environments for the training and validation datasets. + +**Parameters:** + + +Tokenizer or processor. + + + +Data config. + + + +Environment configs. +If None, no environments will be created. This is used for: +- Algorithms like SFT which do not need environments. +- Environments like NeMo-Gym which need to handle the environment creation outside of this function. + + + +Whether to use VLM training or not. + + +**Returns:** `Union[tuple[AllTaskProcessedDataset, Optional[AllTaskProcessedDataset]], tuple[AllTaskProcessedDataset, Optional[AllTaskProcessedDataset], dict[str, EnvironmentInterface], dict[str, EnvironmentInterface]]]` + +If env_configs is not None: +A tuple of (train dataset, validation dataset, task to environment, task to validation environment). + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx new file mode 100644 index 0000000..d615e36 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx @@ -0,0 +1,17 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed +title: nemo_rl.distributed +--- + +## Submodules + +- **[`nemo_rl.distributed.batched_data_dict`](/nemo-rl/nemo_rl/distributed/batched_data_dict)** +- **[`nemo_rl.distributed.collectives`](/nemo-rl/nemo_rl/distributed/collectives)** +- **[`nemo_rl.distributed.model_utils`](/nemo-rl/nemo_rl/distributed/model_utils)** +- **[`nemo_rl.distributed.named_sharding`](/nemo-rl/nemo_rl/distributed/named_sharding)** +- **[`nemo_rl.distributed.ray_actor_environment_registry`](/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry)** +- **[`nemo_rl.distributed.stateless_process_group`](/nemo-rl/nemo_rl/distributed/stateless_process_group)** +- **[`nemo_rl.distributed.virtual_cluster`](/nemo-rl/nemo_rl/distributed/virtual_cluster)** +- **[`nemo_rl.distributed.worker_group_utils`](/nemo-rl/nemo_rl/distributed/worker_group_utils)** +- **[`nemo_rl.distributed.worker_groups`](/nemo-rl/nemo_rl/distributed/worker_groups)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx new file mode 100644 index 0000000..c134df2 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx @@ -0,0 +1,671 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/batched_data_dict +title: nemo_rl.distributed.batched_data_dict +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BatchedDataDict`](#nemo_rl-distributed-batched_data_dict-BatchedDataDict) | - | +| [`DynamicBatchingArgs`](#nemo_rl-distributed-batched_data_dict-DynamicBatchingArgs) | Configuration settings for dynamic batching. | +| [`SequencePackingArgs`](#nemo_rl-distributed-batched_data_dict-SequencePackingArgs) | Configuration settings for sequence packing. | +| [`SlicedDataDict`](#nemo_rl-distributed-batched_data_dict-SlicedDataDict) | A specialized subclass of BatchedDataDict that represents a slice or shard of a larger batch. | + +### Data + +[`DictT`](#nemo_rl-distributed-batched_data_dict-DictT) + +### API + + + + + +```python +class nemo_rl.distributed.batched_data_dict.BatchedDataDict( + args = (), + kwargs = {} +) +``` + + + + + + +**Bases:** `UserDict`, `Generic[DictT]` + + + + + +Get the batch size of the batch. + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.all_gather( + group: torch.distributed.ProcessGroup +) -> typing_extensions.Self +``` + + + + + + +Gathers batches with possibly jagged leading dimensions across the DP ranks. + +If using reshard, it will treat PP as DP ranks. +Works with data that is either tensors or string lists. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.chunk( + rank: int, + chunks: int +) -> nemo_rl.distributed.batched_data_dict.SlicedDataDict +``` + + + + + + +Chunks a global batch into 'chunks' splits and returns the 'rank'th split batch=[A A A B B B D D E], rank=2, chunks=3 -> [D D E]. + +Requires all leading dimensions of tensors and lengths of lists to be the same over the batch +and the chunks must divide batch size. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.from_batches( + batches: typing.Sequence[typing.Mapping[typing.Any, typing.Any]], + pad_value_dict: typing.Optional[dict[str, int | float]] = None +) -> typing_extensions.Self +``` + + + + + + +classmethod + +Given a list of batches, stack the tensors/lists within and put them in a single dictionary. + +Pad sequences to the max length in the batch using either 0(default) or a non-default value for a given key provided in pad_value_dict. + +**Parameters:** + + +A list of dictionaries, each containing a batch of data. + + + +An optional dict mapping keys to non-default(0) padding values. + + +**Returns:** `Self` + +A new BatchedDataDict containing the stacked data. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_batch( + batch_idx, + batch_size = None +) -> nemo_rl.distributed.batched_data_dict.SlicedDataDict +``` + + + + + + +Slices a subbatch from the batch. + +**Parameters:** + + +the batch index to slice + + + +the size of the batch to be sliced + + +**Returns:** `SlicedDataDict` + +A new BatchedDataDict containing the sliced data + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_dict() -> dict[typing.Any, typing.Any] +``` + + + + + + +Get the underlying data dictionary. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_microbatch_iterator_dynamic_shapes_len() -> int +``` + + + + + + +Get the length of the microbatch iterator for dynamic shapes. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_microbatch_iterator_for_packable_sequences_len() -> tuple[int, int] +``` + + + + + + +Get the length of the microbatch iterator for sequence packing and the max packed seqlen. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_multimodal_dict( + as_tensors: bool = False, + device: typing.Optional[torch.device] = None +) -> dict[str, typing.Any] +``` + + + + + + +Return a regular dict of tensors or packed multimodal data items. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.make_microbatch_iterator( + microbatch_size: int +) -> typing.Iterator[nemo_rl.distributed.batched_data_dict.SlicedDataDict] +``` + + + + + + +Make an iterator over the batch that yields microbatches of size microbatch_size. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.make_microbatch_iterator_for_packable_sequences() -> typing.Iterator[nemo_rl.distributed.batched_data_dict.SlicedDataDict] +``` + + + + + + +Make an iterator over the batch that yields microbatches that can be packed into a given max_tokens_per_microbatch. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.make_microbatch_iterator_with_dynamic_shapes( + sequence_dim: int = 1 +) -> typing.Iterator[nemo_rl.distributed.batched_data_dict.SlicedDataDict] +``` + + + + + + +Makes an iterator that yields microbatchs of dynamic batch and sequence sizes. + +**Parameters:** + + +the index of the sequence dim for all tensors in the data dict + + +**Returns:** `Iterator[SlicedDataDict]` + +Iterator["SlicedDataDict"]: An iterator that yield dynamic microbatches + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.reorder_data( + reorded_indices: list[int] +) +``` + + + + + + +Reorders the data along the batch dimension by the given indices. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.repeat_interleave( + num_repeats: int +) -> typing_extensions.Self +``` + + + + + + +Repeats the batch num_repeats times. + +For each element in the batch, repeat each value num_repeats times. +i.e: +{"key": torch.tensor([1, 2, 3]), "other_key": [1, 2, 3]} -> {"key": torch.tensor([1, 1, 2, 2, 3, 3]), "other_key": [1, 1, 2, 2, 3, 3]} + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.select_indices( + indices: typing.Union[list[int], torch.Tensor] +) -> typing_extensions.Self +``` + + + + + + +Selects specific rows from the batch based on indices. + +**Parameters:** + + +A list or tensor of integer indices to select. + + +**Returns:** `Self` + +A new BatchedDataDict containing only the selected rows. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.shard_by_batch_size( + shards: int, + batch_size: typing.Optional[int] = None, + allow_uneven_shards: bool = False, + dynamic_batching_args: typing.Optional[nemo_rl.distributed.batched_data_dict.DynamicBatchingArgs] = None, + sequence_packing_args: typing.Optional[nemo_rl.distributed.batched_data_dict.SequencePackingArgs] = None +) -> list[nemo_rl.distributed.batched_data_dict.SlicedDataDict] | tuple[list[nemo_rl.distributed.batched_data_dict.SlicedDataDict], list[int]] +``` + + + + + + +Shards a batch by first dividing it into chunks of size batch_size, then further dividing each chunk into shards equal parts. Finally aggregates the sub-shards by their position. + +If batch_size is None, there will be no chunking beforehand (will default to the total batch size). + +For example, with data [A A B B C C D D], batch_size=2, shards=2: +- Element 0: [A B C D] (first elements from each chunk) +- Element 1: [A B C D] (second elements from each chunk) + +Examples: + + +```python +>>> from nemo_rl.distributed.batched_data_dict import BatchedDataDict +>>> # Create a batch of two message logs with different lengths +>>> batch = BatchedDataDict({ +... 'problem_id': [0, 0, 1, 1, 2, 2, 3, 3], +... 'arbitrary_data': [1, 2, 3, 4, 5, 6, 7, 8] +... }) +>>> shards = batch.shard_by_batch_size(shards=2) +>>> shards +[{'problem_id': [0, 0, 1, 1], 'arbitrary_data': [1, 2, 3, 4]}, {'problem_id': [2, 2, 3, 3], 'arbitrary_data': [5, 6, 7, 8]}] +>>> # Now say that I'm training with a GBS of 4 and I want to take gradients steps on problems 0 and 1 before 2 and 3 (problems are repeated because GRPO) +>>> # In the current case, problems 0 and 2 will be trained on first since they're the first elements in each DP rank's batch. +>>> # So, we'll use the batch_size argument to split the batch into chunks of size 4 first. +>>> shards = batch.shard_by_batch_size(shards=2, batch_size=4) +>>> shards +[{'problem_id': [0, 0, 2, 2], 'arbitrary_data': [1, 2, 5, 6]}, {'problem_id': [1, 1, 3, 3], 'arbitrary_data': [3, 4, 7, 8]}] +>>> # Now, the ranks have 0 and 1 first so when they split their batches into microbatches (of size 2 since GBS=4 and DP=2), they'll train on 0 and 1 first. +>>> # Another way to use this function is with the 'allow_uneven_shards' flag, which allows the last shard to be smaller than the others when necessary. +>>> # This is necessary in multi-turn rollouts when some sequences terminate early, leaving unclean batch sizes. +>>> batch = BatchedDataDict({ +... 'problem_id': [0, 1, 2, 3, 4], +... 'arbitrary_data': [10, 11, 12, 13, 14] +... }) +>>> shards = batch.shard_by_batch_size(shards=2, allow_uneven_shards=True) +>>> shards +[{'problem_id': [0, 1, 2], 'arbitrary_data': [10, 11, 12]}, {'problem_id': [3, 4], 'arbitrary_data': [13, 14]}] +>>> # This is incompatible with the batch_size argument +``` + + + +**Parameters:** + + +The number of shards to divide each batch_size chunk into. + + + +The size of each initial chunk. + + + +Whether to allow shards to be unevenly sized. + If True, the last shard may be smaller than the others. + + + +If passed, preprocess batch for dynamic batching. This + dict requires four keys: + 1. max_tokens_per_microbatch (int): the maximum + number of tokens in a microbatch + 2. sequence_length_round (int): round each all + sequence lengths to this multiple + 3. input_key (str): the key in the batch + which holds input ids. + 4. input_lengths_key (str): the key in the batch + which holds the sequence length per value. + The sequence dim index is assumed to be 1. + Cannot be passed with sequence_packing_args. + + + +If passed, preprocess batch for sequence packing. This + dict requires five keys: + 1. max_tokens_per_microbatch (int): the maximum + number of tokens in a microbatch + 2. input_key (str): the key in the batch + which holds input ids. + 3. input_lengths_key (str): the key in the batch + which holds the sequence length per value. + The sequence dim index is assumed to be 1. + 4. algorithm (str): the algorithm to use for sequence packing. + 5. sequence_length_pad_multiple (int): the multiple to pad each sequence to. + With CP enabled, this should be set to a multiple of 2*CP and SP. + Cannot be passed with dynamic_batching_args. + + +**Returns:** `list[SlicedDataDict] | tuple[list[SlicedDataDict], list[int]]` + +list[BatchedDataDict]: A list of BatchedDataDicts, length equal to shards. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.slice( + start: int, + end: int +) -> nemo_rl.distributed.batched_data_dict.SlicedDataDict +``` + + + + + + +Slices the batch from start to end. + +**Parameters:** + + +Starting index (inclusive) + + + +Ending index (exclusive) + + +**Returns:** `SlicedDataDict` + +A new BatchedDataDict containing the sliced data + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.to( + device: str | torch.device +) -> typing_extensions.Self +``` + + + + + + +Move tensors in batched dict to device. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.truncate_tensors( + dim: int, + truncated_len: int +) +``` + + + + + + +Truncates tensors in this dict of a given dim to a given length. + + + + + + + + + +```python +class nemo_rl.distributed.batched_data_dict.DynamicBatchingArgs +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration settings for dynamic batching. + +Pass this to 'shard_by_batch_size()' to preprocess batches for dynamic batching. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.batched_data_dict.SequencePackingArgs +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration settings for sequence packing. + +Pass this to 'shard_by_batch_size()' to preprocess batches for sequence packing. + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.batched_data_dict.SlicedDataDict() +``` + + + + + + +**Bases:** [BatchedDataDict](#nemo_rl-distributed-batched_data_dict-BatchedDataDict) + +A specialized subclass of BatchedDataDict that represents a slice or shard of a larger batch. + +This class provides a distinct type to differentiate between full batches and sliced/sharded batches, which can be helpful for +type checking. + + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.DictT = TypeVar('DictT', bound=(Mapping[str, Any])) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx new file mode 100644 index 0000000..5d756ce --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx @@ -0,0 +1,108 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/collectives +title: nemo_rl.distributed.collectives +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`gather_jagged_object_lists`](#nemo_rl-distributed-collectives-gather_jagged_object_lists) | Gathers jagged lists of picklable objects from all ranks and flattens them into a single list. | +| [`rebalance_nd_tensor`](#nemo_rl-distributed-collectives-rebalance_nd_tensor) | Takes tensors with variable leading sizes (at dim=0) and stacks them into a single tensor. | + +### Data + +[`T`](#nemo_rl-distributed-collectives-T) + +### API + + + + + +```python +nemo_rl.distributed.collectives.gather_jagged_object_lists( + local_objects: list[nemo_rl.distributed.collectives.T], + group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> list[nemo_rl.distributed.collectives.T] +``` + + + + + + +Gathers jagged lists of picklable objects from all ranks and flattens them into a single list. + +This function handles the case where different GPUs have lists of different lengths +and combines them into a single list containing all objects from all ranks. + +For example, with 3 GPUs: + GPU0: [obj0, obj1] + GPU1: [obj2, obj3, obj4] + GPU2: [obj5] + +WARNING: synchronous + +**Parameters:** + + +List of objects to gather from current rank + + + +Optional process group + + +**Returns:** `list[T]` + +Flattened list of all objects from all ranks in order [rank0, rank1, ...] + + + + + + + + +```python +nemo_rl.distributed.collectives.rebalance_nd_tensor( + tensor: torch.Tensor, + group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> torch.Tensor +``` + + + + + + +Takes tensors with variable leading sizes (at dim=0) and stacks them into a single tensor. + +This function handles the case where different GPUs have tensors with different batch sizes +and combines them into a single balanced tensor across all ranks. + +For example, with 3 GPUs: + GPU0: tensor of shape [3, D] + GPU1: tensor of shape [5, D] + GPU2: tensor of shape [2, D] + +NOTE: assumes all other (i.e., non-zero) dimensions are equal. + + + + + + + + +```python +nemo_rl.distributed.collectives.T = TypeVar('T') +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx new file mode 100644 index 0000000..57539ab --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx @@ -0,0 +1,851 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/model_utils +title: nemo_rl.distributed.model_utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AllGatherCPTensor`](#nemo_rl-distributed-model_utils-AllGatherCPTensor) | - | +| [`ChunkedDistributedEntropy`](#nemo_rl-distributed-model_utils-ChunkedDistributedEntropy) | Compute H_all = sum_v p_v log p_v across TP with chunking over sequence. | +| [`ChunkedDistributedGatherLogprob`](#nemo_rl-distributed-model_utils-ChunkedDistributedGatherLogprob) | Compute distributed log-softmax once and gather logprobs at given global indices. | +| [`ChunkedDistributedLogprob`](#nemo_rl-distributed-model_utils-ChunkedDistributedLogprob) | Custom autograd function for computing log probabilities in a distributed setting. | +| [`DistributedLogprob`](#nemo_rl-distributed-model_utils-DistributedLogprob) | Custom autograd function for computing log probabilities in a distributed setting. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_compute_distributed_log_softmax`](#nemo_rl-distributed-model_utils-_compute_distributed_log_softmax) | Compute a stable distributed log softmax across tensor parallel workers. | +| [`_get_tokens_on_this_cp_rank`](#nemo_rl-distributed-model_utils-_get_tokens_on_this_cp_rank) | Get tokens on this context parallelism rank. | +| [`allgather_cp_sharded_tensor`](#nemo_rl-distributed-model_utils-allgather_cp_sharded_tensor) | - | +| [`distributed_vocab_topk`](#nemo_rl-distributed-model_utils-distributed_vocab_topk) | Compute global top-k over TP-sharded vocabulary logits. | +| [`dtensor_from_parallel_logits_to_logprobs`](#nemo_rl-distributed-model_utils-dtensor_from_parallel_logits_to_logprobs) | Get log probabilities from TP+CP sharded vocab logits. | +| [`from_parallel_logits_to_logprobs`](#nemo_rl-distributed-model_utils-from_parallel_logits_to_logprobs) | Get log probabilities from TP+CP sharded vocab logits. | +| [`from_parallel_logits_to_logprobs_packed_sequences`](#nemo_rl-distributed-model_utils-from_parallel_logits_to_logprobs_packed_sequences) | Get log probabilities from TP sharded vocab logits for packed sequences. | +| [`gather_logits_at_global_indices`](#nemo_rl-distributed-model_utils-gather_logits_at_global_indices) | Gather student logits at given global token indices under TP+CP sharding. | +| [`get_logprobs_from_vocab_parallel_logits`](#nemo_rl-distributed-model_utils-get_logprobs_from_vocab_parallel_logits) | Computes log probabilities from vocabulary-parallel logits. | + +### API + + + + + +```python +class nemo_rl.distributed.model_utils.AllGatherCPTensor() +``` + + + + + + +**Bases:** `Function` + + + + + +```python +nemo_rl.distributed.model_utils.AllGatherCPTensor.backward( + ctx, + grad_output +) +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.model_utils.AllGatherCPTensor.forward( + ctx, + tensor, + cp_group: torch.distributed.ProcessGroup, + seq_dim = 1 +) +``` + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.model_utils.ChunkedDistributedEntropy() +``` + + + + + + +**Bases:** `Function` + +Compute H_all = sum_v p_v log p_v across TP with chunking over sequence. + +Forward returns [B, S] tensor of global entropy; backward propagates through logits. + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedEntropy.backward( + ctx: typing.Any, + grad_outputs: torch.Tensor = () +) -> tuple[torch.Tensor, None, None, None] +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedEntropy.forward( + ctx: typing.Any, + vocab_parallel_logits: torch.Tensor, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False +) -> torch.Tensor +``` + + + + + + +staticmethod + + + + + + + + + +```python +class nemo_rl.distributed.model_utils.ChunkedDistributedGatherLogprob() +``` + + + + + + +**Bases:** `Function` + +Compute distributed log-softmax once and gather logprobs at given global indices. + +Forward computes per-chunk distributed log-softmax across TP, gathers selected +log probabilities at the provided global indices (shape [B, S, K]), and returns +a tensor of shape [B, S, K]. + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedGatherLogprob.backward( + ctx: typing.Any, + grad_outputs: torch.Tensor = () +) -> tuple[torch.Tensor, None, None, None, None, None, None] +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedGatherLogprob.forward( + ctx: typing.Any, + vocab_parallel_logits: torch.Tensor, + global_indices: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False +) -> torch.Tensor +``` + + + + + + +staticmethod + + + + + + + + + +```python +class nemo_rl.distributed.model_utils.ChunkedDistributedLogprob() +``` + + + + + + +**Bases:** `Function` + +Custom autograd function for computing log probabilities in a distributed setting. + +The log probabilities computation is chunked in the sequence dimension +to mitigate GPU OOM (especially during backward pass). +In addition, logits casting from float16 or bfloat16 -> float32 is performed +inside the chunk loop to avoid materializing a whole float32 logits tensor. + +Adapted from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedLogprob.backward( + ctx: typing.Any, + grad_outputs: torch.Tensor = () +) -> tuple[torch.Tensor, None, None, None, None, None, None] +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedLogprob.forward( + ctx: typing.Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False +) -> torch.Tensor +``` + + + + + + +staticmethod + + + + + + + + + +```python +class nemo_rl.distributed.model_utils.DistributedLogprob() +``` + + + + + + +**Bases:** `Function` + +Custom autograd function for computing log probabilities in a distributed setting. + +Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + + + + + + +```python +nemo_rl.distributed.model_utils.DistributedLogprob.backward( + ctx: typing.Any, + grad_outputs: torch.Tensor = () +) -> tuple[torch.Tensor, None, None, None, None, None, None] +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.distributed.model_utils.DistributedLogprob.forward( + ctx: typing.Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + group: torch.distributed.ProcessGroup, + inference_only: bool = False +) -> torch.Tensor +``` + + + + + + +staticmethod + + + + + + + + + +```python +nemo_rl.distributed.model_utils._compute_distributed_log_softmax( + vocab_parallel_logits: torch.Tensor, + group: torch.distributed.ProcessGroup +) -> torch.Tensor +``` + + + + + + +Compute a stable distributed log softmax across tensor parallel workers. + +Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265 + +**Parameters:** + + +Logits tensor with shape [batch_size, seq_length, vocab_size//TP] +where TP is the tensor parallel size. + + + +Process group for the all-reduce operations. + + +**Returns:** `torch.Tensor` + +torch.Tensor: Log softmax output with the same shape as input, but values represent +log probabilities normalized across the full vocabulary dimension. + + + + + + + + +```python +nemo_rl.distributed.model_utils._get_tokens_on_this_cp_rank( + input_ids: torch.Tensor, + cp_rank: int, + cp_size: int, + seq_dim: int = 1 +) -> torch.Tensor +``` + + + + + + +Get tokens on this context parallelism rank. + +Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1. + +**Parameters:** + + +Input token IDs [seq_length, ] + + + +Context parallelism rank + + + +Context parallelism size + + +**Returns:** `torch.Tensor` + +Tokens on this context parallelism rank [1, seq_length // cp_size] + + + + + + + + +```python +nemo_rl.distributed.model_utils.allgather_cp_sharded_tensor( + tensor, + cp_group, + seq_dim = 1 +) +``` + + + + + + + + + + + + + +```python +nemo_rl.distributed.model_utils.distributed_vocab_topk( + vocab_parallel_logits: torch.Tensor, + k: int, + tp_group: torch.distributed.ProcessGroup, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: typing.Optional[int] = None +) -> tuple[torch.Tensor, torch.Tensor] +``` + + + + + + +Compute global top-k over TP-sharded vocabulary logits. + +**Parameters:** + + +[B, S, V_local] + + + +number of top tokens to select globally + + + +tensor-parallel process group + + + +global vocab start for this rank (inclusive) + + + +global vocab end for this rank (exclusive) + + + +optional chunk along sequence dim to bound memory + + +**Returns:** `torch.Tensor` + +[B, S, k] + + + + + + + + +```python +nemo_rl.distributed.model_utils.dtensor_from_parallel_logits_to_logprobs( + vocab_parallel_logits: torch.Tensor, + target: torch.distributed.tensor.DTensor | torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + seq_index: typing.Optional[torch.Tensor] = None, + chunk_size: typing.Optional[int] = None +) -> torch.Tensor +``` + + + + + + +Get log probabilities from TP+CP sharded vocab logits. + +**Parameters:** + + +Logits distributed across tensor parallel workers, +with shape [batch_size, seq_len, vocab_size/tp_size]. + + + +Target token indices with shape [batch_size, seq_len]. +NOTE: Must be the unmodified targets as this function will shift them internally. + + + +Starting vocabulary index for this worker's partition. + + + +Ending vocabulary index for this worker's partition. + + + +Process group for distributed communication. + + + +If True, tensors won't be saved for backward pass. Defaults to False. + + + +Sequence index tensor with shape [seq_len]. +It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension. + + + +Sequence dimension chunk size for computing the log probabilities. + + +**Returns:** `torch.Tensor` + +torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. +The sequence dimension is reduced by 1 due to the target shifting. + + + + + + + + +```python +nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + cp_group: typing.Optional[torch.distributed.ProcessGroup] = None, + chunk_size: typing.Optional[int] = None +) -> torch.Tensor +``` + + + + + + +Get log probabilities from TP+CP sharded vocab logits. + +Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 + +**Parameters:** + + +Logits tensor with shape [batch_size, seq_len // CP, vocab_size // TP] +where TP is the tensor parallel size. + + + +Target token indices with shape [batch_size, seq_len]. +NOTE: Must be the unmodified targets as this function will shift them internally. + + + +Starting vocabulary index for this worker's partition. + + + +Ending vocabulary index for this worker's partition. + + + +Process group for distributed communication. + + + +If True, tensors won't be saved for backward pass. Defaults to False. + + + +Context parallelism process group. Defaults to None. + + + +Sequence dimension chunk size for computing the log probabilities. + + +**Returns:** `torch.Tensor` + +torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. +The sequence dimension is reduced by 1 due to the target shifting. + + + + + + + + +```python +nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs_packed_sequences( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + unpacked_seqlen: int, + vocab_start_index: int, + vocab_end_index: int, + group: torch.distributed.ProcessGroup, + inference_only: bool = False, + cp_group: typing.Optional[torch.distributed.ProcessGroup] = None, + chunk_size: typing.Optional[int] = None +) -> torch.Tensor +``` + + + + + + +Get log probabilities from TP sharded vocab logits for packed sequences. + +**Parameters:** + + +Packed logits tensor with shape [1, T // CP, vocab_size//TP] +where T is the total number of tokens across all packed sequences. + + + +Packed target token indices with shape [1, T]. +NOTE: Must be the unmodified targets as this function will shift them internally. + + + +Cumulative sequence lengths tensor with shape [batch_size + 1]. +cu_seqlens[i] indicates the start position of sequence i in the packed format. + + + +The length of the unpacked sequence tensor. + + + +Starting vocabulary index for this worker's partition. + + + +Ending vocabulary index for this worker's partition. + + + +Process group for distributed communication. + + + +If True, tensors won't be saved for backward pass. Defaults to False. + + + +Context parallelism process group. Defaults to None. + + + +Sequence dimension chunk size for computing the log probabilities. + + +**Returns:** `torch.Tensor` + +torch.Tensor: Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. +The total length is reduced by batch_size due to target shifting (one token per sequence). + + + + + + + + +```python +nemo_rl.distributed.model_utils.gather_logits_at_global_indices( + vocab_parallel_logits: torch.Tensor, + global_indices: torch.Tensor, + tp_group: typing.Optional[torch.distributed.ProcessGroup] = None, + cp_group: typing.Optional[torch.distributed.ProcessGroup] = None, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: typing.Optional[int] = None +) -> torch.Tensor +``` + + + + + + +Gather student logits at given global token indices under TP+CP sharding. + +Differentiable w.r.t. vocab_parallel_logits. + +**Parameters:** + + +[B, S_cp, V_local] where S_cp is CP sharded sequence length + + + +[B, S_full, k] where S_full is full sequence length + + + +Optional tensor-parallel process group. If None, treats logits as full-vocab (no TP) and skips TP all-reduce. + + + +global vocab start for this rank (inclusive) + + + +global vocab end for this rank (exclusive) + + + +optional chunk along sequence dim to bound memory + + + +Optional context-parallel process group + + +**Returns:** `torch.Tensor` + +[B, S_full, k] + + + + + + + + +```python +nemo_rl.distributed.model_utils.get_logprobs_from_vocab_parallel_logits( + vocab_parallel_logits: torch.distributed.tensor.DTensor, + input_ids: torch.Tensor | torch.distributed.tensor.DTensor, + seq_index: typing.Optional[torch.Tensor] = None, + chunk_size: typing.Optional[int] = None +) +``` + + + + + + +Computes log probabilities from vocabulary-parallel logits. + +This function takes logits that are sharded across the vocabulary dimension (tensor parallel) +and computes the log probabilities for the given input IDs. + +**Parameters:** + + +Logits distributed across tensor parallel workers, +with shape [batch_size, seq_len, vocab_size/tp_size]. + + + +Input token IDs for which to compute log probabilities, +with shape [batch_size, seq_len]. + + + +Sequence index for the input IDs, +with shape [sequence_length]. + + + +Sequence dimension chunk size for computing log probabilities. + + +**Returns:** + +torch.Tensor: Log probabilities for the given input IDs. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx new file mode 100644 index 0000000..20d0bde --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx @@ -0,0 +1,236 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/named_sharding +title: nemo_rl.distributed.named_sharding +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`NamedSharding`](#nemo_rl-distributed-named_sharding-NamedSharding) | Represents an N-dimensional arrangement of ranks with named axes, facilitating data sharding, replication, and collection based on these axes. | + +### API + + + + + +```python +class nemo_rl.distributed.named_sharding.NamedSharding( + layout: typing.Sequence[typing.Any] | numpy.ndarray, + names: list[str] +) +``` + + + + + + +Represents an N-dimensional arrangement of ranks with named axes, facilitating data sharding, replication, and collection based on these axes. + + + + + + + + + + + + +Returns the underlying NumPy array representing the layout. + + + +Returns the names of the axes. + + + +Returns the number of dimensions. + + + +Returns the shape of the rank layout. + + + +Returns the total number of ranks. + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.__eq__( + other: object +) -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.__repr__() -> str +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.get_axis_index( + name: str +) -> int +``` + + + + + + +Gets the numerical index of a named axis. + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.get_axis_size( + name: str +) -> int +``` + + + + + + +Gets the size of a named axis. + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.get_ranks( + kwargs: int = {} +) -> typing.Union[nemo_rl.distributed.named_sharding.NamedSharding, int] +``` + + + + + + +Gets the ranks corresponding to specific indices along named axes. + +**Parameters:** + + +Keyword arguments where the key is the axis name (e.g., "dp", "tp") + and the value is the index along that axis. + + +**Returns:** `Union[NamedSharding, int]` + +A new NamedSharding instance representing the subset of ranks. + +**Raises:** + +- `ValueError`: If an invalid axis name is provided or if an index is out of bounds. + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.get_ranks_by_coord( + coords: int = {} +) -> list[int] +``` + + + + + + +Gets all ranks that match the specified coordinates for named axes. + +**Parameters:** + + +Keyword arguments where the key is the axis name (e.g., "dp", "tp") + and the value is the integer coordinate along that axis. + Axes not specified will match all coordinates along that axis. + + +**Returns:** `list[int]` + +A sorted list of unique rank integers that match the given coordinate criteria. + +**Raises:** + +- `ValueError`: If an invalid axis name is provided. + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.get_worker_coords( + worker_id: int +) -> dict[str, int] +``` + + + + + + +Gets the coordinates of a specific worker ID in the sharding layout. + +**Parameters:** + + +The integer ID of the worker. + + +**Returns:** `dict[str, int]` + +A dictionary mapping axis names to their integer coordinates for the given worker_id. + +**Raises:** + +- `ValueError`: If the worker_id is not found in the layout. + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx new file mode 100644 index 0000000..5396879 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx @@ -0,0 +1,105 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry +title: nemo_rl.distributed.ray_actor_environment_registry +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_actor_python_env`](#nemo_rl-distributed-ray_actor_environment_registry-get_actor_python_env) | - | + +### Data + +[`ACTOR_ENVIRONMENT_REGISTRY`](#nemo_rl-distributed-ray_actor_environment_registry-ACTOR_ENVIRONMENT_REGISTRY) + +[`MCORE_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-MCORE_EXECUTABLE) + +[`SGLANG_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-SGLANG_EXECUTABLE) + +[`USE_SYSTEM_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-USE_SYSTEM_EXECUTABLE) + +[`VLLM_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-VLLM_EXECUTABLE) + +### API + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.get_actor_python_env( + actor_class_fqn: str +) -> str +``` + + + + + + + + + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.ACTOR_ENVIRONMENT_REGISTRY: dict[str, str] = {'nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker': VLLM_EXECUTA... +``` + + + + + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.MCORE_EXECUTABLE = PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.MCORE +``` + + + + + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.SGLANG_EXECUTABLE = PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.SGLANG +``` + + + + + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.USE_SYSTEM_EXECUTABLE = os.environ.get('NEMO_RL_PY_EXECUTABLES_SYSTEM', '0') == '1' +``` + + + + + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.VLLM_EXECUTABLE = PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.VLLM +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx new file mode 100644 index 0000000..ebd4125 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx @@ -0,0 +1,73 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/stateless_process_group +title: nemo_rl.distributed.stateless_process_group +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`StatelessProcessGroup`](#nemo_rl-distributed-stateless_process_group-StatelessProcessGroup) | - | + +### API + + + + + +```python +class nemo_rl.distributed.stateless_process_group.StatelessProcessGroup( + master_address: str, + port: int, + rank: int, + world_size: int +) +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.stateless_process_group.StatelessProcessGroup.broadcast( + tensor: torch.Tensor, + src: int, + stream: typing.Optional[torch.cuda.Stream] = None +) +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.stateless_process_group.StatelessProcessGroup.init_nccl_communicator( + device: int +) +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx new file mode 100644 index 0000000..108df41 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx @@ -0,0 +1,514 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/virtual_cluster +title: nemo_rl.distributed.virtual_cluster +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ClusterConfig`](#nemo_rl-distributed-virtual_cluster-ClusterConfig) | - | +| [`GetGPUIDActor`](#nemo_rl-distributed-virtual_cluster-GetGPUIDActor) | Util actor class to return GPU id of the current worker. | +| [`PY_EXECUTABLES`](#nemo_rl-distributed-virtual_cluster-PY_EXECUTABLES) | - | +| [`RayVirtualCluster`](#nemo_rl-distributed-virtual_cluster-RayVirtualCluster) | Creates a virtual distributed cluster using Ray placement groups. | +| [`ResourceInsufficientError`](#nemo_rl-distributed-virtual_cluster-ResourceInsufficientError) | Exception raised when the cluster does not have enough resources to satisfy the requested configuration. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_free_port_local`](#nemo_rl-distributed-virtual_cluster-_get_free_port_local) | - | +| [`_get_node_ip_and_free_port`](#nemo_rl-distributed-virtual_cluster-_get_node_ip_and_free_port) | - | +| [`_get_node_ip_local`](#nemo_rl-distributed-virtual_cluster-_get_node_ip_local) | - | +| [`init_ray`](#nemo_rl-distributed-virtual_cluster-init_ray) | Initialise Ray. | + +### Data + +[`dir_path`](#nemo_rl-distributed-virtual_cluster-dir_path) + +[`git_root`](#nemo_rl-distributed-virtual_cluster-git_root) + +[`logger`](#nemo_rl-distributed-virtual_cluster-logger) + +### API + + + + + +```python +class nemo_rl.distributed.virtual_cluster.ClusterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.virtual_cluster.GetGPUIDActor() +``` + + + + + + +Util actor class to return GPU id of the current worker. + + + + + + +```python +nemo_rl.distributed.virtual_cluster.GetGPUIDActor.get_gpu_id() +``` + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.virtual_cluster.PY_EXECUTABLES() +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.virtual_cluster.RayVirtualCluster( + bundle_ct_per_node_list: list[int], + use_gpus: bool = True, + max_colocated_worker_groups: int = 1, + num_gpus_per_node: int = 8, + name: str = '', + placement_group_strategy: str = 'SPREAD' +) +``` + + + + + + +Creates a virtual distributed cluster using Ray placement groups. + +This class simplifies distributed training setup by: +- Creating placement groups that represent logical compute nodes +- Allocating GPU and CPU resources for distributed workers +- Managing communication between distributed processes + +- Bundle: A resource allocation unit (ex: 4 GPUs on a single node) +- Worker: A process that performs computation (model training/inference) +- Node: A physical or virtual machine containing multiple bundles + + + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.__del__() -> None +``` + + + + + + +Shutsdown the virtual cluster when the object is deleted or is garbage collected. + +This is an extra safety net in case the user forgets to call shutdown and the pointer to +the cluster is lost due to leaving a function scope. It's always recommended that the +user calls shutdown(). + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster._create_placement_groups_internal( + strategy: str, + use_unified_pg: bool = False +) -> list[ray.util.placement_group.PlacementGroup] +``` + + + + + + +Internal method to create placement groups without retry logic. + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster._get_sorted_bundle_indices() -> typing.Optional[list[int]] +``` + + + + + + +Gets the sorted bundle indices for the placement groups. + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups( + strategy: str | None = None, + use_unified_pg: bool = False +) -> list[ray.util.placement_group.PlacementGroup] +``` + + + + + + +Creates placement groups based on whether cross-node model parallelism is needed. + +**Parameters:** + + +Ray placement group strategy (defaults to self.placement_group_strategy) + + + +If True, create a single unified placement group. + If False, create per-node placement groups. + + +**Returns:** `list[PlacementGroup]` + +List of placement groups + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.get_available_address_and_port( + pg_idx: int, + bundle_idx: int +) -> tuple[str, int] +``` + + + + + + +Gets an available address and port for the given placement group index and bundle index. + +**Returns:** `tuple[str, int]` + +Tuple of (address, port) + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.get_master_address_and_port() -> tuple[str, int] +``` + + + + + + +Gets the master address and port for the distributed training setup. + +**Returns:** `tuple[str, int]` + +Tuple of (address, port) + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.get_placement_groups() -> list[ray.util.placement_group.PlacementGroup] +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.node_count() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.shutdown() -> bool +``` + + + + + + +Cleans up and releases all resources associated with this virtual cluster. + +This includes removing all placement groups and resetting the internal state. + +This method is idempotent and can be safely called multiple times. + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.world_size() -> int +``` + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.virtual_cluster.ResourceInsufficientError() +``` + + + + + + +Exception + +**Bases:** `Exception` + +Exception raised when the cluster does not have enough resources to satisfy the requested configuration. + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster._get_free_port_local() -> int +``` + + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster._get_node_ip_and_free_port() -> tuple[str, int] +``` + + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster._get_node_ip_local() -> str +``` + + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.init_ray( + log_dir: typing.Optional[str] = None +) -> None +``` + + + + + + +Initialise Ray. + +Try to attach to an existing local cluster. +If that cluster uses the same CUDA_VISIBLE_DEVICES or Slurm managed tag we will reuse it. +Otherwise, we will detach and start a fresh local cluster. + +**Parameters:** + + +Optional directory to store Ray logs and temp files. + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.dir_path = os.path.dirname(os.path.abspath(__file__)) +``` + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.git_root = os.path.abspath(os.path.join(dir_path, '../..')) +``` + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx new file mode 100644 index 0000000..8519d0f --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx @@ -0,0 +1,81 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/worker_group_utils +title: nemo_rl.distributed.worker_group_utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_nsight_config_if_pattern_matches`](#nemo_rl-distributed-worker_group_utils-get_nsight_config_if_pattern_matches) | Check if worker name matches patterns in NRL_NSYS_WORKER_PATTERNS and return nsight config. | +| [`recursive_merge_options`](#nemo_rl-distributed-worker_group_utils-recursive_merge_options) | Recursively merge extra options into default options using OmegaConf. | + +### API + + + + + +```python +nemo_rl.distributed.worker_group_utils.get_nsight_config_if_pattern_matches( + worker_name: str +) -> dict[str, typing.Any] +``` + + + + + + +Check if worker name matches patterns in NRL_NSYS_WORKER_PATTERNS and return nsight config. + +**Parameters:** + + +Name of the worker to check against patterns + + +**Returns:** `dict[str, Any]` + +Dictionary containing {"nsight": config} if pattern matches, empty dict otherwise + + + + + + + + +```python +nemo_rl.distributed.worker_group_utils.recursive_merge_options( + default_options: dict[str, typing.Any], + extra_options: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + +Recursively merge extra options into default options using OmegaConf. + +**Parameters:** + + +Default options dictionary (lower precedence) + + + +Extra options provided by the caller (higher precedence) + + +**Returns:** `dict[str, Any]` + +Merged options dictionary with extra_options taking precedence over default_options + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx new file mode 100644 index 0000000..e0205d5 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx @@ -0,0 +1,603 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/worker_groups +title: nemo_rl.distributed.worker_groups +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MultiWorkerFuture`](#nemo_rl-distributed-worker_groups-MultiWorkerFuture) | Container for Ray futures with associated worker information. | +| [`RayWorkerBuilder`](#nemo_rl-distributed-worker_groups-RayWorkerBuilder) | - | +| [`RayWorkerGroup`](#nemo_rl-distributed-worker_groups-RayWorkerGroup) | Manages a group of distributed Ray worker/actor processes that execute tasks in parallel. | + +### API + + + + + +```python +class nemo_rl.distributed.worker_groups.MultiWorkerFuture( + futures: list[ray.ObjectRef], + return_from_workers: typing.Optional[list[int]] = None, + called_workers: typing.Optional[list[int]] = None +) +``` + + + + + + +Dataclass + +Container for Ray futures with associated worker information. + + + + + + + + + + + + + + +```python +nemo_rl.distributed.worker_groups.MultiWorkerFuture.get_results( + worker_group: nemo_rl.distributed.worker_groups.RayWorkerGroup, + return_generators_as_proxies: bool = False +) -> list[typing.Any] +``` + + + + + + +Get results from the futures, optionally respecting tied workers. + +The method uses worker_group.worker_to_tied_group_index to identify which tied +worker group each worker belongs to, then selects only the first result from each group. + +**Parameters:** + + +The RayWorkerGroup that spawned the futures. The +mapping contained in worker_group.worker_to_tied_group_index +is required for the deduplication path. + + + +If True, and a future is an ObjectRefGenerator, + return the ObjectRefGenerator itself instead of consuming it. + + +**Returns:** `list[Any]` + +List of results + + + + + + + + + +```python +class nemo_rl.distributed.worker_groups.RayWorkerBuilder( + ray_actor_class_fqn: str, + args = (), + kwargs = {} +) +``` + + + + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerBuilder.__call__( + placement_group: ray.util.placement_group.PlacementGroup, + placement_group_bundle_index: int, + num_gpus: float | int, + bundle_indices: typing.Optional[tuple[int, list[int]]] = None, + extra_options: typing.Any = {} +) -> ray.actor.ActorHandle +``` + + + + + + +Create a Ray worker with the specified configuration. + +Order of precedence for worker options configuration (from lowest to highest): +1. Options passed by the user to __call__ (extra_options) +2. Options required by the worker via configure_worker (may override user options with warning) +3. Options set by the RayWorkerBuilder.__call__ (specifically scheduling strategy) + +If the worker needs to override user-provided options, it should log a warning +to inform the user about the change and the reason for it. + +**Parameters:** + + +Ray placement group for resource allocation + + + +Index of the bundle in the placement group + + + +Number of GPUs to allocate to this worker (can be fractional) + + + +Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) + + + +Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method) + + +**Returns:** `ray.actor.ActorHandle` + +A Ray actor reference to the created worker + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerBuilder.create_worker_async( + placement_group: ray.util.placement_group.PlacementGroup, + placement_group_bundle_index: int, + num_gpus: float | int, + bundle_indices: typing.Optional[tuple[int, list[int]]] = None, + extra_options: typing.Any = {} +) -> tuple[ray.ObjectRef, ray.actor.ActorHandle] +``` + + + + + + +Create a Ray worker asynchronously, returning futures. + +This method returns immediately with futures that can be awaited later. + +**Parameters:** + + +Ray placement group for resource allocation + + + +Index of the bundle in the placement group + + + +Number of GPUs to allocate to this worker (can be fractional) + + + +Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) + + + +Additional options to pass to the Ray actor + + +**Returns:** `tuple[ray.ObjectRef, ray.actor.ActorHandle]` + +Tuple of (worker_future, initializer_actor): +- worker_future: A Ray ObjectRef that will resolve to the worker actor +- initializer_actor: The initializer actor (needed to prevent GC) + + + + + + + + + +```python +class nemo_rl.distributed.worker_groups.RayWorkerGroup( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, + remote_worker_builder: nemo_rl.distributed.worker_groups.RayWorkerBuilder, + workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None, + name_prefix: str = '', + bundle_indices_list: typing.Optional[list[tuple[int, list[int]]]] = None, + sharding_annotations: typing.Optional[nemo_rl.distributed.named_sharding.NamedSharding] = None, + env_vars: dict[str, str] = {} +) +``` + + + + + + +Manages a group of distributed Ray worker/actor processes that execute tasks in parallel. + +This class creates and manages Ray actor instances that run on resources +allocated by a RayVirtualCluster. It handles: +- Worker creation and placement on specific GPU resources +- Setting up distributed training environment variables (rank, world size, etc.) +- Executing methods across all workers in parallel +- Collecting and aggregating results +- Support for tied worker groups where multiple workers process the same data + + + + + + + + + + + + +Number of data parallel shards. + + + + + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup._create_workers_from_bundle_indices( + remote_worker_builder: nemo_rl.distributed.worker_groups.RayWorkerBuilder, + bundle_indices_list: list[tuple[int, list[int]]], + env_vars: dict[str, str] = {} +) -> None +``` + + + + + + +Create workers based on explicit bundle indices for tied worker groups. + +**Parameters:** + + +Builder function for Ray actors + + + +List of (node_idx, local_bundle_indices) tuples, where each tuple + specifies a tied group with its node and local bundle indices. If the local_bundle_indices + spans multiple nodes, the node_idx will be the first node's index in the tied group. + + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.get_all_worker_results( + future_bundle: nemo_rl.distributed.worker_groups.MultiWorkerFuture, + return_generators_as_proxies: bool = False +) -> list[typing.Any] +``` + + + + + + +Get results from all workers, optionally filtering to get just one result per tied worker group. + +**Parameters:** + + +MultiWorkerFuture containing futures and worker information. + + + +If True, and a future in the bundle is an ObjectRefGenerator, + return the ObjectRefGenerator itself instead of consuming it. + + +**Returns:** `list[Any]` + +List of results, deduplicated as specified in the future_bundle + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.get_dp_leader_worker_idx( + dp_shard_idx: int +) -> int +``` + + + + + + +Returns the index of the primary worker for a given data parallel shard. + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.run_all_workers_multiple_data( + method_name: str, + args = (), + run_rank_0_only_axes: list[str] | None = None, + common_kwargs: typing.Optional[dict[str, typing.Any]] = None, + kwargs = {} +) -> list[ray.ObjectRef] +``` + + + + + + +Run a method on all workers in parallel with different data. + +**Parameters:** + + +Name of the method to call on each worker + + + +List of arguments to pass to workers/groups + e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]] + + + +List of named axes for which only rank 0 should run the method. + + + +Keyword arguments to pass to all workers + + + +Keyword arguments to pass to workers/groups + e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]} + + +**Returns:** `list[ray.ObjectRef]` + +list[ray.ObjectRef]: A list of ray futures + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.run_all_workers_sharded_data( + method_name: str, + args = (), + in_sharded_axes: list[str] | None = None, + replicate_on_axes: list[str] | None = None, + output_is_replicated: list[str] | None = None, + make_dummy_calls_to_free_axes: bool = False, + common_kwargs: typing.Optional[dict[str, typing.Any]] = None, + kwargs = {} +) -> nemo_rl.distributed.worker_groups.MultiWorkerFuture +``` + + + + + + +Run a method on all workers in parallel with sharded data. + +Axes in in_sharded_axes: Data is already split across these axes, so we just send the appropriate slice to each worker (along this axis) +Axes in replicate_on_axes: Data is replicated to all workers along these dimensions +Free axes (axes not in either list): Data is only sent to workers at index 0 of these axes + +**Parameters:** + + +Name of the method to call on each worker + + + +List of arguments to pass to workers/groups + e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]] + + + +List of axes that are sharded + + + +List of axes that are to be replicated + + + +List of axes along which the output is replicated (and we should just return the first result). + We also just return from rank 0 of free axes. + + + +Whether to make dummy calls (with None) to workers that + aren't rank 0 on 'free axes' (axes not in in_sharded_axes or replicate_on_axes). + + + +Keyword arguments to pass to all workers + + + +Keyword arguments to pass to workers/groups + e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]} + + +**Returns:** `MultiWorkerFuture` + +Object containing futures and their associated worker information + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.run_all_workers_single_data( + method_name: str, + args = (), + run_rank_0_only_axes: list[str] | None = None, + kwargs = {} +) -> list[ray.ObjectRef] +``` + + + + + + +Run a method on all workers in parallel with the same data. + +**Parameters:** + + +Name of the method to call on each worker + + + +Arguments to pass to the method + + + +List of named axes for which only rank 0 should run the method. + + +**Returns:** `list[ray.ObjectRef]` + +list[ray.ObjectRef]: A list of ray futures + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.run_single_worker_single_data( + method_name: str, + worker_idx: int, + args = (), + kwargs = {} +) -> ray.ObjectRef +``` + + + + + + +Run a method on a single, specific worker. + +**Parameters:** + + +Name of the method to call on the worker. + + + +The index of the worker to run the method on. + + + +Arguments to pass to the method. + + +**Returns:** `ray.ObjectRef` + +ray.ObjectRef: A Ray future for the result. + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.shutdown( + cleanup_method: typing.Optional[str] = None, + timeout: typing.Optional[float] = 30.0, + force: bool = False +) -> bool +``` + + + + + + +Shutdown all workers in the worker group. + +**Parameters:** + + +Optional method name to call on each worker before termination. + If provided, this method will be called on each worker to allow + for graceful cleanup. + + + +Timeout in seconds for graceful shutdown. Only applicable if cleanup_method is provided. + If None, wait indefinitely for workers to complete their cleanup. + + + +If True, forcefully terminate workers with ray.kill() even if cleanup_method is provided. + If cleanup_method is None, workers are always forcefully terminated. + + +**Returns:** `bool` + +True if all workers were successfully shut down + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx new file mode 100644 index 0000000..de15010 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx @@ -0,0 +1,19 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments +title: nemo_rl.environments +--- + +## Submodules + +- **[`nemo_rl.environments.code_environment`](/nemo-rl/nemo_rl/environments/code_environment)** +- **[`nemo_rl.environments.code_jaccard_environment`](/nemo-rl/nemo_rl/environments/code_jaccard_environment)** +- **[`nemo_rl.environments.dapo_math_verifier`](/nemo-rl/nemo_rl/environments/dapo_math_verifier)** +- **[`nemo_rl.environments.interfaces`](/nemo-rl/nemo_rl/environments/interfaces)** +- **[`nemo_rl.environments.math_environment`](/nemo-rl/nemo_rl/environments/math_environment)** +- **[`nemo_rl.environments.metrics`](/nemo-rl/nemo_rl/environments/metrics)** +- **[`nemo_rl.environments.nemo_gym`](/nemo-rl/nemo_rl/environments/nemo_gym)** +- **[`nemo_rl.environments.reward_model_environment`](/nemo-rl/nemo_rl/environments/reward_model_environment)** +- **[`nemo_rl.environments.rewards`](/nemo-rl/nemo_rl/environments/rewards)** +- **[`nemo_rl.environments.utils`](/nemo-rl/nemo_rl/environments/utils)** +- **[`nemo_rl.environments.vlm_environment`](/nemo-rl/nemo_rl/environments/vlm_environment)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx new file mode 100644 index 0000000..5c46941 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx @@ -0,0 +1,290 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/code_environment +title: nemo_rl.environments.code_environment +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CodeEnvConfig`](#nemo_rl-environments-code_environment-CodeEnvConfig) | - | +| [`CodeEnvMetadata`](#nemo_rl-environments-code_environment-CodeEnvMetadata) | - | +| [`CodeEnvironment`](#nemo_rl-environments-code_environment-CodeEnvironment) | Code execution environment that maintains state between steps. | +| [`CodeExecutionWorker`](#nemo_rl-environments-code_environment-CodeExecutionWorker) | Helper class to process individual code execution steps. | + +### API + + + + + +```python +class nemo_rl.environments.code_environment.CodeEnvConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.code_environment.CodeEnvMetadata +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.code_environment.CodeEnvironment( + cfg: nemo_rl.environments.code_environment.CodeEnvConfig +) +``` + + + + + + +**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + +Code execution environment that maintains state between steps. + + + + + + + + + + + + + + +```python +nemo_rl.environments.code_environment.CodeEnvironment.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict +) -> typing.Tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict] +``` + + + + + + +Compute metrics for the batch. + + + + + + + +```python +nemo_rl.environments.code_environment.CodeEnvironment.shutdown() +``` + + + + + + + + + + + + +```python +nemo_rl.environments.code_environment.CodeEnvironment.step( + message_log_batch: typing.List[nemo_rl.data.interfaces.LLMMessageLogType], + metadata_batch: typing.List[nemo_rl.environments.code_environment.CodeEnvMetadata], + return_extracted_answer: bool = False +) -> nemo_rl.environments.interfaces.EnvironmentReturn +``` + + + + + + +Process a batch of code execution steps. + + + + + + + + + +```python +class nemo_rl.environments.code_environment.CodeExecutionWorker() +``` + + + + + + +Helper class to process individual code execution steps. + + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.chdir( + dir: str +) +``` + + + + + + +Change to temporary directory for file operations. + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.execute( + message_batch: str, + metadata_batch: typing.List[nemo_rl.environments.code_environment.CodeEnvMetadata] +) -> typing.Tuple[typing.List[typing.Dict[str, str]], typing.List[bool], typing.List[typing.Any]] +``` + + + + + + +Execute code in a sandboxed environment. + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.format_result( + result: typing.Any, + code: typing.Optional[str] = None, + lookahead: typing.Optional[str] = None +) -> str +``` + + + + + + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.safe_import( + name: str, + args = (), + kwargs = {} +) +``` + + + + + + +Safe version of import that blocks risky modules. + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.safe_open( + file: str, + args = (), + kwargs = {} +) +``` + + + + + + +Safe version of open() that only allows access to temporary directory. + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.sanitize( + obj: typing.Any +) -> typing.Any +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx new file mode 100644 index 0000000..0bdc82a --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx @@ -0,0 +1,268 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/code_jaccard_environment +title: nemo_rl.environments.code_jaccard_environment +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CodeJaccardEnvConfig`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardEnvConfig) | - | +| [`CodeJaccardEnvironment`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardEnvironment) | Environment for evaluating code responses using Jaccard similarity. | +| [`CodeJaccardEnvironmentMetadata`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardEnvironmentMetadata) | - | +| [`CodeJaccardVerifyWorker`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardVerifyWorker) | Worker for evaluating code responses using Jaccard-based similarity. | + +### API + + + + + +```python +class nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment( + cfg: nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvConfig +) +``` + + + + + + +**Bases:** [EnvironmentInterface[CodeJaccardEnvironmentMetadata]](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + +Environment for evaluating code responses using Jaccard similarity. + + + + + + + + + + + +```python +nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], dict[str, float | int]] +``` + + + + + + +Post-process batch and compute metrics for CodeJaccard. + + + + + + + +```python +nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment.shutdown() -> None +``` + + + + + + +Shutdown all workers. + + + + + + + +```python +nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment.step( + message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], + metadata: list[nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironmentMetadata], + return_extracted_answer: bool = False +) -> nemo_rl.environments.interfaces.EnvironmentReturn[nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironmentMetadata] +``` + + + + + + +Runs a step in the CodeJaccard environment. + +**Parameters:** + + +Batch of OpenAI-API-like message logs. + + + +Batch of CodeJaccardEnvironmentMetadata with ground truth. + + + +Whether to return extracted answers. + + +**Returns:** `EnvironmentReturn[CodeJaccardEnvironmentMetadata]` + +Tuple containing observations, metadata, stop strings, rewards, and done flags. + + + + + + + + + +```python +class nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironmentMetadata +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.environments.code_jaccard_environment.CodeJaccardVerifyWorker() +``` + + + + + + +Worker for evaluating code responses using Jaccard-based similarity. + + + + + + +```python +nemo_rl.environments.code_jaccard_environment.CodeJaccardVerifyWorker._calculate_preference_score( + response: str, + ground_truth: str +) -> float +``` + + + + + + +Calculate a Jaccard-based alignment score between response and ground truth. + +This is a simplified scoring function. In practice, you might want to use: +- Semantic similarity models +- BLEU/ROUGE scores +- Tokenize both texts into sets A and B (here we use whitespace tokenization). +- Compute intersection size |A ∩ B| and union size |A ∪ B|. +- J(A, B) = |A ∩ B| / |A ∪ B|, with guards for union=0 -> 0.0. +- Optionally combine with a length-ratio penalty to discourage degenerate very short/long matches. + +Complexity: +- Tokenization: O(n + m) +- Set ops: O(n + m) average (hash sets) + +**Parameters:** + + +The model's response + + +**Returns:** `float` + +Score between 0.0 and 1.0 + + + + + + + +```python +nemo_rl.environments.code_jaccard_environment.CodeJaccardVerifyWorker.verify( + pred_responses: list[str], + ground_truths: list[str], + return_extracted_answer: bool = False +) -> typing.Union[list[float], tuple[list[float], list[str | None]]] +``` + + + + + + +Verify code responses against ground-truth solutions using Jaccard-based similarity. + +We use a simple text similarity approach (Jaccard over tokenized words) +to evaluate how well the model's response aligns with the ground truth. + +**Parameters:** + + +list[str]. The predicted responses from the LLM. + + + +list[str]. The ground-truth solutions. + + + +bool. Whether to return extracted answers (here, the full response). + + +**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` + +Union[list[float], tuple[list[float], list[str | None]]]. + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx new file mode 100644 index 0000000..ecea315 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx @@ -0,0 +1,316 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/dapo_math_verifier +title: nemo_rl.environments.dapo_math_verifier +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`compute_score`](#nemo_rl-environments-dapo_math_verifier-compute_score) | Compute the reward score for a solution. | +| [`is_correct_minerva`](#nemo_rl-environments-dapo_math_verifier-is_correct_minerva) | Check if the solution is correct according to Minerva criteria. | +| [`is_correct_strict_box`](#nemo_rl-environments-dapo_math_verifier-is_correct_strict_box) | Check if the prediction is correct using strict boxed answer criteria. | +| [`last_boxed_only_string`](#nemo_rl-environments-dapo_math_verifier-last_boxed_only_string) | Extract the last LaTeX boxed expression from a string. | +| [`normalize_final_answer`](#nemo_rl-environments-dapo_math_verifier-normalize_final_answer) | Normalize a final answer to a quantitative reasoning question. | +| [`remove_boxed`](#nemo_rl-environments-dapo_math_verifier-remove_boxed) | Remove the LaTeX boxed command from a string. | +| [`verify`](#nemo_rl-environments-dapo_math_verifier-verify) | Verify if the solution is correct. | + +### Data + +[`REMOVED_EXPRESSIONS`](#nemo_rl-environments-dapo_math_verifier-REMOVED_EXPRESSIONS) + +[`SUBSTITUTIONS`](#nemo_rl-environments-dapo_math_verifier-SUBSTITUTIONS) + +### API + + + + + +```python +nemo_rl.environments.dapo_math_verifier.compute_score( + solution_str: str, + ground_truth: str, + strict_box_verify: bool = False, + pause_tokens_index: typing.Optional[list[int]] = None +) -> float +``` + + + + + + +Compute the reward score for a solution. + +**Parameters:** + + +The solution string + + + +The ground truth answer + + + +Whether to use strict box verification + + + +Indices of pause tokens + + +**Returns:** `float` + +Reward score (1.0 for correct, 0.0 for incorrect) + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.is_correct_minerva( + solution_str: str, + gt: str, + gt_need_extract: bool = False, + answer_pattern: str = '(?i)Answer\\s*:\\s*([^\\n]+)' +) -> tuple[bool, str] +``` + + + + + + +Check if the solution is correct according to Minerva criteria. + +**Parameters:** + + +The solution string to check + + + +The ground truth answer + + + +Whether the ground truth needs extraction + + + +Regex pattern to extract the answer + + +**Returns:** `tuple[bool, str]` + +Tuple of (is_correct, normalized_prediction) + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.is_correct_strict_box( + pred: str, + gt: str, + pause_tokens_index: typing.Optional[list[int]] = None +) -> tuple[int, typing.Optional[str]] +``` + + + + + + +Check if the prediction is correct using strict boxed answer criteria. + +**Parameters:** + + +The prediction string + + + +The ground truth answer + + + +Indices of pause tokens + + +**Returns:** `tuple[int, Optional[str]]` + +Tuple of (score, extracted_prediction) + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.last_boxed_only_string( + string: str +) -> typing.Optional[str] +``` + + + + + + +Extract the last LaTeX boxed expression from a string. + +**Parameters:** + + +Input string containing LaTeX code + + +**Returns:** `Optional[str]` + +The last boxed expression or None if not found + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.normalize_final_answer( + final_answer: str +) -> str +``` + + + + + + +Normalize a final answer to a quantitative reasoning question. + +**Parameters:** + + +The answer string to normalize + + +**Returns:** `str` + +Normalized answer string + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.remove_boxed( + s: str +) -> str +``` + + + + + + +Remove the LaTeX boxed command from a string. + +**Parameters:** + + +String with format "\\boxed{content}" + + +**Returns:** `str` + +The content inside the boxed command + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.verify( + solution_str: str, + answer: str, + strict_box_verify: bool = False, + pause_tokens_index: typing.Optional[list[int]] = None +) -> bool +``` + + + + + + +Verify if the solution is correct. + +**Parameters:** + + +The solution string to verify + + + +The ground truth answer + + + +Whether to use strict box verification + + + +Indices of pause tokens + + +**Returns:** `bool` + +True if the solution is correct, False otherwise + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.REMOVED_EXPRESSIONS = ['square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'hours', 'km', 'units... +``` + + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''), ('\\ ', ''), (' ', ''), ('mb... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx new file mode 100644 index 0000000..487b356 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx @@ -0,0 +1,151 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/interfaces +title: nemo_rl.environments.interfaces +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`EnvironmentInterface`](#nemo_rl-environments-interfaces-EnvironmentInterface) | - | +| [`EnvironmentReturn`](#nemo_rl-environments-interfaces-EnvironmentReturn) | Standard batched return type for environment step methods. | + +### Data + +[`MetadataT`](#nemo_rl-environments-interfaces-MetadataT) + +### API + + + + + +```python +class nemo_rl.environments.interfaces.EnvironmentInterface() +``` + + + + + + +Abstract + +**Bases:** `ABC`, `Generic[MetadataT]` + + + + + +```python +nemo_rl.environments.interfaces.EnvironmentInterface.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict] +``` + + + + + + +abstract + +Post processing function after all rollouts are done for the batch and returns metrics. + + + + + + + +```python +nemo_rl.environments.interfaces.EnvironmentInterface.step( + message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], + metadata: list[nemo_rl.environments.interfaces.MetadataT] +) -> nemo_rl.environments.interfaces.EnvironmentReturn[nemo_rl.environments.interfaces.MetadataT] +``` + + + + + + +abstract + +Runs a step in the environment. Allows for asynchrony with remote servers, but it's not required (this function is a ray remote). + +metadata: batch of whatever the environment needs to keep track of. I.e. + math solutions, code unit tests, or agent states. Can be None if episode terminated. + +Returns: +- EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminateds flags. + + + + + + + + + +```python +class nemo_rl.environments.interfaces.EnvironmentReturn() +``` + + + + + + +**Bases:** `NamedTuple`, `Generic[MetadataT]` + +Standard batched return type for environment step methods. + +**All elements are batched.** +observations: New observation from the environment. + It's a (batched) 'message' type, which is a dict + with keys 'role' and 'content'. +metadata: Updated metadata from the environment. +next_stop_strings: The stop strings for the next turn. + If your environment is a game or similar, + you may want to return a list of stop strings + that are valid actions for the next turn or + similar. This field lets you control this per turn. +rewards: the rewards for this turn. +terminateds: whether the episode ended this turn. +answers: the answers for this turn. + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.environments.interfaces.MetadataT = TypeVar('MetadataT') +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx new file mode 100644 index 0000000..eccdd26 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx @@ -0,0 +1,356 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/math_environment +title: nemo_rl.environments.math_environment +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`EnglishMultichoiceVerifyWorker`](#nemo_rl-environments-math_environment-EnglishMultichoiceVerifyWorker) | - | +| [`HFVerifyWorker`](#nemo_rl-environments-math_environment-HFVerifyWorker) | - | +| [`MathEnvConfig`](#nemo_rl-environments-math_environment-MathEnvConfig) | - | +| [`MathEnvironment`](#nemo_rl-environments-math_environment-MathEnvironment) | - | +| [`MathEnvironmentMetadata`](#nemo_rl-environments-math_environment-MathEnvironmentMetadata) | - | +| [`MultilingualMultichoiceVerifyWorker`](#nemo_rl-environments-math_environment-MultilingualMultichoiceVerifyWorker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_mute_output`](#nemo_rl-environments-math_environment-_mute_output) | - | + +### API + + + + + +```python +class nemo_rl.environments.math_environment.EnglishMultichoiceVerifyWorker() +``` + + + + + + + + + + +```python +nemo_rl.environments.math_environment.EnglishMultichoiceVerifyWorker.verify( + pred_responses: list[str], + ground_truths: list[str], + return_extracted_answer: bool = False, + kwargs = {} +) -> typing.Union[list[float], tuple[list[float], list[str | None]]] +``` + + + + + + +Verify the correctness of the predicted responses against the ground truth. + +**Parameters:** + + +list[str]. The predicted responses from the LLM. + + + +list[str]. The ground truth responses. + + +**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` + +Union[list[float], tuple[list[float], list[str | None]]]. + + + + + + + + + +```python +class nemo_rl.environments.math_environment.HFVerifyWorker() +``` + + + + + + + + + + + + +```python +nemo_rl.environments.math_environment.HFVerifyWorker.verify( + pred_responses: list[str], + ground_truths: list[str], + return_extracted_answer: bool = False, + kwargs = {} +) -> typing.Union[list[float], tuple[list[float], list[str | None]]] +``` + + + + + + +Verify the correctness of the predicted responses against the ground truth. + +**Parameters:** + + +list[str]. The predicted responses from the LLM. + + + +list[str]. The ground truth responses. + + +**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` + +Union[list[float], tuple[list[float], list[str | None]]]. + + + + + + + + + +```python +class nemo_rl.environments.math_environment.MathEnvConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.math_environment.MathEnvironment( + cfg: nemo_rl.environments.math_environment.MathEnvConfig +) +``` + + + + + + +**Bases:** [EnvironmentInterface[MathEnvironmentMetadata]](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + + + + + + + + + + +```python +nemo_rl.environments.math_environment.MathEnvironment.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], dict[str, float | int]] +``` + + + + + + +Computes metrics for this environment given a global rollout batch. + +Every rank will run this function, so you're free to use distributed +calculations if you'd prefer for heavy metrics. + + + + + + + +```python +nemo_rl.environments.math_environment.MathEnvironment.shutdown() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.environments.math_environment.MathEnvironment.step( + message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], + metadata: list[nemo_rl.environments.math_environment.MathEnvironmentMetadata], + return_extracted_answer: bool = False +) -> nemo_rl.environments.interfaces.EnvironmentReturn[nemo_rl.environments.math_environment.MathEnvironmentMetadata] +``` + + + + + + +Runs a step in the math environment. + +**Parameters:** + + +list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM. + + + +list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. The extracted answer will be stored to caculate cons@k. + + +**Returns:** `EnvironmentReturn[MathEnvironmentMetadata]` + +A tuple containing: +- list[dict[str, str]]: Observations/responses batch +- list[dict]: Updated metadata +- list[str]: Next stop strings for the next turn +- Tensor: Rewards tensor +- Tensor: Done flags tensor + + + + + + + + + +```python +class nemo_rl.environments.math_environment.MathEnvironmentMetadata +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.math_environment.MultilingualMultichoiceVerifyWorker() +``` + + + + + + + + + + +```python +nemo_rl.environments.math_environment.MultilingualMultichoiceVerifyWorker.verify( + pred_responses: list[str], + ground_truths: list[str], + return_extracted_answer: bool = False, + kwargs = {} +) -> typing.Union[list[float], tuple[list[float], list[str | None]]] +``` + + + + + + +Verify the correctness of the predicted responses against the ground truth. + +**Parameters:** + + +list[str]. The predicted responses from the LLM. + + + +list[str]. The ground truth responses. + + +**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` + +Union[list[float], tuple[list[float], list[str | None]]]. + + + + + + + + + +```python +nemo_rl.environments.math_environment._mute_output() +``` + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx new file mode 100644 index 0000000..12386b2 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx @@ -0,0 +1,42 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/metrics +title: nemo_rl.environments.metrics +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`calculate_pass_rate_per_prompt`](#nemo_rl-environments-metrics-calculate_pass_rate_per_prompt) | Function to compute fraction of prompts that have at least one correct answer (reward > 0). | + +### API + + + + + +```python +nemo_rl.environments.metrics.calculate_pass_rate_per_prompt( + prompts: torch.Tensor, + is_correct: torch.Tensor +) -> float +``` + + + + + + +Function to compute fraction of prompts that have at least one correct answer (reward > 0). + +prompts: tensor (b, s) Tensor of prompts the model used. May be on any device +is_correct: tensor (b,) bool-valued label. May be on any device + +Returns: +pass rate: float + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx new file mode 100644 index 0000000..b8b85fe --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx @@ -0,0 +1,210 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/nemo_gym +title: nemo_rl.environments.nemo_gym +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`NemoGym`](#nemo_rl-environments-nemo_gym-NemoGym) | This environment class isn't really used for training. It's really meant as an integration wrapper around NeMo-Gym that hooks into the existing NeMo RL resource management via ray. So there is still one source of truth for resource management in NeMo RL. | +| [`NemoGymConfig`](#nemo_rl-environments-nemo_gym-NemoGymConfig) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`setup_nemo_gym_config`](#nemo_rl-environments-nemo_gym-setup_nemo_gym_config) | - | + +### API + + + + + +```python +class nemo_rl.environments.nemo_gym.NemoGym( + cfg: nemo_rl.environments.nemo_gym.NemoGymConfig +) +``` + + + + + + +**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + +This environment class isn't really used for training. It's really meant as an integration wrapper around NeMo-Gym that hooks into the existing NeMo RL resource management via ray. So there is still one source of truth for resource management in NeMo RL. + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym._postprocess_nemo_gym_to_nemo_rl_result( + nemo_gym_result: dict, + tokenizer: transformers.PreTrainedTokenizerBase +) -> dict +``` + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym.global_post_process_and_metrics( + batch +) +``` + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym.health_check() -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym.run_rollouts( + nemo_gym_examples: list[dict], + tokenizer: transformers.PreTrainedTokenizerBase, + timer_prefix: str +) -> list[dict] +``` + + + + + + +async + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym.shutdown() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym.step( + message_log_batch, + metadata +) +``` + + + + + + + + + + + + + + +```python +class nemo_rl.environments.nemo_gym.NemoGymConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.setup_nemo_gym_config( + config, + tokenizer +) -> None +``` + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx new file mode 100644 index 0000000..67f23d3 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx @@ -0,0 +1,276 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/reward_model_environment +title: nemo_rl.environments.reward_model_environment +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RewardModelEnvironment`](#nemo_rl-environments-reward_model_environment-RewardModelEnvironment) | Environment that uses a reward model to score conversations. | +| [`RewardModelEnvironmentConfig`](#nemo_rl-environments-reward_model_environment-RewardModelEnvironmentConfig) | Configuration for RewardModelEnvironment. | + +### API + + + + + +```python +class nemo_rl.environments.reward_model_environment.RewardModelEnvironment( + config: typing.Dict[str, typing.Any] +) +``` + + + + + + +**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + +Environment that uses a reward model to score conversations. + +This environment implements a reward model-based scoring system for reinforcement +learning tasks. It takes conversation logs as input and returns rewards based on +the quality of the assistant's responses as judged by a pre-trained reward model. + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.environments.reward_model_environment.RewardModelEnvironment.__del__() +``` + + + + + + +Destructor that ensures proper cleanup when the object is garbage collected. + +This is an extra safety net in case the user forgets to call shutdown() and +the pointer to the object is lost due to leaving a function scope. It's always +recommended that the user calls shutdown() explicitly for better resource +management. + + + + + + + +```python +nemo_rl.environments.reward_model_environment.RewardModelEnvironment.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict +) -> typing.Tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict] +``` + + + + + + +Post processing function after all rollouts are done for the batch and returns metrics. + +This method computes aggregate statistics and metrics from the processed batch. +It provides insights into reward distribution and processing statistics. + +**Parameters:** + + +The batch data dictionary containing processed conversations and rewards. + + +**Returns:** `BatchedDataDict` + +Tuple of (processed_batch, metrics_dict) where: + + + + + + + +```python +nemo_rl.environments.reward_model_environment.RewardModelEnvironment.preprocess_data( + message_logs: typing.List[nemo_rl.data.interfaces.LLMMessageLogType] +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec] +``` + + + + + + +Preprocess the message logs for the reward model. + +This method tokenizes and formats conversation logs into the format expected +by the reward model. It handles: +- Tokenization of user and assistant messages +- Formatting with proper special tokens +- Batching and padding for efficient processing +- Sequence length validation and truncation + +**Parameters:** + + +List of conversation message logs, where each log contains + a list of messages with 'role' and 'content' fields. + + +**Returns:** `BatchedDataDict[GenerationDatumSpec]` + +BatchedDataDict containing tokenized and formatted data ready for + + + + + + + +```python +nemo_rl.environments.reward_model_environment.RewardModelEnvironment.shutdown() +``` + + + + + + +Shutdown the reward model worker and virtual cluster. + +This method properly cleans up resources by shutting down the reward model +policy and virtual cluster. It should be called when the environment is +no longer needed to prevent resource leaks. + + + + + + + +```python +nemo_rl.environments.reward_model_environment.RewardModelEnvironment.step( + message_logs: typing.List[nemo_rl.data.interfaces.LLMMessageLogType], + env_infos: typing.List[typing.Dict[str, typing.Any]] +) -> nemo_rl.environments.interfaces.EnvironmentReturn +``` + + + + + + +Calculate rewards for the given message logs using the reward model. + +This method processes conversation logs through the reward model to compute +quality scores for each conversation. The rewards are based on the reward +model's assessment of how well the assistant's responses align with human +preferences. + +**Parameters:** + + +List of conversation message logs to be scored. + Each log should contain alternating user and assistant messages. + + + +List of environment info dictionaries (currently unused + but required by the interface). + + +**Returns:** `EnvironmentReturn` + +EnvironmentReturn containing: + + + + + + + + + +```python +class nemo_rl.environments.reward_model_environment.RewardModelEnvironmentConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for RewardModelEnvironment. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx new file mode 100644 index 0000000..aaa41ba --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx @@ -0,0 +1,180 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/rewards +title: nemo_rl.environments.rewards +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`bbox_giou_reward`](#nemo_rl-environments-rewards-bbox_giou_reward) | Given [x1, y1, x2, y2] normalized bounding box coordinates within the <{answer_tag}> tags, compute the GIoU between the ground truth and the response. | +| [`combine_reward_functions`](#nemo_rl-environments-rewards-combine_reward_functions) | Returns a callable function that takes (ground_truth, response) and collects multiple reward functions in sequence. | +| [`exact_answer_alphanumeric_reward`](#nemo_rl-environments-rewards-exact_answer_alphanumeric_reward) | Reward the agent when the answer within the <{answer_tag}> tags is the same as the ground truth (case-insensitive). | +| [`format_reward`](#nemo_rl-environments-rewards-format_reward) | Reward the agent when the response follows the format: (.*) <think> (.*) </think> <answer> (.*) </answer>. | +| [`math_expression_reward`](#nemo_rl-environments-rewards-math_expression_reward) | Reward the agent when the answer within the <{tag}> tags is the same expression as the ground truth. | + +### Data + +[`boxed`](#nemo_rl-environments-rewards-boxed) + +[`math_verify_func`](#nemo_rl-environments-rewards-math_verify_func) + +### API + + + + + +```python +nemo_rl.environments.rewards.bbox_giou_reward( + ground_truth: str, + response: str, + giou_penalty_thres: float = 10.0, + answer_tag: str = 'answer' +) -> tuple[float, bool] +``` + + + + + + +Given [x1, y1, x2, y2] normalized bounding box coordinates within the <{answer_tag}> tags, compute the GIoU between the ground truth and the response. + +The `answer_tag` is customizable and must be specified as part of the user COT prompt text file. + + + + + + + + +```python +nemo_rl.environments.rewards.combine_reward_functions( + reward_functions: list[tuple[typing.Callable[[str, str], tuple[float, bool]], float]] +) -> typing.Callable[[str, str], tuple[float, bool]] +``` + + + + + + +Returns a callable function that takes (ground_truth, response) and collects multiple reward functions in sequence. + +The reward functions are weighted by the second element of the tuple. +This information can be provided in the YAML config file and resolved in the VLMEnvironment class. + +**Parameters:** + + +list[tuple[Callable[[str, str], tuple[float, bool]], float]]. A list of reward functions and their weights. + + +**Returns:** `Callable[[str, str], tuple[float, bool]]` + +Callable[[str, str], tuple[float, bool]]: A callable function that takes (ground_truth, response) and collects multiple reward functions in sequence + + + + + + + + +```python +nemo_rl.environments.rewards.exact_answer_alphanumeric_reward( + ground_truth: str, + response: str, + answer_tag: str = 'answer' +) -> tuple[float, bool] +``` + + + + + + +Reward the agent when the answer within the <{answer_tag}> tags is the same as the ground truth (case-insensitive). + +The `answer_tag` is customizable and must be specified as part of the user COT prompt text file. + + + + + + + + +```python +nemo_rl.environments.rewards.format_reward( + ground_truth: str, + response: str, + think_tag: str = 'think', + answer_tag: str = 'answer' +) -> tuple[float, typing.Optional[bool]] +``` + + + + + + +Reward the agent when the response follows the format: (.*) <think> (.*) </think> <answer> (.*) </answer>. + +The `think_tag` and `answer_tag` are customizable and must be specified as part of the user COT prompt text file. + + + + + + + + +```python +nemo_rl.environments.rewards.math_expression_reward( + ground_truth: str, + response: str, + tag: str = 'answer' +) -> tuple[float, bool] +``` + + + + + + +Reward the agent when the answer within the <{tag}> tags is the same expression as the ground truth. + +The `tag` is customizable and must be specified as part of the user COT prompt text file. + + + + + + + + +```python +nemo_rl.environments.rewards.boxed = lambda x: '\\boxed{' + x + '}' if not x.startswith('\\boxed{') else x +``` + + + + + + + + + +```python +nemo_rl.environments.rewards.math_verify_func = math_metric(gold_extraction_target=(LatexExtractionConfig(),), pred_extraction_t... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx new file mode 100644 index 0000000..67e63b2 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx @@ -0,0 +1,152 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/utils +title: nemo_rl.environments.utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`EnvRegistryEntry`](#nemo_rl-environments-utils-EnvRegistryEntry) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`chunk_list_to_workers`](#nemo_rl-environments-utils-chunk_list_to_workers) | Chunk a list into a list of lists, where each sublist is assigned to a worker. Keeps ordering of elements. | +| [`create_env`](#nemo_rl-environments-utils-create_env) | - | +| [`register_env`](#nemo_rl-environments-utils-register_env) | - | + +### Data + +[`ENV_REGISTRY`](#nemo_rl-environments-utils-ENV_REGISTRY) + +### API + + + + + +```python +class nemo_rl.environments.utils.EnvRegistryEntry +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +nemo_rl.environments.utils.chunk_list_to_workers( + to_chunk: list[typing.Any], + num_workers: int +) -> list[list[typing.Any]] +``` + + + + + + +Chunk a list into a list of lists, where each sublist is assigned to a worker. Keeps ordering of elements. + +If the list is not divisible by the number of workers, the last worker may have fewer elements. +If there are more workers than elements, the first len(list) workers will have a single element each, +and the remaining workers will have empty lists. + +Examples: + + +```python +>>> from nemo_rl.environments.utils import chunk_list_to_workers +>>> chunk_list_to_workers([1, 2, 3, 4, 5], 3) +[[1, 2], [3, 4], [5]] +``` + + + +**Parameters:** + + +The list to be chunked. + + + +The number of workers to distribute the list to. + + +**Returns:** `list[list[Any]]` + +A list of lists, where each sublist contains elements assigned to a worker. + + + + + + + + +```python +nemo_rl.environments.utils.create_env( + env_name: str, + env_config: dict +) -> nemo_rl.environments.interfaces.EnvironmentInterface +``` + + + + + + + + + + + + + +```python +nemo_rl.environments.utils.register_env( + env_name: str, + actor_class_fqn: str +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.environments.utils.ENV_REGISTRY: Dict[str, EnvRegistryEntry] = {'math_default': {'actor_class_fqn': 'nemo_rl.environments.math_environment.Math... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx new file mode 100644 index 0000000..40b7a2a --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx @@ -0,0 +1,243 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/vlm_environment +title: nemo_rl.environments.vlm_environment +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`VLMEnvConfig`](#nemo_rl-environments-vlm_environment-VLMEnvConfig) | - | +| [`VLMEnvironment`](#nemo_rl-environments-vlm_environment-VLMEnvironment) | - | +| [`VLMEnvironmentMetadata`](#nemo_rl-environments-vlm_environment-VLMEnvironmentMetadata) | - | +| [`VLMVerifyWorker`](#nemo_rl-environments-vlm_environment-VLMVerifyWorker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_mute_output`](#nemo_rl-environments-vlm_environment-_mute_output) | - | + +### API + + + + + +```python +class nemo_rl.environments.vlm_environment.VLMEnvConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.vlm_environment.VLMEnvironment( + cfg: nemo_rl.environments.vlm_environment.VLMEnvConfig +) +``` + + + + + + +**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + + + + + + + + + + +```python +nemo_rl.environments.vlm_environment.VLMEnvironment.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], dict[str, float | int]] +``` + + + + + + +Computes metrics for this environment given a global rollout batch. + +Every rank will run this function, so you're free to use distributed +calculations if you'd prefer for heavy metrics. + + + + + + + +```python +nemo_rl.environments.vlm_environment.VLMEnvironment.shutdown() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.environments.vlm_environment.VLMEnvironment.step( + message_log_batch: list[list[dict[str, str]]], + metadata: list[nemo_rl.environments.vlm_environment.VLMEnvironmentMetadata] +) -> nemo_rl.environments.interfaces.EnvironmentReturn +``` + + + + + + +Runs a step in the vlm environment. + +**Parameters:** + + +list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the VLM. + + + +list[VLMEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. + + +**Returns:** `EnvironmentReturn` + +A tuple containing: +- list[dict[str, str]]: Observations/responses batch +- list[dict]: Updated metadata +- list[str]: Next stop strings for the next turn +- Tensor: Rewards tensor +- Tensor: Done flags tensor + + + + + + + + + +```python +class nemo_rl.environments.vlm_environment.VLMEnvironmentMetadata +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.environments.vlm_environment.VLMVerifyWorker( + cfg: nemo_rl.environments.vlm_environment.VLMEnvConfig +) +``` + + + + + + + + + + + + +```python +nemo_rl.environments.vlm_environment.VLMVerifyWorker.verify( + pred_responses: list[str], + ground_truths: list[str] +) -> list[float] +``` + + + + + + +Verify the correctness of the predicted responses against the ground truth. + +**Parameters:** + + +list[str]. The predicted responses from the LLM. + + + +list[str]. The ground truth responses. + + +**Returns:** `list[float]` + +list[float]. The rewards for each predicted response. + + + + + + + + + +```python +nemo_rl.environments.vlm_environment._mute_output() +``` + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx new file mode 100644 index 0000000..e7d333c --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx @@ -0,0 +1,10 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/evals +title: nemo_rl.evals +--- + +## Submodules + +- **[`nemo_rl.evals.answer_parsing`](/nemo-rl/nemo_rl/evals/answer_parsing)** +- **[`nemo_rl.evals.eval`](/nemo-rl/nemo_rl/evals/eval)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx new file mode 100644 index 0000000..9126f64 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx @@ -0,0 +1,86 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/evals/answer_parsing +title: nemo_rl.evals.answer_parsing +--- + +Contains utility functions for answer parsing. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`normalize_extracted_answer`](#nemo_rl-evals-answer_parsing-normalize_extracted_answer) | - | +| [`normalize_response`](#nemo_rl-evals-answer_parsing-normalize_response) | Normalize the response by removing markdown and LaTeX formatting that may prevent a match. | + +### Data + +[`MULTILINGUAL_ANSWER_PATTERN_TEMPLATE`](#nemo_rl-evals-answer_parsing-MULTILINGUAL_ANSWER_PATTERN_TEMPLATE) + +[`MULTILINGUAL_ANSWER_REGEXES`](#nemo_rl-evals-answer_parsing-MULTILINGUAL_ANSWER_REGEXES) + +### API + + + + + +```python +nemo_rl.evals.answer_parsing.normalize_extracted_answer( + extracted_answer: str +) -> str +``` + + + + + + + + + + + + + +```python +nemo_rl.evals.answer_parsing.normalize_response( + response: str +) -> str +``` + + + + + + +Normalize the response by removing markdown and LaTeX formatting that may prevent a match. + + + + + + + + +```python +nemo_rl.evals.answer_parsing.MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = '(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])' +``` + + + + + + + + + +```python +nemo_rl.evals.answer_parsing.MULTILINGUAL_ANSWER_REGEXES = ['Answer\\s*:', 'Answer\\s*:\u200b\u200b\u200b\u200b\u200b\u200b', 'উত্তর\\s*:',... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx new file mode 100644 index 0000000..f2712a9 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx @@ -0,0 +1,399 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/evals/eval +title: nemo_rl.evals.eval +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`EvalConfig`](#nemo_rl-evals-eval-EvalConfig) | - | +| [`MasterConfig`](#nemo_rl-evals-eval-MasterConfig) | - | +| [`_PassThroughMathConfig`](#nemo_rl-evals-eval-_PassThroughMathConfig) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_generate_texts`](#nemo_rl-evals-eval-_generate_texts) | Generate texts using either sync or async method. | +| [`_print_results`](#nemo_rl-evals-eval-_print_results) | Print evaluation results. | +| [`_run_env_eval_impl`](#nemo_rl-evals-eval-_run_env_eval_impl) | Unified implementation for both sync and async evaluation. | +| [`_save_evaluation_data_to_json`](#nemo_rl-evals-eval-_save_evaluation_data_to_json) | Save evaluation data to a JSON file. | +| [`eval_cons_k`](#nemo_rl-evals-eval-eval_cons_k) | Evaluate cons@k score using an unbiased estimator. | +| [`eval_pass_k`](#nemo_rl-evals-eval-eval_pass_k) | Evaluate pass@k score using an unbiased estimator. | +| [`run_env_eval`](#nemo_rl-evals-eval-run_env_eval) | Main entry point for running evaluation using environment. | +| [`setup`](#nemo_rl-evals-eval-setup) | Set up components for model evaluation. | + +### API + + + + + +```python +class nemo_rl.evals.eval.EvalConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.evals.eval.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.evals.eval._PassThroughMathConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +nemo_rl.evals.eval._generate_texts( + vllm_generation, + inputs, + use_async +) +``` + + + + + + +async + +Generate texts using either sync or async method. + + + + + + + + +```python +nemo_rl.evals.eval._print_results( + master_config, + generation_config, + score, + dataset_size, + metric, + k_value, + num_tests_per_prompt +) +``` + + + + + + +Print evaluation results. + + + + + + + + +```python +nemo_rl.evals.eval._run_env_eval_impl( + vllm_generation, + dataloader, + env, + master_config, + use_async = False +) +``` + + + + + + +async + +Unified implementation for both sync and async evaluation. + + + + + + + + +```python +nemo_rl.evals.eval._save_evaluation_data_to_json( + evaluation_data, + master_config, + save_path +) +``` + + + + + + +Save evaluation data to a JSON file. + +**Parameters:** + + +List of evaluation samples + + + +Configuration dictionary + + + +Path to save evaluation results. Set to null to disable saving. + Example: "results/eval_output" or "/path/to/evaluation_results" + + + + + + + + + +```python +nemo_rl.evals.eval.eval_cons_k( + rewards: torch.Tensor, + num_tests_per_prompt: int, + k: int, + extracted_answers: list[str | None] +) -> float +``` + + + + + + +Evaluate cons@k score using an unbiased estimator. + +**Parameters:** + + +Tensor of shape (batch_size * num_tests_per_prompt) + + + +int + + + +int + + + +list[str| None] + + +**Returns:** `float` + +float + + + + + + + + +```python +nemo_rl.evals.eval.eval_pass_k( + rewards: torch.Tensor, + num_tests_per_prompt: int, + k: int +) -> float +``` + + + + + + +Evaluate pass@k score using an unbiased estimator. + +Reference: https://github.com/huggingface/evaluate/blob/32546aafec25cdc2a5d7dd9f941fc5be56ba122f/metrics/code_eval/code_eval.py#L198-L213 +Args: + rewards: Tensor of shape (batch_size * num_tests_per_prompt) + k: int (pass@k value) + +**Returns:** `float` + +float + + + + + + + + +```python +nemo_rl.evals.eval.run_env_eval( + vllm_generation, + dataloader, + env, + master_config +) +``` + + + + + + +Main entry point for running evaluation using environment. + +Generates model responses and evaluates them by env. + +**Parameters:** + + +Model for generating responses. + + + +Data loader with evaluation samples. + + + +Environment that scores responses. + + + +Configuration settings. + + + + + + + + + +```python +nemo_rl.evals.eval.setup( + master_config: nemo_rl.evals.eval.MasterConfig, + tokenizer: transformers.AutoTokenizer, + dataset: nemo_rl.data.datasets.AllTaskProcessedDataset +) -> tuple[nemo_rl.models.generation.vllm.VllmGeneration, torch.utils.data.DataLoader, nemo_rl.evals.eval.MasterConfig] +``` + + + + + + +Set up components for model evaluation. + +Initializes the VLLM model and data loader. + +**Parameters:** + + +Configuration settings. + + + +Dataset to evaluate on. + + +**Returns:** `tuple[VllmGeneration, DataLoader, MasterConfig]` + +VLLM model, data loader, and config. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx new file mode 100644 index 0000000..eacff25 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx @@ -0,0 +1,9 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/experience +title: nemo_rl.experience +--- + +## Submodules + +- **[`nemo_rl.experience.rollouts`](/nemo-rl/nemo_rl/experience/rollouts)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx new file mode 100644 index 0000000..f00431e --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx @@ -0,0 +1,469 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/experience/rollouts +title: nemo_rl.experience.rollouts +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AsyncNemoGymRolloutResult`](#nemo_rl-experience-rollouts-AsyncNemoGymRolloutResult) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_calculate_single_metric`](#nemo_rl-experience-rollouts-_calculate_single_metric) | - | +| [`async_generate_response_for_sample_turn`](#nemo_rl-experience-rollouts-async_generate_response_for_sample_turn) | Generate a response for a single sample's turn using async generation. | +| [`calculate_rewards`](#nemo_rl-experience-rollouts-calculate_rewards) | Calculate rewards for generated responses and get environment feedback. | +| [`generate_responses`](#nemo_rl-experience-rollouts-generate_responses) | Generate responses from policy using synchronous generation. | +| [`generate_responses_async`](#nemo_rl-experience-rollouts-generate_responses_async) | Async version of generate_responses that properly calls generate_async. | +| [`run_async_multi_turn_rollout`](#nemo_rl-experience-rollouts-run_async_multi_turn_rollout) | Run multi-turn rollouts with sample-level processing. | +| [`run_async_nemo_gym_rollout`](#nemo_rl-experience-rollouts-run_async_nemo_gym_rollout) | Run multi-turn rollouts with NeMo-Gym. Please refer to the `run_async_multi_turn_rollout` docs for more information on the parameters. | +| [`run_multi_turn_rollout`](#nemo_rl-experience-rollouts-run_multi_turn_rollout) | Runs a multi-turn rollout loop, interacting with the environment. | +| [`run_sample_multi_turn_rollout`](#nemo_rl-experience-rollouts-run_sample_multi_turn_rollout) | Run a multi-turn rollout for a single sample. | + +### Data + +[`TokenizerType`](#nemo_rl-experience-rollouts-TokenizerType) + +### API + + + + + +```python +class nemo_rl.experience.rollouts.AsyncNemoGymRolloutResult( + input_ids: torch.Tensor, + final_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + rollout_metrics: dict[str, typing.Any] +) +``` + + + + + + +Dataclass + + + + + + + + + + + + + + + +```python +nemo_rl.experience.rollouts._calculate_single_metric( + values: list[float], + batch_size: int, + key_name: str +) -> dict +``` + + + + + + + + + + + + + +```python +nemo_rl.experience.rollouts.async_generate_response_for_sample_turn( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + sample_message_log: list[dict], + sample_stop_strings: list[str] | None, + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + max_seq_len: int, + greedy: bool = False +) -> tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, float]] +``` + + + + + + +async + +Generate a response for a single sample's turn using async generation. + +**Parameters:** + + +The generation interface to use + + + +Message log for a single sample + + + +Stop strings for this sample + + + +Tokenizer to use + + + +Maximum sequence length + + + +Whether to use greedy decoding + + +**Returns:** `tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, float]]` + +Tuple of (updated_message_log, generated_tokens, input_lengths, generation_metrics) + + + + + + + + +```python +nemo_rl.experience.rollouts.calculate_rewards( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface] +) -> nemo_rl.environments.interfaces.EnvironmentReturn +``` + + + + + + +Calculate rewards for generated responses and get environment feedback. + +**Parameters:** + + +Batch containing message_log (LLMMessageLogType) with generated responses + + + +Dictionary mapping task names to their corresponding environments + + +**Returns:** `EnvironmentReturn` + +EnvironmentReturn namedtuple containing: +- observations: List of observations from the environment for the next turn. +- metadata: List of extracted metadata from the environment. +- next_stop_strings: List of stop strings for the next generation step. +- rewards: Tensor of rewards for the last turn. +- terminateds: Tensor of booleans indicating if an episode ended naturally. + + + + + + + + +```python +nemo_rl.experience.rollouts.generate_responses( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + generation_input_data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + input_lengths: torch.Tensor, + include_logprobs: bool = True, + greedy: bool = False +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], list[torch.Tensor], dict[str, float | int]] +``` + + + + + + +Generate responses from policy using synchronous generation. + + + + + + + + +```python +nemo_rl.experience.rollouts.generate_responses_async( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + generation_input_data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + input_lengths: torch.Tensor, + include_logprobs: bool = True, + greedy: bool = False +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], list[torch.Tensor], dict[str, float | int]] +``` + + + + + + +async + +Async version of generate_responses that properly calls generate_async. + + + + + + + + +```python +nemo_rl.experience.rollouts.run_async_multi_turn_rollout( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + max_seq_len: int, + max_rollout_turns: int = 999999, + greedy: bool = False +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], dict[str, typing.Any]] +``` + + + + + + +Run multi-turn rollouts with sample-level processing. + +Each sample in the batch proceeds through its interaction independently. +Async generation is used internally when available but the function is synchronous. + +**Parameters:** + + +The generation interface (policy) + + + +The starting batch containing initial message logs + + + +The tokenizer + + + +Dictionary mapping task names to environment instances + + + +Maximum sequence length allowed + + + +Maximum number of agent-environment interaction turns + + + +Whether to use greedy decoding + + +**Returns:** `tuple[BatchedDataDict[DatumSpec], dict[str, Any]]` + +Tuple containing: +- BatchedDataDict with the full interaction history and accumulated rewards +- Dictionary of rollout metrics + + + + + + + + +```python +nemo_rl.experience.rollouts.run_async_nemo_gym_rollout( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + generation_config: nemo_rl.models.generation.interfaces.GenerationConfig, + max_seq_len: typing.Optional[int] = None, + max_rollout_turns: typing.Optional[int] = None, + greedy: bool = False +) -> nemo_rl.experience.rollouts.AsyncNemoGymRolloutResult +``` + + + + + + +Run multi-turn rollouts with NeMo-Gym. Please refer to the `run_async_multi_turn_rollout` docs for more information on the parameters. + + + + + + + + +```python +nemo_rl.experience.rollouts.run_multi_turn_rollout( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + max_seq_len: int, + max_rollout_turns: int = 999999, + greedy: bool = False +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], dict[str, typing.Any]] +``` + + + + + + +Runs a multi-turn rollout loop, interacting with the environment. + +**Parameters:** + + +The generation interface (policy). + + + +The starting batch containing initial message logs. + + + +The tokenizer. + + + +Dictionary mapping task names to environment instances. + + + +Maximum number of agent-environment interaction turns. + + + +Maximum sequence length allowed. + + + +Whether to use greedy decoding. + + +**Returns:** `tuple[BatchedDataDict[DatumSpec], dict[str, Any]]` + +Tuple containing: +- BatchedDataDict with the full interaction history and accumulated rewards +- Dictionary of rollout metrics + + + + + + + + +```python +nemo_rl.experience.rollouts.run_sample_multi_turn_rollout( + sample_idx: int, + initial_sample_state: dict, + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + max_seq_len: int, + max_rollout_turns: int = 999999, + greedy: bool = False +) -> tuple[dict, dict[str, typing.Any]] +``` + + + + + + +async + +Run a multi-turn rollout for a single sample. + +This function manages the complete lifecycle of one sample's interaction. +Async generation is used internally when available. + +**Parameters:** + + +Index of this sample in the original batch + + + +Initial state containing message_log, extra_env_info, etc. + + + +The generation interface + + + +Tokenizer to use + + + +Environment mapping + + + +Maximum sequence length + + + +Maximum number of turns + + + +Whether to use greedy decoding + + +**Returns:** `tuple[dict, dict[str, Any]]` + +Tuple of (final_sample_state, sample_metrics) + + + + + + + + +```python +nemo_rl.experience.rollouts.TokenizerType = PreTrainedTokenizerBase +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx new file mode 100644 index 0000000..78aeea0 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx @@ -0,0 +1,14 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models +title: nemo_rl.models +--- + +## Subpackages + +- **[`nemo_rl.models.automodel`](/nemo-rl/nemo_rl/models/automodel)** +- **[`nemo_rl.models.dtensor`](/nemo-rl/nemo_rl/models/dtensor)** +- **[`nemo_rl.models.generation`](/nemo-rl/nemo_rl/models/generation)** +- **[`nemo_rl.models.huggingface`](/nemo-rl/nemo_rl/models/huggingface)** +- **[`nemo_rl.models.megatron`](/nemo-rl/nemo_rl/models/megatron)** +- **[`nemo_rl.models.policy`](/nemo-rl/nemo_rl/models/policy)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx new file mode 100644 index 0000000..a8d957d --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx @@ -0,0 +1,12 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/automodel +title: nemo_rl.models.automodel +--- + +## Submodules + +- **[`nemo_rl.models.automodel.config`](/nemo-rl/nemo_rl/models/automodel/config)** +- **[`nemo_rl.models.automodel.data`](/nemo-rl/nemo_rl/models/automodel/data)** +- **[`nemo_rl.models.automodel.setup`](/nemo-rl/nemo_rl/models/automodel/setup)** +- **[`nemo_rl.models.automodel.train`](/nemo-rl/nemo_rl/models/automodel/train)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx new file mode 100644 index 0000000..c409b87 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx @@ -0,0 +1,125 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/automodel/config +title: nemo_rl.models.automodel.config +--- + +Configuration classes for automodel-based training in NeMo RL. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ModelAndOptimizerState`](#nemo_rl-models-automodel-config-ModelAndOptimizerState) | Container for model and optimizer state. | +| [`RuntimeConfig`](#nemo_rl-models-automodel-config-RuntimeConfig) | Runtime configuration for model training and inference. | + +### API + + + + + +```python +class nemo_rl.models.automodel.config.ModelAndOptimizerState() +``` + + + + + + +**Bases:** `NamedTuple` + +Container for model and optimizer state. + +This named tuple holds all model-related state including the model itself, +optimizer, scheduler, and metadata about the model type and configuration. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.automodel.config.RuntimeConfig() +``` + + + + + + +**Bases:** `NamedTuple` + +Runtime configuration for model training and inference. + +This contains all validated runtime settings needed for model initialization, +parallelization, and training. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx new file mode 100644 index 0000000..f777494 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx @@ -0,0 +1,374 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/automodel/data +title: nemo_rl.models.automodel.data +--- + +Data processing utilities for automodel training and inference. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ProcessedInputs`](#nemo_rl-models-automodel-data-ProcessedInputs) | Processed microbatch inputs ready for model forward pass. | +| [`ProcessedMicrobatch`](#nemo_rl-models-automodel-data-ProcessedMicrobatch) | Container for a processed microbatch ready for model forward pass. | + +### Functions + +| Name | Description | +|------|-------------| +| [`check_sequence_dim`](#nemo_rl-models-automodel-data-check_sequence_dim) | Check and validate sequence dimension across all tensors. | +| [`get_microbatch_iterator`](#nemo_rl-models-automodel-data-get_microbatch_iterator) | Create processed microbatch iterator based on batching strategy. | +| [`make_processed_microbatch_iterator`](#nemo_rl-models-automodel-data-make_processed_microbatch_iterator) | Wrap a raw microbatch iterator to yield processed microbatches. | +| [`process_global_batch`](#nemo_rl-models-automodel-data-process_global_batch) | Process a global batch and compute normalization factors. | +| [`process_microbatch`](#nemo_rl-models-automodel-data-process_microbatch) | Process a microbatch and prepare inputs for model forward. | + +### API + + + + + +```python +class nemo_rl.models.automodel.data.ProcessedInputs( + input_ids: torch.Tensor, + seq_len: int, + attention_mask: typing.Optional[torch.Tensor] = None, + position_ids: typing.Optional[torch.Tensor] = None, + flash_attn_kwargs: dict[str, typing.Any] = dict(), + vlm_kwargs: dict[str, typing.Any] = dict(), + cp_buffers: list[torch.Tensor] = list(), + seq_index: typing.Optional[torch.Tensor] = None +) +``` + + + + + + +Dataclass + +Processed microbatch inputs ready for model forward pass. + +This structure contains all necessary tensors and metadata for a forward pass, +including context parallel buffers and flash attention configuration. + + + + + + + + + + + + +Check if context parallel is enabled. + + + +Check if flash attention is configured. + +Works for both empty dict {} and dataclass objects like FlashAttnKwargs. + + + + + + +Check if this is a multimodal input. + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.automodel.data.ProcessedMicrobatch( + data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + original_batch_size: int, + original_seq_len: int +) +``` + + + + + + +Dataclass + +Container for a processed microbatch ready for model forward pass. + +This dataclass holds both the original data dictionary and the processed +tensors needed for the automodel forward pass. It follows the same pattern +as nemo_rl/models/megatron/data.py ProcessedMicrobatch. + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.automodel.data.check_sequence_dim( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +) -> typing.Tuple[int, int] +``` + + + + + + +Check and validate sequence dimension across all tensors. + +Verifies that dimension 1 is the sequence dimension for all tensors +in the data dictionary that have more than one dimension. + +**Parameters:** + + +BatchedDataDict to validate + + +**Returns:** `Tuple[int, int]` + +Tuple of (sequence_dim, seq_dim_size) + +**Raises:** + +- `AssertionError`: If any tensor has inconsistent sequence dimension + + + + + + + + +```python +nemo_rl.models.automodel.data.get_microbatch_iterator( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + cfg: dict[str, typing.Any], + mbs: int, + dp_mesh: typing.Any, + tokenizer: transformers.AutoTokenizer, + cp_size: int = 1 +) -> tuple[typing.Iterator[nemo_rl.models.automodel.data.ProcessedMicrobatch], int] +``` + + + + + + +Create processed microbatch iterator based on batching strategy. + +**Parameters:** + + +Full dataset to iterate over + + + +Configuration dictionary (enable_seq_packing is inferred from cfg["sequence_packing"]["enabled"]) + + + +Microbatch size + + + +Data parallel mesh + + + +Tokenizer for processing + + + +Context parallel size + + +**Returns:** `tuple[Iterator[ProcessedMicrobatch], int]` + +Tuple of (processed_microbatch_iterator, iterator_length) + + + + + + + + +```python +nemo_rl.models.automodel.data.make_processed_microbatch_iterator( + raw_iterator: typing.Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any]], + tokenizer: transformers.AutoTokenizer, + cfg: dict[str, typing.Any], + cp_size: int +) -> typing.Iterator[nemo_rl.models.automodel.data.ProcessedMicrobatch] +``` + + + + + + +Wrap a raw microbatch iterator to yield processed microbatches. + +This function takes a raw iterator that yields BatchedDataDict objects and +wraps it to yield ProcessedMicrobatch objects that contain both the original +data and the processed tensors ready for model forward pass. + +**Parameters:** + + +Iterator yielding raw BatchedDataDict microbatches + + + +Tokenizer for processing + + + +Configuration dictionary (enable_seq_packing is inferred from cfg["sequence_packing"]["enabled"]) + + + +Context parallel size + + + + + + + + + +```python +nemo_rl.models.automodel.data.process_global_batch( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + dp_group: torch.distributed.ProcessGroup, + batch_idx: int, + batch_size: int +) -> dict[str, typing.Any] +``` + + + + + + +Process a global batch and compute normalization factors. + +**Parameters:** + + +Full dataset + + + +Loss function (used to check loss type) + + + +Data parallel process group (for consistency with Megatron naming) + + + +Index of batch to extract + + + +Size of batch to extract + + +**Returns:** `dict[str, Any]` + +Dictionary containing: + + + + + + + + +```python +nemo_rl.models.automodel.data.process_microbatch( + mb: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + tokenizer: transformers.AutoTokenizer, + enable_seq_packing: bool, + cfg: dict[str, typing.Any], + cp_size: int +) -> nemo_rl.models.automodel.data.ProcessedInputs +``` + + + + + + +Process a microbatch and prepare inputs for model forward. + +**Parameters:** + + +Microbatch data + + + +Tokenizer for padding value + + + +Whether sequence packing is enabled + + + +Configuration dictionary + + + +Context parallel size + + +**Returns:** `ProcessedInputs` + +ProcessedInputs containing all tensors and metadata for forward pass + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx new file mode 100644 index 0000000..2b2a115 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx @@ -0,0 +1,229 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/automodel/setup +title: nemo_rl.models.automodel.setup +--- + +Setup utilities for automodel-based training in NeMo RL. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`setup_distributed`](#nemo_rl-models-automodel-setup-setup_distributed) | Set up distributed training environment and create FSDP2Manager. | +| [`setup_model_and_optimizer`](#nemo_rl-models-automodel-setup-setup_model_and_optimizer) | Set up model, parallelization, and optimizer. | +| [`setup_reference_model_state`](#nemo_rl-models-automodel-setup-setup_reference_model_state) | Set up reference model state dict by creating a CPU copy of the model's state dict. | +| [`validate_and_prepare_config`](#nemo_rl-models-automodel-setup-validate_and_prepare_config) | Validate configuration and prepare runtime settings. | + +### Data + +[`STRING_TO_DTYPE`](#nemo_rl-models-automodel-setup-STRING_TO_DTYPE) + +### API + + + + + +```python +nemo_rl.models.automodel.setup.setup_distributed( + config: nemo_rl.models.policy.PolicyConfig, + runtime_config: nemo_rl.models.automodel.config.RuntimeConfig +) -> nemo_automodel.components.distributed.fsdp2.FSDP2Manager +``` + + + + + + +Set up distributed training environment and create FSDP2Manager. + +Initializes torch.distributed process group and creates an FSDP2Manager +with the appropriate parallelization and precision settings. + +**Parameters:** + + +Policy configuration dictionary + + + +RuntimeConfig named tuple from validate_and_prepare_config + + +**Returns:** `FSDP2Manager` + +FSDP2Manager instance with all distributed configuration + + + + + + + + +```python +nemo_rl.models.automodel.setup.setup_model_and_optimizer( + config: nemo_rl.models.policy.PolicyConfig, + tokenizer: transformers.AutoTokenizer, + runtime_config: nemo_rl.models.automodel.config.RuntimeConfig, + distributed_manager: nemo_automodel.components.distributed.fsdp2.FSDP2Manager, + checkpoint_manager: typing.Any, + is_vlm: bool = False, + init_optimizer: bool = True, + weights_path: typing.Optional[str] = None, + optimizer_path: typing.Optional[str] = None +) -> nemo_rl.models.automodel.config.ModelAndOptimizerState +``` + + + + + + +Set up model, parallelization, and optimizer. + +Creates the model from config, applies parallelization strategies (FSDP2, TP, CP), +loads base weights, and optionally initializes optimizer and scheduler. + +**Parameters:** + + +Policy configuration dictionary + + + +Tokenizer for the model + + + +RuntimeConfig named tuple from validate_and_prepare_config + + + +FSDP2Manager from setup_distributed + + + +Checkpoint manager for loading/saving weights + + + +Whether this is a vision-language model + + + +Whether to initialize optimizer + + + +Optional path to checkpoint weights to load + + + +Optional path to optimizer state to load + + +**Returns:** `ModelAndOptimizerState` + +ModelAndOptimizerState containing model, optimizer, scheduler, and metadata + + + + + + + + +```python +nemo_rl.models.automodel.setup.setup_reference_model_state( + model: torch.nn.Module +) -> dict[str, torch.Tensor] +``` + + + + + + +Set up reference model state dict by creating a CPU copy of the model's state dict. + +This creates a reference copy of the model weights on CPU with pinned memory +for efficient CPU-GPU transfers. The reference model is typically used to +compute reference log probabilities during RL training. + +**Parameters:** + + +The model to create a reference copy from + + +**Returns:** `dict[str, torch.Tensor]` + +Dictionary mapping parameter names to CPU tensors with pinned memory + + + + + + + + +```python +nemo_rl.models.automodel.setup.validate_and_prepare_config( + config: nemo_rl.models.policy.PolicyConfig, + processor: typing.Optional[transformers.AutoProcessor], + rank: int +) -> nemo_rl.models.automodel.config.RuntimeConfig +``` + + + + + + +Validate configuration and prepare runtime settings. + +This function validates the policy configuration, sets environment variables, +determines model configuration, and returns runtime settings as a named tuple. + +**Parameters:** + + +Policy configuration dictionary + + + +Optional processor for multimodal models + + + +Current process rank + + +**Returns:** `RuntimeConfig` + +RuntimeConfig named tuple containing validated configuration values + +**Raises:** + +- `ValueError`: If configuration is invalid +- `RuntimeError`: If incompatible settings are detected + + + + + + + + +```python +nemo_rl.models.automodel.setup.STRING_TO_DTYPE = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16} +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx new file mode 100644 index 0000000..7126780 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx @@ -0,0 +1,841 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/automodel/train +title: nemo_rl.models.automodel.train +--- + +Training utilities for automodel (DTensor-based) policy workers. + +This module provides post-processor classes and forward/backward functions +that follow the same pattern as nemo_rl/models/megatron/train.py. + +Key differences from megatron approach: +- Post-processors compute results directly (no callable return pattern) +- forward_with_post_processing_fn calls post-processor directly +- automodel_forward_backward uses PyTorch autograd instead of Megatron's pipeline + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LogprobsPostProcessor`](#nemo_rl-models-automodel-train-LogprobsPostProcessor) | Post-processor for computing log probabilities from model outputs. | +| [`LossPostProcessor`](#nemo_rl-models-automodel-train-LossPostProcessor) | Post-processor for computing training loss from model outputs. | +| [`ScorePostProcessor`](#nemo_rl-models-automodel-train-ScorePostProcessor) | Post-processor for computing reward model scores from model outputs. | +| [`TopkLogitsPostProcessor`](#nemo_rl-models-automodel-train-TopkLogitsPostProcessor) | Post-processor for computing top-k logits from model outputs. | + +### Functions + +| Name | Description | +|------|-------------| +| [`aggregate_training_statistics`](#nemo_rl-models-automodel-train-aggregate_training_statistics) | Aggregate training statistics across microbatches and ranks. | +| [`apply_temperature_scaling`](#nemo_rl-models-automodel-train-apply_temperature_scaling) | Apply temperature scaling to logits. | +| [`automodel_forward_backward`](#nemo_rl-models-automodel-train-automodel_forward_backward) | Execute forward and backward passes for automodel. | +| [`extract_logits`](#nemo_rl-models-automodel-train-extract_logits) | Extract logits from model outputs. | +| [`forward_with_post_processing_fn`](#nemo_rl-models-automodel-train-forward_with_post_processing_fn) | Perform forward pass with pre-processed microbatch and apply post-processing. | +| [`model_forward`](#nemo_rl-models-automodel-train-model_forward) | Perform a single forward pass through the model. | +| [`prepare_data_for_cp`](#nemo_rl-models-automodel-train-prepare_data_for_cp) | Prepare data for context parallel processing. | +| [`redistribute_logits_for_cp`](#nemo_rl-models-automodel-train-redistribute_logits_for_cp) | Redistribute logits for context parallel processing. | + +### Data + +[`PostProcessingFunction`](#nemo_rl-models-automodel-train-PostProcessingFunction) + +### API + + + + + +```python +class nemo_rl.models.automodel.train.LogprobsPostProcessor( + cfg: nemo_rl.models.policy.PolicyConfig, + device_mesh: typing.Any, + cp_mesh: typing.Any, + tp_mesh: typing.Any, + cp_size: int, + enable_seq_packing: bool = False +) +``` + + + + + + +Post-processor for computing log probabilities from model outputs. + + + + + + + + +```python +nemo_rl.models.automodel.train.LogprobsPostProcessor.__call__( + logits: torch.Tensor, + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + input_lengths: torch.Tensor, + original_batch_size: int, + original_seq_len: int, + sequence_dim: int = 1 +) -> torch.Tensor +``` + + + + + + +Compute token log probabilities from logits. + +**Parameters:** + + +Model output logits + + + +Processed inputs + + + +Sequence lengths + + + +Original batch size before packing + + + +Original sequence length before packing + + + +Sequence dimension + + +**Returns:** `torch.Tensor` + +Token log probabilities tensor [batch_size, seq_length] + + + + + + + +```python +nemo_rl.models.automodel.train.LogprobsPostProcessor._compute_local_logprobs( + logits: torch.Tensor, + input_ids: torch.Tensor +) -> torch.Tensor +``` + + + + + + +Compute logprobs locally without distributed processing. + +**Parameters:** + + +Model output logits + + + +Input token IDs + + +**Returns:** `torch.Tensor` + +Token log probabilities + + + + + + + + + +```python +class nemo_rl.models.automodel.train.LossPostProcessor( + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + cfg: nemo_rl.models.policy.PolicyConfig, + device_mesh: typing.Any, + cp_mesh: typing.Any, + tp_mesh: typing.Any, + cp_size: int, + dp_size: int, + enable_seq_packing: bool = False +) +``` + + + + + + +Post-processor for computing training loss from model outputs. + + + + + + +```python +nemo_rl.models.automodel.train.LossPostProcessor.__call__( + logits: torch.Tensor, + mb: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + sequence_dim: int = 1 +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + +Compute loss from logits. + +**Parameters:** + + +Model output logits + + + +Microbatch data + + + +Processed inputs + + + +Global valid sequence count + + + +Global valid token count + + + +Sequence dimension + + +**Returns:** `tuple[torch.Tensor, dict[str, Any]]` + +Tuple of (loss, metrics) + + + + + + + + + +```python +class nemo_rl.models.automodel.train.ScorePostProcessor( + cfg: nemo_rl.models.policy.PolicyConfig +) +``` + + + + + + +Post-processor for computing reward model scores from model outputs. + + + + + + +```python +nemo_rl.models.automodel.train.ScorePostProcessor.__call__( + logits: torch.Tensor +) -> torch.Tensor +``` + + + + + + +Extract scores from reward model outputs. + +**Parameters:** + + +Model output logits + + +**Returns:** `torch.Tensor` + +Scores tensor + + + + + + + + + +```python +class nemo_rl.models.automodel.train.TopkLogitsPostProcessor( + cfg: nemo_rl.models.policy.PolicyConfig, + device_mesh: typing.Any, + cp_mesh: typing.Any, + tp_mesh: typing.Any, + cp_size: int, + k: int, + enable_seq_packing: bool = False +) +``` + + + + + + +Post-processor for computing top-k logits from model outputs. + + + + + + +```python +nemo_rl.models.automodel.train.TopkLogitsPostProcessor.__call__( + logits: torch.Tensor, + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + input_lengths: torch.Tensor, + original_batch_size: int, + original_seq_len: int, + sequence_dim: int = 1 +) -> tuple[torch.Tensor, torch.Tensor] +``` + + + + + + +Compute top-k logits and indices from model outputs. + +**Parameters:** + + +Model output logits + + + +Processed inputs + + + +Sequence lengths + + + +Original batch size before packing + + + +Original sequence length before packing + + + +Sequence dimension + + +**Returns:** `tuple[torch.Tensor, torch.Tensor]` + +Tuple of (top-k values, top-k indices) tensors + + + + + + + + + +```python +nemo_rl.models.automodel.train.aggregate_training_statistics( + losses: list[float], + all_mb_metrics: list[dict[str, typing.Any]], + grad_norm: typing.Optional[torch.Tensor], + dp_group: typing.Any, + dtype: torch.dtype +) -> dict[str, typing.Any] +``` + + + + + + +Aggregate training statistics across microbatches and ranks. + +**Parameters:** + + +List of loss values from each microbatch + + + +List of metrics dictionaries from each microbatch + + + +Gradient norm tensor (or None if eval mode) + + + +Data parallel process group for all-reduce + + + +Model dtype for metrics + + +**Returns:** `dict[str, Any]` + +Dictionary containing aggregated metrics including global_loss, grad_norm, etc. + + + + + + + + +```python +nemo_rl.models.automodel.train.apply_temperature_scaling( + logits: torch.Tensor, + cfg: nemo_rl.models.policy.PolicyConfig +) -> torch.Tensor +``` + + + + + + +Apply temperature scaling to logits. + +**Parameters:** + + +Logits tensor to scale + + + +Configuration dictionary containing generation settings + + +**Returns:** `torch.Tensor` + +torch.Tensor: Temperature-scaled logits + + + + + + + + +```python +nemo_rl.models.automodel.train.automodel_forward_backward( + model: torch.nn.Module, + cfg: nemo_rl.models.policy.PolicyConfig, + data_iterator: typing.Iterator[nemo_rl.models.automodel.data.ProcessedMicrobatch], + post_processing_fn: nemo_rl.models.automodel.train.PostProcessingFunction, + forward_only: bool = False, + is_reward_model: bool = False, + allow_flash_attn_args: bool = True, + global_valid_seqs: typing.Optional[torch.Tensor] = None, + global_valid_toks: typing.Optional[torch.Tensor] = None, + sequence_dim: int = 1, + dp_size: int = 1, + cp_size: int = 1, + num_global_batches: int = 1, + train_context_fn: typing.Optional[typing.Callable[[ProcessedInputs], typing.Any]] = None, + num_valid_microbatches: typing.Optional[int] = None, + on_microbatch_start: typing.Optional[typing.Callable[[int], None]] = None +) -> list[typing.Tuple[typing.Any, dict[str, typing.Any]]] +``` + + + + + + +Execute forward and backward passes for automodel. + +This is the main training loop function that coordinates forward and backward +passes across multiple microbatches using PyTorch autograd. + +Unlike megatron_forward_backward which uses Megatron's pipeline parallel +framework, this uses standard PyTorch operations. + +**Parameters:** + + +The model to train + + + +Configuration dictionary + + + +Iterator yielding ProcessedMicrobatch objects (already processed) + + + +Number of microbatches to process + + + +Post-processing function to apply to the logits + + + +If True, skip backward pass + + + +Whether this is a reward model + + + +Whether to pass flash_attn_kwargs to model + + + +Global valid sequence count for loss normalization + + + +Global valid token count for loss normalization + + + +Sequence dimension + + + +Data parallel size + + + +Context parallel size + + + +Number of global batches (for metric scaling) + + + +Optional callable that takes ProcessedInputs and returns +a context manager for the forward/backward pass. If None, no context is used. + + + +Number of valid (non-dummy) microbatches. If provided, +microbatches beyond this index are treated as dummy batches (loss *= 0). +If None, all microbatches are considered valid. + + + +Optional callback called at the start of each microbatch +with the microbatch index. Useful for cache clearing, etc. + + +**Returns:** `list[Tuple[Any, dict[str, Any]]]` + +List of (result, metrics) tuples from each microbatch + + + + + + + + +```python +nemo_rl.models.automodel.train.extract_logits( + model: torch.nn.Module, + outputs: typing.Any +) -> torch.Tensor +``` + + + + + + +Extract logits from model outputs. + +**Parameters:** + + +The model (used for lm_head if needed) + + + +Model outputs (can be tensor, DTensor, or object with logits attribute) + + +**Returns:** `torch.Tensor` + +torch.Tensor: Logits tensor + + + + + + + + +```python +nemo_rl.models.automodel.train.forward_with_post_processing_fn( + model: torch.nn.Module, + cfg: nemo_rl.models.policy.PolicyConfig, + post_processing_fn: nemo_rl.models.automodel.train.PostProcessingFunction, + processed_mb: nemo_rl.models.automodel.data.ProcessedMicrobatch, + is_reward_model: bool = False, + allow_flash_attn_args: bool = True, + global_valid_seqs: typing.Optional[torch.Tensor] = None, + global_valid_toks: typing.Optional[torch.Tensor] = None, + sequence_dim: int = 1 +) -> typing.Tuple[typing.Any, dict[str, typing.Any], nemo_rl.models.automodel.data.ProcessedMicrobatch] +``` + + + + + + +Perform forward pass with pre-processed microbatch and apply post-processing. + +This function takes a pre-processed microbatch (with sequence packing already handled), +runs the forward step through the model, and applies the post-processing function +to compute the result. + +Unlike the megatron approach which returns a callable, this directly computes +and returns the result since automodel uses PyTorch autograd. + +**Parameters:** + + +The model to run forward pass on + + + +Configuration dictionary + + + +Post-processing function to apply to the logits + + + +Pre-fetched ProcessedMicrobatch containing data and processed inputs + + + +Whether this is a reward model + + + +Whether to pass flash_attn_kwargs to model + + + +Global valid sequence count for loss normalization + + + +Global valid token count for loss normalization + + + +Sequence dimension + + +**Returns:** `Tuple[Any, dict[str, Any], ProcessedMicrobatch]` + +(result, metrics, processed_microbatch) +- result: Output from post-processing (loss, logprobs, topk, or scores) +- metrics: Dictionary of metrics from post-processing +- processed_microbatch: The ProcessedMicrobatch that was processed + + + + + + + + +```python +nemo_rl.models.automodel.train.model_forward( + model: torch.nn.Module, + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + is_reward_model: bool = False, + allow_flash_attn_args: bool = True +) -> torch.Tensor +``` + + + + + + +Perform a single forward pass through the model. + +**Parameters:** + + +The model to run forward pass on + + + +ProcessedInputs containing all tensors for forward pass + + + +Whether this is a reward model + + + +Whether to pass flash_attn_kwargs to model + + +**Returns:** `torch.Tensor` + +torch.Tensor: Output tensor from the model (logits) + + + + + + + + +```python +nemo_rl.models.automodel.train.prepare_data_for_cp( + mb: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + cp_mesh: typing.Any, + sequence_dim: int = 1 +) -> tuple[torch.Tensor, nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any]] +``` + + + + + + +Prepare data for context parallel processing. + +Converts seq_index to full tensor and wraps CP-sharded tensors in DTensor. + +**Parameters:** + + +Microbatch data dictionary + + + +Processed inputs containing CP buffers + + + +Context parallel mesh + + + +Dimension for sequence sharding + + +**Returns:** `tuple[torch.Tensor, BatchedDataDict[Any]]` + +Tuple of (seq_index_dtensor, updated_mb) + + + + + + + + +```python +nemo_rl.models.automodel.train.redistribute_logits_for_cp( + logits: torch.Tensor, + device_mesh: typing.Any, + cp_mesh: typing.Any, + sequence_dim: int = 1 +) -> torch.distributed.tensor.DTensor +``` + + + + + + +Redistribute logits for context parallel processing. + +Handles the case where logits may be TP-sharded DTensor or regular tensor, +and converts them to CP+TP sharded DTensor. + +**Parameters:** + + +Logits tensor (may be DTensor or regular tensor) + + + +Full device mesh + + + +Context parallel mesh (kept for signature compatibility) + + + +Dimension for sequence sharding + + +**Returns:** `DTensor` + +DTensor sharded on both CP and TP dimensions + + + + + + + + +```python +nemo_rl.models.automodel.train.PostProcessingFunction = Union['LossPostProcessor', 'LogprobsPostProcessor', 'TopkLogitsPostProcessor', '... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx new file mode 100644 index 0000000..b29df46 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx @@ -0,0 +1,9 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/dtensor +title: nemo_rl.models.dtensor +--- + +## Submodules + +- **[`nemo_rl.models.dtensor.parallelize`](/nemo-rl/nemo_rl/models/dtensor/parallelize)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx new file mode 100644 index 0000000..72877d4 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx @@ -0,0 +1,454 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/dtensor/parallelize +title: nemo_rl.models.dtensor.parallelize +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RotaryEmbedParallel`](#nemo_rl-models-dtensor-parallelize-RotaryEmbedParallel) | Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_parallelize_gemma3`](#nemo_rl-models-dtensor-parallelize-_parallelize_gemma3) | Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions. | +| [`_parallelize_llama`](#nemo_rl-models-dtensor-parallelize-_parallelize_llama) | Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions. | +| [`_parallelize_model`](#nemo_rl-models-dtensor-parallelize-_parallelize_model) | Parallelize a model using DTensor. | +| [`_parallelize_nm5_h`](#nemo_rl-models-dtensor-parallelize-_parallelize_nm5_h) | Parallelize a NemotronHForCausalLM model across data and tensor parallel dimensions. | +| [`_parallelize_qwen`](#nemo_rl-models-dtensor-parallelize-_parallelize_qwen) | Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions. | +| [`clip_grad_by_total_norm_`](#nemo_rl-models-dtensor-parallelize-clip_grad_by_total_norm_) | Clips gradient of an iterable of parameters by total norm. | +| [`get_grad_norm`](#nemo_rl-models-dtensor-parallelize-get_grad_norm) | Calculate the norm of gradients. | +| [`get_hf_tp_plan`](#nemo_rl-models-dtensor-parallelize-get_hf_tp_plan) | Get the Hugging Face tensor parallel plan from the model. | +| [`to_local_if_dtensor`](#nemo_rl-models-dtensor-parallelize-to_local_if_dtensor) | Returns the local shard of the given tensor if it is a DTensor. | +| [`translate_parallel_style`](#nemo_rl-models-dtensor-parallelize-translate_parallel_style) | Translate parallel style str to parallel type. | + +### Data + +[`PARALLIZE_FUNCTIONS`](#nemo_rl-models-dtensor-parallelize-PARALLIZE_FUNCTIONS) + +### API + + + + + +```python +class nemo_rl.models.dtensor.parallelize.RotaryEmbedParallel() +``` + + + + + + +**Bases:** `SequenceParallel` + +Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple. + + + + + + +```python +nemo_rl.models.dtensor.parallelize.RotaryEmbedParallel._prepare_input_fn( + sequence_sharding, + mod, + inputs, + device_mesh +) +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.RotaryEmbedParallel._prepare_output_fn( + use_local_output, + mod, + outputs, + device_mesh +) +``` + + + + + + +staticmethod + + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize._parallelize_gemma3( + model: typing.Union[transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration], + sequence_parallel: bool = False +) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle] +``` + + + + + + +Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions. + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize._parallelize_llama( + model: transformers.models.llama.modeling_llama.LlamaForCausalLM, + sequence_parallel: bool = False +) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle] +``` + + + + + + +Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions. + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize._parallelize_model( + model: typing.Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM, transformers.models.llama.modeling_llama.LlamaForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration], + dp_mesh: torch.distributed.device_mesh.DeviceMesh, + tp_mesh: torch.distributed.device_mesh.DeviceMesh, + param_dtype: torch.dtype, + sequence_parallel: bool = False, + activation_checkpointing: bool = False, + cpu_offload: bool = False, + custom_parallel_plan: typing.Optional[typing.Union[dict, str]] = None +) +``` + + + + + + +Parallelize a model using DTensor. + +**Parameters:** + + +The model to parallelize. + + + +Device mesh for data parallelism. + + + +Device mesh for tensor parallelism. + + + +Data type for model parameters. + + + +Whether to use sequence parallelism. Defaults to False. + + + +Whether to use activation checkpointing. Defaults to False. + + + +Whether to enable cpu offloading for FSDP. Defaults to False. + + + +Custom parallel plan for the model. Defaults to None. +If it's a dict, it will be used as the parallel plan directly. +If it's a string, it must be a path that points to a dict or a function that returns a dict. +The usage example can refer to `docs/design-docs/fsdp2-parallel-plan.md`. + + +**Returns:** + +The parallelized model. + +**Raises:** + +- `ValueError`: If the model type is not supported for parallelization. + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize._parallelize_nm5_h( + model, + dp_mesh: torch.distributed.device_mesh.DeviceMesh, + tp_mesh: torch.distributed.device_mesh.DeviceMesh, + param_dtype: torch.dtype, + sequence_parallel: bool = False, + activation_checkpointing: bool = False, + cpu_offload: bool = False, + custom_parallel_plan: typing.Optional[typing.Union[dict, str]] = None +) -> torch.distributed.fsdp.FSDPModule +``` + + + + + + +Parallelize a NemotronHForCausalLM model across data and tensor parallel dimensions. + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize._parallelize_qwen( + model: typing.Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM], + sequence_parallel: bool = False +) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle] +``` + + + + + + +Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions. + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.clip_grad_by_total_norm_( + parameters: typing.Union[list[typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], + max_grad_norm: typing.Union[int, float], + total_norm: float +) +``` + + + + + + +Clips gradient of an iterable of parameters by total norm. + +Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L138 + +Note that the gradients are modified in place. + +**Parameters:** + + + +An iterable of Tensors or DTensors, or a single Tensor or DTensor +that will have gradients normalized. + + + +Maximum norm of the gradients. + + + +The pre-computed total norm of the gradients to use for scaling. + + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.get_grad_norm( + parameters: typing.Union[list[typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], + dp_cp_group: torch.distributed.ProcessGroup, + tp_group: torch.distributed.ProcessGroup, + norm_type: typing.Union[int, float] = 2, + dtype: torch.dtype = torch.float32 +) -> float +``` + + + + + + +Calculate the norm of gradients. + +Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L51 + +**Parameters:** + + + +An iterable of Tensors or DTensors, or a single Tensor or DTensor +that will have gradient norm calculated. + + + +Process group for data parallel communication. + + + +Process group for context parallel communication. + + + +Process group for tensor parallel communication. + + + +Type of the used p-norm. Can be ``'inf'`` for +infinity norm. + + +**Returns:** `float` + +Total norm of the gradients (viewed as a single vector) + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.get_hf_tp_plan( + model: transformers.modeling_utils.PreTrainedModel +) +``` + + + + + + +Get the Hugging Face tensor parallel plan from the model. + +This function: +- Retrieves TP strategies from model class, instance, and inner model levels. +- Handles special cases for `embed_tokens` and `lm_head` for speed up. +- Converts string-based parallel styles to DTensor parallelization strategies. + +Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L532 + +**Parameters:** + + +A Hugging Face model instance + + +**Returns:** + +A dictionary mapping model component paths to their parallelization strategies + +**Raises:** + +- `AssertionError`: If no TP plan is found + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.to_local_if_dtensor( + tensor: typing.Union[torch.Tensor, torch.distributed.tensor.DTensor] +) -> torch.Tensor +``` + + + + + + +Returns the local shard of the given tensor if it is a DTensor. + +Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/605f618f237cda8fa80132bc2ccff933512d5a0d/megatron/core/utils.py#L746 + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.translate_parallel_style( + style: str +) +``` + + + + + + +Translate parallel style str to parallel type. + +Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L547 + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.PARALLIZE_FUNCTIONS: dict[type[Module], Callable[..., dict[str, ParallelStyle]]] = {Qwen2ForCausalLM: _parallelize_qwen, Qwen3ForCausalLM: _parallelize_qwen, Llama... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx new file mode 100644 index 0000000..ff3114c --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx @@ -0,0 +1,62 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation +title: nemo_rl.models.generation +--- + +## Subpackages + +- **[`nemo_rl.models.generation.sglang`](/nemo-rl/nemo_rl/models/generation/sglang)** +- **[`nemo_rl.models.generation.vllm`](/nemo-rl/nemo_rl/models/generation/vllm)** + +## Submodules + +- **[`nemo_rl.models.generation.interfaces`](/nemo-rl/nemo_rl/models/generation/interfaces)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`configure_generation_config`](#nemo_rl-models-generation-configure_generation_config) | Apply specific configurations to generation config. | + +### Data + +[`TokenizerType`](#nemo_rl-models-generation-TokenizerType) + +### API + + + + + +```python +nemo_rl.models.generation.configure_generation_config( + config: nemo_rl.models.generation.interfaces.GenerationConfig, + tokenizer: nemo_rl.models.generation.TokenizerType, + is_eval = False +) -> nemo_rl.models.generation.interfaces.GenerationConfig +``` + + + + + + +Apply specific configurations to generation config. + + + + + + + + +```python +nemo_rl.models.generation.TokenizerType = PreTrainedTokenizerBase +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx new file mode 100644 index 0000000..886ccc8 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx @@ -0,0 +1,569 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/interfaces +title: nemo_rl.models.generation.interfaces +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ColocationConfig`](#nemo_rl-models-generation-interfaces-ColocationConfig) | - | +| [`GenerationConfig`](#nemo_rl-models-generation-interfaces-GenerationConfig) | Configuration for generation. | +| [`GenerationDatumSpec`](#nemo_rl-models-generation-interfaces-GenerationDatumSpec) | Specification for input data required by generation models. | +| [`GenerationInterface`](#nemo_rl-models-generation-interfaces-GenerationInterface) | Abstract base class defining the interface for RL policies. | +| [`GenerationOutputSpec`](#nemo_rl-models-generation-interfaces-GenerationOutputSpec) | Specification for output data returned by generation models. | +| [`OptionalResourcesConfig`](#nemo_rl-models-generation-interfaces-OptionalResourcesConfig) | - | +| [`ResourcesConfig`](#nemo_rl-models-generation-interfaces-ResourcesConfig) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`verify_right_padding`](#nemo_rl-models-generation-interfaces-verify_right_padding) | Verify that a tensor is right-padded according to the provided lengths. | + +### API + + + + + +```python +class nemo_rl.models.generation.interfaces.ColocationConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.GenerationConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for generation. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.GenerationDatumSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +Specification for input data required by generation models. + +- input_ids: Tensor of token IDs representing the input sequences (right padded) +- input_lengths: Tensor containing the actual length of each sequence (without padding) +- stop_strings: Optional list of strings to stop generation (per sample) +- __extra__: Additional model-specific data fields + +Example of a batch with 4 entries with different sequence lengths: + + +```python +# Batch of 4 sequences with lengths [3, 5, 2, 4] + +input_ids (padded): +[ + [101, 2054, 2003, 0, 0], # Length 3 + [101, 2054, 2003, 2001, 1996], # Length 5 + [101, 2054, 0, 0, 0], # Length 2 + [101, 2054, 2003, 2001, 0], # Length 4 +] + +input_lengths: +[3, 5, 2, 4] +``` + + + +All functions receiving or returning GenerationDatumSpec should ensure +right padding is maintained. Use verify_right_padding() to check. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.GenerationInterface() +``` + + + + + + +Abstract + +Abstract base class defining the interface for RL policies. + + + +Whether the generation backend requires KV cache scales synchronization. + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.clear_logger_metrics() -> None +``` + + + + + + +Clear logger metrics for performance reporting. + +This is an optional method that backends can implement to clear +telemetry metrics. Default implementation does nothing. + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.finish_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.get_logger_metrics() -> dict[str, typing.Any] +``` + + + + + + +Get logger metrics for performance reporting. + +This is an optional method that backends can implement to collect +telemetry metrics. Default implementation returns empty dict. + +**Returns:** `dict[str, Any]` + +Dictionary of metrics. Format may vary by backend. + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.init_collective( + ip: str, + port: int, + world_size: int +) -> list[ray.ObjectRef] +``` + + + + + + +abstract + +Initialize the collective communication. + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.invalidate_kv_cache() -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.prepare_for_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.prepare_refit_info( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + +Prepare the info for refit. + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.update_weights_from_collective() -> list[ray.ObjectRef] +``` + + + + + + +Update the model weights from collective communication. + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.update_weights_via_ipc_zmq() -> list[ray.ObjectRef] +``` + + + + + + +Update the model weights from the given IPC handles. + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.GenerationOutputSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +Specification for output data returned by generation models. + +- output_ids: Tensor of token IDs representing the generated sequences (right padded) +- generation_lengths: Tensor containing the actual length of each generated sequence +- unpadded_sequence_lengths: Tensor containing the actual length of each input + generated sequence (without padding) +- logprobs: Tensor of log probabilities for each generated token (right padded with zeros) +- truncated: Boolean tensor indicating if each sequence was truncated (hit max_tokens limit) +- __extra__: Additional model-specific data fields + +Example of a batch with 2 sequences: + + +```python +# Sample batch with 2 examples +# - Example 1: Input length 3, generated response length 4 +# - Example 2: Input length 5, generated response length 2 + +output_ids (right-padded): +[ + [101, 2054, 2003, 2023, 2003, 1037, 2200, 0], # 7 valid tokens (3 input + 4 output) + [101, 2054, 2003, 2001, 1996, 3014, 2005, 0], # 7 valid tokens (5 input + 2 output) +] + +generation_lengths: +[4, 2] # Length of just the generated response part + +unpadded_sequence_lengths: +[7, 7] # Length of full valid sequence (input + generated response) + +logprobs (right-padded with zeros): +[ + [0.0, 0.0, 0.0, -1.2, -0.8, -2.1, -1.5, 0.0], # First 3 are 0 (input tokens), next 4 are actual logprobs + [0.0, 0.0, 0.0, 0.0, 0.0, -0.9, -1.7, 0.0], # First 5 are 0 (input tokens), next 2 are actual logprobs +] + +truncated: +[False, True] # Example 2 was truncated (hit max_tokens limit without EOS) +``` + + + +All functions receiving or returning GenerationOutputSpec should ensure +right padding is maintained. Use verify_right_padding() to check. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.OptionalResourcesConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.ResourcesConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.interfaces.verify_right_padding( + data: typing.Union[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], + pad_value: int = 0, + raise_error: bool = True +) -> tuple[bool, typing.Union[str, None]] +``` + + + + + + +Verify that a tensor is right-padded according to the provided lengths. + +**Parameters:** + + +The BatchedDataDict to check, containing either: +- For GenerationDatumSpec: input_ids and input_lengths +- For GenerationOutputSpec: output_ids and unpadded_sequence_lengths + + + +The expected padding value (default: 0) + + + +Whether to raise an error if wrong padding is detected + + +**Returns:** `bool` + +Tuple of (is_right_padded, error_message) + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx new file mode 100644 index 0000000..78b60ba --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx @@ -0,0 +1,33 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang +title: nemo_rl.models.generation.sglang +--- + +## Submodules + +- **[`nemo_rl.models.generation.sglang.config`](/nemo-rl/nemo_rl/models/generation/sglang/config)** +- **[`nemo_rl.models.generation.sglang.sglang_copied_utils`](/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils)** +- **[`nemo_rl.models.generation.sglang.sglang_generation`](/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation)** +- **[`nemo_rl.models.generation.sglang.sglang_worker`](/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker)** +- **[`nemo_rl.models.generation.sglang.utils`](/nemo-rl/nemo_rl/models/generation/sglang/utils)** + +## Package Contents + +### Data + +[`__all__`](#nemo_rl-models-generation-sglang-__all__) + +### API + + + + + +```python +nemo_rl.models.generation.sglang.__all__ = ['SGLangConfig', 'SGLangGeneration'] +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx new file mode 100644 index 0000000..f86bc80 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx @@ -0,0 +1,299 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang/config +title: nemo_rl.models.generation.sglang.config +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`SGLangConfig`](#nemo_rl-models-generation-sglang-config-SGLangConfig) | Configuration for SGLang runtime. | +| [`SglangSpecificArgs`](#nemo_rl-models-generation-sglang-config-SglangSpecificArgs) | SGLang-specific configuration arguments. | + +### API + + + + + +```python +class nemo_rl.models.generation.sglang.config.SGLangConfig() +``` + + + + + + +**Bases:** [GenerationConfig](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationConfig) + +Configuration for SGLang runtime. + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.sglang.config.SglangSpecificArgs +``` + + + + + + +**Bases:** `typing.TypedDict` + +SGLang-specific configuration arguments. + +Most fields below map directly to SGLang's ServerArgs (see: +https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py). + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx new file mode 100644 index 0000000..c940dea --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx @@ -0,0 +1,307 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils +title: nemo_rl.models.generation.sglang.sglang_copied_utils +--- + +Standalone utility functions copied from the SGLang project. + +This module contains utility functions that were originally part of the SGLang +repository (https://github.com/sgl-project/sglang). They have been copied here +to avoid requiring sglang as a runtime dependency for weight refitting functionality. + +IMPORTANT: This module should NOT contain any imports from the sglang package. +All functions are standalone and self-contained. + +Each function includes a permalink to its original source in the SGLang repository. +These functions were copied from sglang version 0.5.2. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MultiprocessingSerializer`](#nemo_rl-models-generation-sglang-sglang_copied_utils-MultiprocessingSerializer) | Serialize/deserialize Python objects using ForkingPickler for IPC. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_device_from_maybe_uuid`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_device_from_maybe_uuid) | Convert a device UUID string or index to a device index. | +| [`_device_to_uuid`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_device_to_uuid) | Convert a device index to its UUID string. | +| [`_modify_tuple`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_modify_tuple) | Create a new tuple with one element modified by a function. | +| [`_rebuild_cuda_tensor_modified`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_rebuild_cuda_tensor_modified) | Modified rebuild_cuda_tensor that accepts GPU UUID or device index. | +| [`_reduce_tensor_modified`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_reduce_tensor_modified) | Modified reduce_tensor that stores GPU UUID instead of device index. | +| [`monkey_patch_torch_reductions`](#nemo_rl-models-generation-sglang-sglang_copied_utils-monkey_patch_torch_reductions) | Monkey patch torch multiprocessing reductions to use GPU UUIDs. | + +### Data + +[`_REDUCE_TENSOR_ARG_DEVICE_INDEX`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_REDUCE_TENSOR_ARG_DEVICE_INDEX) + +### API + + + + + +```python +class nemo_rl.models.generation.sglang.sglang_copied_utils.MultiprocessingSerializer() +``` + + + + + + +Serialize/deserialize Python objects using ForkingPickler for IPC. + +This class enables serialization of objects (including CUDA tensors with IPC +handles) for transfer between processes via HTTP or other mechanisms. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/utils.py#L589-L623 + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils.MultiprocessingSerializer.deserialize( + data +) +``` + + + + + + +staticmethod + +Deserialize a previously serialized object. + +**Parameters:** + + +The serialized data, optionally base64-encoded. + + +**Returns:** + +The deserialized Python object. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils.MultiprocessingSerializer.serialize( + obj, + output_str: bool = False +) +``` + + + + + + +staticmethod + +Serialize a Python object using ForkingPickler. + +**Parameters:** + + +The object to serialize. + + + +If True, return a base64-encoded string instead of raw bytes. + + +**Returns:** + +bytes or str: The serialized object. + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._device_from_maybe_uuid( + device_maybe_uuid: typing.Union[int, str] +) -> int +``` + + + + + + +Convert a device UUID string or index to a device index. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L55-L65 + +**Parameters:** + + +Either an integer device index or a UUID string. + + +**Returns:** `int` + +The integer device index. + +**Raises:** + +- `Exception`: If the UUID doesn't match any available device. + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._device_to_uuid( + device: int +) -> str +``` + + + + + + +Convert a device index to its UUID string. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L51-L52 + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._modify_tuple( + t, + index: int, + modifier: typing.Callable +) +``` + + + + + + +Create a new tuple with one element modified by a function. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L68-L69 + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._rebuild_cuda_tensor_modified( + args = () +) +``` + + + + + + +Modified rebuild_cuda_tensor that accepts GPU UUID or device index. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L46-L48 + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._reduce_tensor_modified( + args = (), + kwargs = {} +) +``` + + + + + + +Modified reduce_tensor that stores GPU UUID instead of device index. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L39-L43 + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils.monkey_patch_torch_reductions() +``` + + + + + + +Monkey patch torch multiprocessing reductions to use GPU UUIDs. + +This patch modifies PyTorch's CUDA tensor IPC mechanism to use GPU UUIDs +instead of device indices. This enables proper weight transfer between +processes that may have different CUDA_VISIBLE_DEVICES configurations. + +The patch is idempotent - calling it multiple times is safe. + +This is a workaround before PyTorch https://github.com/pytorch/pytorch/pull/149248 +is merged and released. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L20-L33 + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx new file mode 100644 index 0000000..c8393bd --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx @@ -0,0 +1,369 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation +title: nemo_rl.models.generation.sglang.sglang_generation +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`SGLangGeneration`](#nemo_rl-models-generation-sglang-sglang_generation-SGLangGeneration) | - | + +### Data + +[`TOP_K_THRESHOLD`](#nemo_rl-models-generation-sglang-sglang_generation-TOP_K_THRESHOLD) + +[`TOP_P_THRESHOLD`](#nemo_rl-models-generation-sglang-sglang_generation-TOP_P_THRESHOLD) + +[`logger`](#nemo_rl-models-generation-sglang-sglang_generation-logger) + +### API + + + + + +```python +class nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, + config: nemo_rl.models.generation.sglang.config.SGLangConfig, + name_prefix: str = 'sglang_policy', + workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None +) +``` + + + + + + +**Bases:** [GenerationInterface](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationInterface) + + + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.__del__() -> None +``` + + + + + + +Shuts down the worker groups when the object is deleted or is garbage collected. + +This is an extra safety net in case the user forgets to call shutdown() and the pointer to +the object is lost due to leaving a function scope. It's always recommended that the +user calls shutdown(). + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration._allocate_bundles_for_servers( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, + num_servers: int, + gpus_per_server: int +) -> list[tuple[int, list[int]]] +``` + + + + + + +Allocate GPU bundles to each SGLang server. + +Each server gets consecutive bundles within the same placement group (node). +Ray will automatically set CUDA_VISIBLE_DEVICES so each server sees logical GPUs 0, 1, 2, ..., gpus_per_server-1. + +**Parameters:** + + +The Ray virtual cluster + + + +Total number of SGLang servers to create + + + +Number of GPUs each server needs + + +**Returns:** `list[tuple[int, list[int]]]` + +List of (node_idx, [bundle_indices]) tuples for each server + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.finish_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +Sleep workers and reset prefix cache. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using SGLang. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.get_sglang_server_urls() -> list[str] +``` + + + + + + +Get base URLs of all SGLang servers. + +**Returns:** `list[str]` + +List of base URLs (e.g., ["http://localhost:30000", "http://localhost:30001"]) + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.get_sglang_url_to_gpu_uuids() -> dict[str, list[str]] +``` + + + + + + +Get mapping from SGLang server URL to list of GPU UUIDs it uses. + +**Returns:** `dict[str, list[str]]` + +Dict mapping server URL to list of GPU UUIDs + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.init_collective( + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> list[ray.ObjectRef] +``` + + + + + + +Initialize the collective communication. + +TODO: if weight updates via NCCL are needed in the future. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.invalidate_kv_cache() -> bool +``` + + + + + + +Invalidate KV cache before weight updates (Megatron-style). + +This flushes the cache before weight updates to clear stale cache. +Only primary workers (TP rank 0, model owners) will flush their cache. + +**Returns:** `bool` + +True if all caches were flushed successfully, False otherwise + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.prepare_for_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +Wake workers up for colocated inference. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.prepare_refit_info( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.shutdown() -> bool +``` + + + + + + +Shut down all SGLang workers and clean up resources. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.update_weights_from_collective() -> list[ray.ObjectRef] +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.update_weights_via_ipc_zmq() -> list[ray.ObjectRef] +``` + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.TOP_K_THRESHOLD = 8000 +``` + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.TOP_P_THRESHOLD = 0.99 +``` + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx new file mode 100644 index 0000000..a74da3c --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx @@ -0,0 +1,529 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker +title: nemo_rl.models.generation.sglang.sglang_worker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`SGLangGenerationWorker`](#nemo_rl-models-generation-sglang-sglang_worker-SGLangGenerationWorker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_require_sglang`](#nemo_rl-models-generation-sglang-sglang_worker-_require_sglang) | Import `sglang` lazily so test collection works without the optional extra. | + +### Data + +[`logger`](#nemo_rl-models-generation-sglang-sglang_worker-logger) + +### API + + + + + +```python +class nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker( + config: nemo_rl.models.generation.sglang.config.SGLangConfig, + bundle_indices: typing.Optional[list[int]] = None, + fraction_of_gpus: float = 1.0, + seed: typing.Optional[int] = None +) +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.__repr__() -> str +``` + + + + + + +Customizes the actor's prefix in the Ray logs. + +This makes it easier to identify which worker is producing specific log messages. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._build_sampling_params( + greedy: bool, + stop_strings, + max_new_tokens: typing.Optional[int] = None, + input_len: typing.Optional[int] = None, + context_length: typing.Optional[int] = None, + sample_index: typing.Optional[int] = None +) -> dict[str, typing.Any] +``` + + + + + + +Build sampling parameters dictionary for SGLang API. + +**Parameters:** + + +Whether to use greedy decoding (temperature=0.0) + + + +Merged stop strings (not used here, handled per sample) + + + +Override max_new_tokens from config if provided + + + +Input length for this sample (used for context_length adjustment) + + + +Maximum context length (if provided, adjusts max_new_tokens) + + + +Sample index (used for warning messages, 0-indexed) + + +**Returns:** `dict[str, Any]` + +Dictionary of sampling parameters compatible with SGLang API + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._ensure_session() +``` + + + + + + +async + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._generate_async( + tasks +) +``` + + + + + + +async + +Execute generation tasks with concurrency control. + +TEMP: Uses a semaphore to limit the number of concurrent requests per server, preventing server overload. +A router based solution is preffered in the future. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._generate_single_sample( + input_ids: list[int], + sampling_params: dict[str, typing.Any], + stop_string: typing.Optional[str] = None +) -> tuple[list[int], list[float]] +``` + + + + + + +async + +Generate a single sample using SGLang API (async function). + +**Parameters:** + + +List of input token IDs (without padding) + + + +Dictionary of sampling parameters (temperature, top_p, max_new_tokens, etc.) + + + +Optional stop string for this sample + + +**Returns:** `tuple[list[int], list[float]]` + +Tuple of (generated_tokens, logprobs): +- generated_tokens: List of generated token IDs +- logprobs: List of log probabilities for generated tokens + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._launch_server_process( + server_args: typing.Any +) -> multiprocessing.Process +``` + + + + + + +Launch the SGLang server process and wait for it to be ready. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._make_request( + endpoint: str, + payload: typing.Optional[dict] = None +) +``` + + + + + + +Make a POST request to the specified endpoint with the given payload. + +**Parameters:** + + +The API endpoint to call + + + +The JSON payload to send (default: empty dict) + + +**Returns:** + +The JSON response from the server + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._merge_stop_strings( + batch_stop_strings +) +``` + + + + + + +Merge stop strings from config and batch. + +**Parameters:** + + +List of stop strings from batch (one per sample) + + +**Returns:** + +List of merged stop strings (one per sample) + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.configure_worker( + num_gpus: int | float, + bundle_indices: typing.Optional[tuple[int, list[int]]] = None +) -> tuple[dict[str, typing.Any], dict[str, str], dict[str, typing.Any]] +``` + + + + + + +staticmethod + +Provides complete worker configuration for SGLang server. + +This method configures the worker based on bundle_indices which tells us +how many GPUs this server should use. + +**Parameters:** + + +Original GPU allocation for this worker based on the placement group + + + +Tuple of (node_idx, local_bundle_indices) for this server + + +**Returns:** `tuple[dict[str, Any], dict[str, str], dict[str, Any]]` + +tuple with complete worker configuration: +- 'resources': Resource allocation (e.g., num_gpus) +- 'env_vars': Environment variables for this worker +- 'init_kwargs': Parameters to pass to __init__ of the worker + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using SGLang generation. + +**Parameters:** + + +BatchedDataDict containing input_ids and input_lengths tensors + + + +Whether to use greedy decoding instead of sampling + + +**Returns:** `BatchedDataDict[GenerationOutputSpec]` + +BatchedDataDict conforming to GenerationOutputSpec: +- output_ids: input + generated token IDs with proper padding +- logprobs: Log probabilities for tokens +- generation_lengths: Lengths of each response +- unpadded_sequence_lengths: Lengths of each input + generated sequence + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.get_base_url() -> str +``` + + + + + + +Get the base URL of this SGLang server. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.get_gpu_uuids() -> list[str] +``` + + + + + + +Get list of GPU UUIDs used by this SGLang server. + +**Returns:** `list[str]` + +List of GPU UUIDs (e.g., ["GPU-xxxxx", "GPU-yyyyy"]) + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.invalidate_kv_cache() -> bool +``` + + + + + + +Invalidate KV cache before weight updates (Megatron-style). + +This flushes the cache before weight updates to clear stale cache. +Uses retry logic to handle cases where there are pending requests. + +**Returns:** `bool` + +True if flush was successful, False otherwise + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.shutdown() -> bool +``` + + + + + + +Shutdown the SGLang server process and cleanup async resources. + +**Returns:** `bool` + +True if shutdown was successful, False otherwise + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.sleep() +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.wake_up( + kwargs = {} +) +``` + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker._require_sglang() +``` + + + + + + +Import `sglang` lazily so test collection works without the optional extra. + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx new file mode 100644 index 0000000..ace8dcd --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx @@ -0,0 +1,109 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang/utils +title: nemo_rl.models.generation.sglang.utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AsyncLoopThread`](#nemo_rl-models-generation-sglang-utils-AsyncLoopThread) | A background event loop thread for running async operations in Ray actors. | + +### API + + + + + +```python +class nemo_rl.models.generation.sglang.utils.AsyncLoopThread() +``` + + + + + + +A background event loop thread for running async operations in Ray actors. + +This class creates a dedicated thread with its own event loop, allowing +synchronous Ray actor methods to execute async coroutines without blocking +the main actor thread. This is necessary because run_coroutine_threadsafe +requires the event loop to be in a different thread. + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.utils.AsyncLoopThread._start_loop() +``` + + + + + + +Run the event loop in the background thread. + + + + + + + +```python +nemo_rl.models.generation.sglang.utils.AsyncLoopThread.run( + coro +) +``` + + + + + + +Schedule a coroutine onto the loop and block until it's done. + +**Parameters:** + + +The coroutine to execute + + +**Returns:** + +The result of the coroutine + + + + + + + +```python +nemo_rl.models.generation.sglang.utils.AsyncLoopThread.shutdown() +``` + + + + + + +Shutdown the event loop and wait for the thread to finish. + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx new file mode 100644 index 0000000..c5f278e --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx @@ -0,0 +1,34 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm +title: nemo_rl.models.generation.vllm +--- + +## Submodules + +- **[`nemo_rl.models.generation.vllm.config`](/nemo-rl/nemo_rl/models/generation/vllm/config)** +- **[`nemo_rl.models.generation.vllm.utils`](/nemo-rl/nemo_rl/models/generation/vllm/utils)** +- **[`nemo_rl.models.generation.vllm.vllm_backend`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend)** +- **[`nemo_rl.models.generation.vllm.vllm_generation`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation)** +- **[`nemo_rl.models.generation.vllm.vllm_worker`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker)** +- **[`nemo_rl.models.generation.vllm.vllm_worker_async`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async)** + +## Package Contents + +### Data + +[`__all__`](#nemo_rl-models-generation-vllm-__all__) + +### API + + + + + +```python +nemo_rl.models.generation.vllm.__all__ = ['VllmConfig', 'VllmGeneration', 'VllmGenerationWorker', 'VllmAsyncGenerationWor... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx new file mode 100644 index 0000000..6ca4574 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx @@ -0,0 +1,111 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/config +title: nemo_rl.models.generation.vllm.config +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`VllmConfig`](#nemo_rl-models-generation-vllm-config-VllmConfig) | - | +| [`VllmSpecificArgs`](#nemo_rl-models-generation-vllm-config-VllmSpecificArgs) | - | + +### API + + + + + +```python +class nemo_rl.models.generation.vllm.config.VllmConfig() +``` + + + + + + +**Bases:** [GenerationConfig](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationConfig) + + + + + + + + + + + + +```python +class nemo_rl.models.generation.vllm.config.VllmSpecificArgs +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx new file mode 100644 index 0000000..5bfcfad --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx @@ -0,0 +1,113 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/utils +title: nemo_rl.models.generation.vllm.utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`aggregate_spec_decode_counters`](#nemo_rl-models-generation-vllm-utils-aggregate_spec_decode_counters) | Aggregate speculative decoding counters from multiple workers. | +| [`compute_spec_decode_metrics`](#nemo_rl-models-generation-vllm-utils-compute_spec_decode_metrics) | Compute delta and derived metrics for speculative decoding. | +| [`format_prompt_for_vllm_generation`](#nemo_rl-models-generation-vllm-utils-format_prompt_for_vllm_generation) | Format a list of prompts for vllm generation (which requires a specific format for its own `generate` method). | + +### API + + + + + +```python +nemo_rl.models.generation.vllm.utils.aggregate_spec_decode_counters( + worker_metrics: list[dict[str, float | list[float]]] +) -> dict[str | tuple[str, int], float] +``` + + + + + + +Aggregate speculative decoding counters from multiple workers. + +Combines spec decode metrics collected from DP leader workers into +a single aggregated counter dictionary. + +**Parameters:** + + +List of metric dictionaries from each worker. +Each dict maps metric names to float values or lists of floats +(for per-position metrics). + + +**Returns:** `dict[str | tuple[str, int], float]` + +Dictionary mapping metric names to their aggregated float values. + + + + + + + + +```python +nemo_rl.models.generation.vllm.utils.compute_spec_decode_metrics( + start_counters: dict[str | tuple[str, int], float], + end_counters: dict[str | tuple[str, int], float] +) -> dict[str, float] +``` + + + + + + +Compute delta and derived metrics for speculative decoding. + +Calculates the difference between two counter snapshots and derives +acceptance rate and acceptance length metrics for logging. + +**Parameters:** + + +Counter snapshot taken before generation. + + + +Counter snapshot taken after generation. + + +**Returns:** `dict[str, float]` + +Dictionary of metrics suitable for logging to wandb/tensorboard. + + + + + + + + +```python +nemo_rl.models.generation.vllm.utils.format_prompt_for_vllm_generation( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + sample_idx: typing.Optional[int] = None +) -> list[dict[str, typing.Any]] +``` + + + + + + +Format a list of prompts for vllm generation (which requires a specific format for its own `generate` method). + +See https://docs.vllm.ai/en/v0.9.1/features/multimodal_inputs.html for prompt format for multimodal inputs. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx new file mode 100644 index 0000000..c06de54 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx @@ -0,0 +1,236 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend +title: nemo_rl.models.generation.vllm.vllm_backend +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`VllmInternalWorkerExtension`](#nemo_rl-models-generation-vllm-vllm_backend-VllmInternalWorkerExtension) | - | + +### API + + + + + +```python +class nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension() +``` + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension._maybe_process_fp8_kv_cache() -> None +``` + + + + + + +Process weights after loading for FP8 KV cache (static scales). + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.cleanup() -> None +``` + + + + + + +Shutdown and cleanup resources. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.get_zmq_address() +``` + + + + + + +Get the ZMQ address for the current device. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.init_collective( + rank_prefix: int, + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> None +``` + + + + + + +Initialize the collective communication. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.maybe_init_zmq() +``` + + + + + + +Initialize the ZMQ socket if it doesn't exist. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.prepare_refit_info( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + +Prepare state dict metadata for weight refitting and IPC streaming. + +**Parameters:** + + +A dictionary containing the info for refit. +e.g. {tensor_name: (shape, dtype)} + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.report_device_id() -> str +``` + + + + + + +Retrieve the UUID of the current CUDA device. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.start_gpu_profiling() -> None +``` + + + + + + +Start GPU profiling. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.stop_gpu_profiling() -> None +``` + + + + + + +Stop GPU profiling. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.update_weights_from_collective() -> bool +``` + + + + + + +Update the model weights from collective communication. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.update_weights_via_ipc_zmq() -> bool +``` + + + + + + +Receive and update model weights via ZMQ IPC socket. + +**Returns:** `bool` + +True if weights were successfully updated. + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx new file mode 100644 index 0000000..e4cdee2 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx @@ -0,0 +1,656 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation +title: nemo_rl.models.generation.vllm.vllm_generation +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`VllmGeneration`](#nemo_rl-models-generation-vllm-vllm_generation-VllmGeneration) | - | + +### Data + +[`TOP_K_THRESHOLD`](#nemo_rl-models-generation-vllm-vllm_generation-TOP_K_THRESHOLD) + +[`TOP_P_THRESHOLD`](#nemo_rl-models-generation-vllm-vllm_generation-TOP_P_THRESHOLD) + +### API + + + + + +```python +class nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, + config: nemo_rl.models.generation.vllm.config.VllmConfig, + name_prefix: str = 'vllm_policy', + workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None +) +``` + + + + + + +**Bases:** [GenerationInterface](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationInterface) + + + + + + + + + + + + + + + + + + + + + + + + + + +Check if KV cache scales should be synchronized during refit. + +Returns True if kv_cache_dtype is fp8/fp8_e4m3. + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.__del__() -> None +``` + + + + + + +Shuts down the worker groups when the object is deleted or is garbage collected. + +This is an extra safety net in case the user forgets to call shutdown() and the pointer to +the object is lost due to leaving a function scope. It's always recommended that the +user calls shutdown(). + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._async_generate_base( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + method_name: str, + data_validation_fn, + greedy: bool = False +) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] +``` + + + + + + +async + +Base async generation method that handles common worker management logic. + +**Parameters:** + + +Input data for generation + + + +Name of the worker method to call ('generate_async' or 'generate_text_async') + + + +Function to validate input data + + + +Whether to use greedy decoding + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._get_raw_spec_counters() -> dict[str | tuple[str, int], float] +``` + + + + + + +Collect raw spec decode counters from workers. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._get_tied_worker_bundle_indices( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster +) -> list[tuple[int, list[int]]] +``` + + + + + + +Calculate bundle indices for tensor and pipeline parallel workers. + +Handles both unified placement groups (for cross-node model parallelism) and +per-node placement groups (for node-local model parallelism). + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._post_init() +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._report_device_id() -> list[list[str]] +``` + + + + + + +Report the device ID of vllm workers. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._report_dp_openai_server_base_urls() -> list[typing.Optional[str]] +``` + + + + + + +Report the data parallel OpenAI server base URLs of vLLM workers, only populated if it is async vLLM engine and the HTTP server is active. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.clear_logger_metrics() -> None +``` + + + + + + +Clear logger metrics for performance reporting. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.clear_vllm_logger_metrics() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.finish_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +Sleep workers and reset prefix cache. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using vLLM. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate_async( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] +``` + + + + + + +async + +Generate responses asynchronously, yielding individual samples as they complete. + +This method provides per-sample streaming across all workers, yielding each +sample result as soon as it's ready, regardless of which worker processed it. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate_text( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate text responses using vLLM. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate_text_async( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] +``` + + + + + + +async + +Generate text responses asynchronously, yielding results as they are ready. + +**Parameters:** + + +BatchedDataDict containing prompts with text strings + + + +Whether to use greedy decoding instead of sampling + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.get_logger_metrics() -> dict[str, typing.Any] +``` + + + + + + +Get logger metrics for performance reporting. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.get_step_metrics() -> dict[str, float] +``` + + + + + + +Get speculative decoding metrics delta since snapshot_step_metrics(). + +**Returns:** `dict[str, float]` + +Dictionary of delta metrics with 'vllm/' prefix. + +**Raises:** + +- `RuntimeWarning`: If called without snapshot_step_metrics() first. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.get_vllm_logger_metrics() -> dict[str, typing.Any] +``` + + + + + + +Collect vLLM logger metrics from vLLM workers (model-owner actors only). + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.init_collective( + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> list[ray.ObjectRef] +``` + + + + + + +Initialize the collective communication. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.invalidate_kv_cache() -> bool +``` + + + + + + +Invalidate reusable caches in vLLM (e.g., prefix/KV cache) after weight updates. + +For async_engine, calls reset_prefix_cache_async on workers. For sync, calls reset_prefix_cache. +Returns True if all workers report success. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.prepare_for_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +Wake workers up for colocated inference. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.prepare_refit_info( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + +Prepare the info for refit. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.shutdown() -> bool +``` + + + + + + +Shut down all vLLM workers and clean up resources. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.snapshot_step_metrics() -> None +``` + + + + + + +Snapshot current spec decode counters to begin tracking a training step. + +Call this before generation to establish a baseline for metrics delta. + +**Raises:** + +- `RuntimeWarning`: If called twice without get_step_metrics() in between. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.start_gpu_profiling() -> None +``` + + + + + + +Start GPU profiling. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.stop_gpu_profiling() -> None +``` + + + + + + +Stop GPU profiling. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.update_weights_from_collective() -> list[ray.ObjectRef] +``` + + + + + + +Update weights of the policy using collective communication. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.update_weights_via_ipc_zmq() -> list[ray.ObjectRef] +``` + + + + + + +Update weights of the policy using IPC handles via ZMQ socket. + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.TOP_K_THRESHOLD = 8000 +``` + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.TOP_P_THRESHOLD = 0.99 +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx new file mode 100644 index 0000000..080c9c1 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx @@ -0,0 +1,545 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker +title: nemo_rl.models.generation.vllm.vllm_worker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseVllmGenerationWorker`](#nemo_rl-models-generation-vllm-vllm_worker-BaseVllmGenerationWorker) | - | +| [`VllmGenerationWorker`](#nemo_rl-models-generation-vllm-vllm_worker-VllmGenerationWorker) | - | + +### API + + + + + +```python +class nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker( + config: nemo_rl.models.generation.vllm.config.VllmConfig, + bundle_indices: typing.Optional[list[int]] = None, + fraction_of_gpus: float = 1.0, + seed: typing.Optional[int] = None +) +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.__repr__() -> str +``` + + + + + + +Customizes the actor's prefix in the Ray logs. + +This makes it easier to identify which worker is producing specific log messages. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker._build_sampling_params( + greedy: bool, + stop_strings, + max_new_tokens: typing.Optional[int] = None +) +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker._get_raw_spec_counters() -> dict[str, float | list[float]] +``` + + + + + + +Get speculative decoding metrics from the vLLM engine. + +Collects spec decode counters including number of drafts, +draft tokens, and accepted tokens for monitoring acceptance rates. + +**Returns:** `dict[str, float | list[float]]` + +Dictionary mapping metric names to their values. + +**Raises:** + +- `AssertionError`: If called before vLLM engine is initialized. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker._merge_stop_strings( + batch_stop_strings +) +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.configure_worker( + num_gpus: int | float, + bundle_indices: typing.Optional[tuple[int, list[int]]] = None +) -> tuple[dict[str, typing.Any], dict[str, str], dict[str, typing.Any]] +``` + + + + + + +staticmethod + +Provides complete worker configuration for vLLM tensor and pipeline parallelism. + +This method configures the worker based on its role in tensor and pipeline parallelism, +which is determined directly from the bundle_indices parameter. + +**Parameters:** + + +Original GPU allocation for this worker based on the placement group + + + +Tuple of (node_idx, local_bundle_indices) for parallelism (if applicable) + + +**Returns:** `tuple[dict[str, Any], dict[str, str], dict[str, Any]]` + +tuple with complete worker configuration: +- 'resources': Resource allocation (e.g., num_gpus) +- 'env_vars': Environment variables for this worker +- 'init_kwargs': Parameters to pass to __init__ of the worker + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.is_alive() +``` + + + + + + +Check if the worker is alive. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.llm() +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.start_gpu_profiling() -> None +``` + + + + + + +Start GPU profiling. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.stop_gpu_profiling() -> None +``` + + + + + + +Stop GPU profiling. + + + + + + + + + +```python +class nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker() +``` + + + + + + +**Bases:** [BaseVllmGenerationWorker](#nemo_rl-models-generation-vllm-vllm_worker-BaseVllmGenerationWorker) + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker._create_engine( + llm_kwargs: dict[str, typing.Any] +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using vLLM generation. + +**Parameters:** + + +BatchedDataDict containing input_ids and input_lengths tensors + + + +Whether to use greedy decoding instead of sampling + + +**Returns:** `BatchedDataDict[GenerationOutputSpec]` + +BatchedDataDict conforming to GenerationOutputSpec: +- output_ids: input + generated token IDs with proper padding +- logprobs: Log probabilities for tokens +- generation_lengths: Lengths of each response +- unpadded_sequence_lengths: Lengths of each input + generated sequence + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.generate_text( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate text responses using vLLM generation. + +**Parameters:** + + +BatchedDataDict containing prompts with text strings + + + +Whether to use greedy decoding instead of sampling + + +**Returns:** `BatchedDataDict[GenerationOutputSpec]` + +BatchedDataDict containing: +- texts: List of generated text responses + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.init_collective( + rank_prefix: int, + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.post_init() +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.prepare_refit_info( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + +Prepare the info for refit. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.report_device_id() -> list[str] +``` + + + + + + +Report device ID from the vLLM worker. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.reset_prefix_cache() +``` + + + + + + +Reset the prefix cache of vLLM engine. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.shutdown() -> bool +``` + + + + + + +Clean up vLLM resources. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.sleep() +``` + + + + + + +Put the vLLM engine to sleep. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.update_weights_from_collective() -> bool +``` + + + + + + +Update the model weights from collective communication. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.update_weights_via_ipc_zmq() -> bool +``` + + + + + + +Update weights from IPC handles via ZMQ socket. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.wake_up( + kwargs = {} +) +``` + + + + + + +Wake up the vLLM engine. + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx new file mode 100644 index 0000000..b9b731a --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx @@ -0,0 +1,485 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async +title: nemo_rl.models.generation.vllm.vllm_worker_async +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`VllmAsyncGenerationWorker`](#nemo_rl-models-generation-vllm-vllm_worker_async-VllmAsyncGenerationWorker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_replace_prefix_tokens`](#nemo_rl-models-generation-vllm-vllm_worker_async-_replace_prefix_tokens) | This is a subroutine used inside the vLLM Chat Completion server. | + +### API + + + + + +```python +class nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker() +``` + + + + + + +**Bases:** [BaseVllmGenerationWorker](/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker#nemo_rl-models-generation-vllm-vllm_worker-BaseVllmGenerationWorker) + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._create_engine( + llm_kwargs: dict[str, typing.Any] +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._setup_vllm_openai_api_server( + app: fastapi.FastAPI +) -> fastapi.FastAPI +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._setup_vllm_server() -> tuple[threading.Thread, str, uvicorn.Server] +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._start_vllm_metrics_logger() -> None +``` + + + + + + +Start a background thread that periodically collects vLLM logger metrics. + +Controlled by vllm_metrics_logger_interval (default: 0.5) in vllm_cfg. +Runs only on the model-owner actor. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.clear_vllm_logger_metrics() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.generate_async( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] +``` + + + + + + +async + +Generate a batch of data using vLLM's AsyncLLMEngine, yielding results as they are ready. + +**Parameters:** + + +BatchedDataDict with input_ids and input_lengths + + + +Whether to use greedy decoding instead of sampling + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.generate_text_async( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] +``` + + + + + + +async + +Generate text responses asynchronously, yielding results as they are ready. + +**Parameters:** + + +BatchedDataDict containing prompts with text strings + + + +Whether to use greedy decoding instead of sampling + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.get_vllm_logger_metrics() -> dict[str, typing.Any] +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.init_collective_async( + rank_prefix: int, + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> None +``` + + + + + + +async + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.post_init_async() +``` + + + + + + +async + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.prepare_refit_info_async( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + +async + +Async version of prepare_refit_info. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.report_device_id_async() -> list[str] +``` + + + + + + +async + +Async version of report_device_id. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.report_dp_openai_server_base_url() -> typing.Optional[str] +``` + + + + + + +async + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.reset_prefix_cache_async() +``` + + + + + + +async + +Async version of reset_prefix_cache. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.shutdown() -> bool +``` + + + + + + +async + +Clean up vLLM resources. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.sleep_async() +``` + + + + + + +async + +Async version of sleep. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.update_weights_from_collective_async() -> bool +``` + + + + + + +async + +Async version of update_weights_from_collective. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.update_weights_via_ipc_zmq_async() -> bool +``` + + + + + + +async + +Async version of update_weights_via_ipc_zmq. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.wake_up_async( + kwargs = {} +) +``` + + + + + + +async + +Async version of wake_up. + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async._replace_prefix_tokens( + tokenizer, + model_prefix_token_ids: list[int], + template_prefix_token_ids: list[int], + template_token_ids: list[int] +) -> list[int] +``` + + + + + + +This is a subroutine used inside the vLLM Chat Completion server. + +This function is for fixing up the chat template-tokenized messages history +to match the model output tokenization up to the last assistant turn, +in order to preserve the monotonic tokens property for optimized multi-turn +training. + +Some environments (namely NeMo-Gym) require an OpenAI compatible server +endpoint rather than an inference engine handle. This is fine for the most +part, but it may cause issues when the environment is used as a part of +training. + +RL training frameworks train models on token IDs, but the OpenAI compatible +server communicates in what is basically de-tokenized text. When multiple +model calls are made to the OpenAI compatible server in a single trajectory, +model generations in previous model calls may be re-tokenized to something +that is different than what was generated. This is not too big of an issue +(that we know of) at inference time, but the log probs the model produces +are different enough for the differently re-tokenized generation result that +it causes the training to be off policy. Off policy isn't necessarily a bad +thing in isolation, but this source of off-policyness may cause unexpected +issues if not properly accounted for. It also mis-aligns the token ID +sequences across model calls, which feels very strange during training. + +There are real cases where the model output string _does not match_ the chat +template tokenization of the parsed model output. A concrete example is +inconsistent whitespace tokens around tool call special tokens. + +TODO When NeMo RL supports training image generation models, we want to +revisit and possibly update this function. This issue occurs when the model +generates tokens that are de-tokenized into text or images, and then +re-tokenized into tokens. So if there is a situation like that with images +and image tokenization is non-unique, then we will need to uppdate this +function. + +Example (turn-by-turn, concise; eos_token_id = 2): + Turn 1: + - prefill_T1 (template prefill) = [11,12,13,40,41] + - model output = [220,17,2] # decodes to " 4" + EOS + - model_prefix_token_ids = prefill_T1 + model output + => [11,12,13,40,41,220,17,2] + + Turn 2 (template retokenizes prior assistant text differently): + - template_prefix_token_ids = [11,12,13,40,41,1001,2] # 1001 decodes to " 4" + - template_token_ids = [11,12,13,40,41,1001,2,21,22,40,41] + + _replace_prefix_tokens keeps the exact prior model tokens up to EOS and + resumes from the template after that EOS: + output => [11,12,13,40,41,220,17,2,21,22,40,41] + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx new file mode 100644 index 0000000..6095398 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx @@ -0,0 +1,9 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/huggingface +title: nemo_rl.models.huggingface +--- + +## Submodules + +- **[`nemo_rl.models.huggingface.common`](/nemo-rl/nemo_rl/models/huggingface/common)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx new file mode 100644 index 0000000..f626968 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx @@ -0,0 +1,303 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/huggingface/common +title: nemo_rl.models.huggingface.common +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FlashAttentionKwargs`](#nemo_rl-models-huggingface-common-FlashAttentionKwargs) | Dataclass to hold FlashAttention v2 kwargs. | +| [`ModelFlag`](#nemo_rl-models-huggingface-common-ModelFlag) | Enum that defines special flags for model-specific behaviors. | + +### Functions + +| Name | Description | +|------|-------------| +| [`get_flash_attention_kwargs`](#nemo_rl-models-huggingface-common-get_flash_attention_kwargs) | Returns kwargs required for FlashAttention v2 forward functions. | +| [`group_and_cat_tensors`](#nemo_rl-models-huggingface-common-group_and_cat_tensors) | Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor. | +| [`is_gemma_model`](#nemo_rl-models-huggingface-common-is_gemma_model) | - | +| [`pack_sequences`](#nemo_rl-models-huggingface-common-pack_sequences) | Packs sequences into rows where each row concatenates multiple sequences. | +| [`unpack_tensor`](#nemo_rl-models-huggingface-common-unpack_tensor) | Unpacks a packed tensor into individual sequences padded to the same length. | + +### Data + +[`Tensor`](#nemo_rl-models-huggingface-common-Tensor) + +### API + + + + + +```python +class nemo_rl.models.huggingface.common.FlashAttentionKwargs( + cu_seqlens_q: nemo_rl.models.huggingface.common.Tensor, + cu_seqlens_k: nemo_rl.models.huggingface.common.Tensor, + max_seqlen_q: int, + max_seqlen_k: int +) +``` + + + + + + +Dataclass + +Dataclass to hold FlashAttention v2 kwargs. + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.huggingface.common.ModelFlag +``` + + + + + + +**Bases:** `enum.Enum` + +Enum that defines special flags for model-specific behaviors. + +This enum provides a way to identify models that require special handling or +configuration in different parts of the NeMo RL codebase. + +Each flag has a `matches` method that determines if the flag applies to a given model_name. + + + + + + + + + + +```python +nemo_rl.models.huggingface.common.get_flash_attention_kwargs( + input_lengths: torch.Tensor +) -> nemo_rl.models.huggingface.common.FlashAttentionKwargs +``` + + + + + + +Returns kwargs required for FlashAttention v2 forward functions. + +**Parameters:** + + +[batch_size] containing lengths of each sequence + + +**Returns:** `FlashAttentionKwargs` + +Dict[str, torch.Tensor | int]: +{ + "cu_seqlens_q": Tensor[int32], + "cu_seqlens_k": Tensor[int32], + "max_seqlen_q": int, + "max_seqlen_k": int +} + + + + + + + + +```python +nemo_rl.models.huggingface.common.group_and_cat_tensors( + tensors: list[torch.Tensor], + group_sizes: list[int], + padding_value: int = 0, + min_seq_len: int = 0 +) -> torch.Tensor +``` + + + + + + +Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor. + +Each group of 1D tensors is concatenated into a single 1D tensor, and all resulting +group tensors are padded to the same length and stacked into a 2D tensor. + +**Parameters:** + + +List of 1D tensors of varying lengths. + + + +List of integers. Each integer specifies how many tensors to group. + + + +Integer used to pad shorter sequences. + + + +Minimum sequence length. + + +**Returns:** `torch.Tensor` + +A 2D tensor where each row is a padded concatenation of the grouped tensors. + + + + + + + + +```python +nemo_rl.models.huggingface.common.is_gemma_model( + model_name: str +) -> bool +``` + + + + + + + + + + + + + +```python +nemo_rl.models.huggingface.common.pack_sequences( + input_ids: torch.Tensor, + input_lengths: torch.Tensor, + packed_sequence_size: list[int], + padding_value: int = 0, + return_attention_mask: bool = True, + min_seq_len: int = 0 +) -> typing.Tuple[torch.Tensor, torch.Tensor, typing.Optional[torch.Tensor]] +``` + + + + + + +Packs sequences into rows where each row concatenates multiple sequences. + +Useful for sequence packing in transformer models (e.g. for SFT training). Returns: +packed input_ids, packed position_ids, and optional attention_mask. + +**Parameters:** + + +Tensor of shape [num_sequences, max_seq_len] + + + +Tensor of shape [num_sequences], containing true lengths + + + +How many sequences to pack per row + + + +Pad value for input_ids + + + +Whether to return per-row causal attention mask + + + +Minimum sequence length. + + +**Returns:** `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]` + + +input_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] +position_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] +attention_mask (Optional[torch.Tensor]): [batch_size, max_len, max_len] if requested + + + + + + + + +```python +nemo_rl.models.huggingface.common.unpack_tensor( + tensor, + input_lengths +) +``` + + + + + + +Unpacks a packed tensor into individual sequences padded to the same length. + +**Parameters:** + + +Packed tensor of shape [batch_size, packed_seq_len]. + + + +Original sequence lengths in the order they were packed. + + +**Returns:** + +torch.Tensor: [num_sequences, max_seq_len], each row is one unpacked and padded sequence. + + + + + + + + +```python +nemo_rl.models.huggingface.common.Tensor = TypeVar('Tensor', bound=(torch.Tensor)) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx new file mode 100644 index 0000000..c37e90e --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx @@ -0,0 +1,13 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron +title: nemo_rl.models.megatron +--- + +## Submodules + +- **[`nemo_rl.models.megatron.common`](/nemo-rl/nemo_rl/models/megatron/common)** +- **[`nemo_rl.models.megatron.community_import`](/nemo-rl/nemo_rl/models/megatron/community_import)** +- **[`nemo_rl.models.megatron.config`](/nemo-rl/nemo_rl/models/megatron/config)** +- **[`nemo_rl.models.megatron.data`](/nemo-rl/nemo_rl/models/megatron/data)** +- **[`nemo_rl.models.megatron.setup`](/nemo-rl/nemo_rl/models/megatron/setup)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx new file mode 100644 index 0000000..de40812 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx @@ -0,0 +1,212 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/common +title: nemo_rl.models.megatron.common +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_round_up_to_multiple`](#nemo_rl-models-megatron-common-_round_up_to_multiple) | - | +| [`broadcast_tensor`](#nemo_rl-models-megatron-common-broadcast_tensor) | Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata. | +| [`forward_step_arbitrary_loss`](#nemo_rl-models-megatron-common-forward_step_arbitrary_loss) | Forward training step with support for packed sequences and context parallelism. | +| [`get_moe_metrics`](#nemo_rl-models-megatron-common-get_moe_metrics) | Returns Mixture of Experts (MoE) auxiliary-loss metrics. | + +### API + + + + + +```python +nemo_rl.models.megatron.common._round_up_to_multiple( + value: int, + multiple: int +) -> int +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.common.broadcast_tensor( + tensor: torch.Tensor | None, + src_rank: int, + group: torch.distributed.ProcessGroup +) -> torch.Tensor +``` + + + + + + +Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata. + +Handles the case where the input tensor might be None on non-source ranks. +If the input tensor is provided on non-source ranks, it must have the +correct shape and dtype matching the tensor on the source rank. + +**Parameters:** + + +The tensor to broadcast on the source rank. Can be None on + non-source ranks (will be created with correct shape/dtype). + If not None on non-source ranks, it's used as the buffer + for the broadcast and must match the source tensor's metadata. + + + +The global rank of the source process. + + + +The process group for communication. + + +**Returns:** `torch.Tensor` + +torch.Tensor: The broadcasted tensor. On non-source ranks, this will + be the tensor received from the source. + +**Raises:** + +- `ValueError`: If the tensor is None on the source rank, or if a tensor + provided on a non-source rank has mismatched shape/dtype/device. +- `TypeError`: If broadcasting metadata fails (e.g., due to pickling issues). + + + + + + + + +```python +nemo_rl.models.megatron.common.forward_step_arbitrary_loss( + state: megatron.bridge.training.state.GlobalState, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + data_iterator: typing.Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any]], + model: megatron.core.models.gpt.GPTModel, + loss_fn: nemo_rl.algorithms.loss_functions.LossFunction, + pack_sequences: bool = False, + defer_fp32_logits: typing.Optional[bool] = None, + cp_normalize: bool = True, + policy_cfg: typing.Optional[dict] = None +) +``` + + + + + + +Forward training step with support for packed sequences and context parallelism. + +Notes on packed sequences with context parallelism (CP): + - When CP > 1, each sequence is padded to a multiple of (cp_size * 2) + - The factor of 2 ensures load balancing for causal attention + - cu_seqlens tracks actual sequence boundaries + - cu_seqlens_padded tracks padded sequence boundaries for CP + - Requires TransformerEngine >= 1.10 for CP support + +**Parameters:** + + +Global state for the run + + + +Global count of valid sequences + + + +Global count of valid tokens + + + +Input data iterator + + + +The GPT Model + + + +Loss function to apply + + + +Whether to pack sequences for efficiency + + + +Whether to skip the conversion of logits to fp32 + + + +Whether to normalize the loss by the cp_size + + + +Policy configuration containing generation parameters + + + + + + + + + +```python +nemo_rl.models.megatron.common.get_moe_metrics( + loss_scale: float, + total_loss_dict: typing.Optional[dict] = None, + per_layer_logging: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Returns Mixture of Experts (MoE) auxiliary-loss metrics. + +This function reduces MoE auxiliary losses across ranks, aggregates them, and +returns a dictionary of metrics. + +**Parameters:** + + +Scale factor to apply to each auxiliary loss (e.g., 1/num_microbatches). + + + +If provided, accumulate means into this dict (by name). + + + +If True, include per-layer values in the returned dict. + + +**Returns:** `dict[str, Any]` + +dict[str, Any]: A flat dict of aggregated metrics. For each aux loss name, + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx new file mode 100644 index 0000000..a0f53a4 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx @@ -0,0 +1,76 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/community_import +title: nemo_rl.models.megatron.community_import +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`export_model_from_megatron`](#nemo_rl-models-megatron-community_import-export_model_from_megatron) | - | +| [`import_model_from_hf_name`](#nemo_rl-models-megatron-community_import-import_model_from_hf_name) | Import a Hugging Face model into Megatron checkpoint format and save the Megatron checkpoint to the output path. | + +### API + + + + + +```python +nemo_rl.models.megatron.community_import.export_model_from_megatron( + hf_model_name: str, + input_path: str, + output_path: str, + hf_tokenizer_path: str, + overwrite: bool = False, + hf_overrides: typing.Optional[dict[str, typing.Any]] = {} +) +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.community_import.import_model_from_hf_name( + hf_model_name: str, + output_path: str, + megatron_config: typing.Optional[nemo_rl.models.policy.MegatronConfig] = None, + config_overrides: typing.Any = {} +) +``` + + + + + + +Import a Hugging Face model into Megatron checkpoint format and save the Megatron checkpoint to the output path. + +**Parameters:** + + +Hugging Face model ID or local path (e.g., 'meta-llama/Llama-3.1-8B-Instruct'). + + + +Directory to write the Megatron checkpoint (e.g., /tmp/megatron_ckpt). + + + +Optional megatron config with paralellism settings for distributed megatron model import. + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx new file mode 100644 index 0000000..7dda8b0 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx @@ -0,0 +1,146 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/config +title: nemo_rl.models.megatron.config +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MegatronGenerationConfig`](#nemo_rl-models-megatron-config-MegatronGenerationConfig) | - | +| [`ModelAndOptimizerState`](#nemo_rl-models-megatron-config-ModelAndOptimizerState) | Container for model and optimizer state. | +| [`RuntimeConfig`](#nemo_rl-models-megatron-config-RuntimeConfig) | Runtime configuration for model training and inference. | + +### API + + + + + +```python +class nemo_rl.models.megatron.config.MegatronGenerationConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.megatron.config.ModelAndOptimizerState() +``` + + + + + + +**Bases:** `NamedTuple` + +Container for model and optimizer state. + +This named tuple holds all model-related state including the model itself, +optimizer, scheduler, and metadata about the model type and configuration. + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.megatron.config.RuntimeConfig() +``` + + + + + + +**Bases:** `NamedTuple` + +Runtime configuration for model training and inference. + +This contains all validated runtime settings needed for model initialization, +parallelization, and training. + + + + + + + + + + + + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx new file mode 100644 index 0000000..30fa12b --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx @@ -0,0 +1,471 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/data +title: nemo_rl.models.megatron.data +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ProcessedInputs`](#nemo_rl-models-megatron-data-ProcessedInputs) | Processed microbatch inputs used for model forward pass. | +| [`ProcessedMicrobatch`](#nemo_rl-models-megatron-data-ProcessedMicrobatch) | Container for a processed microbatch ready for model forward pass. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_pack_sequence_parameters_for_megatron`](#nemo_rl-models-megatron-data-_get_pack_sequence_parameters_for_megatron) | Get pack sequence parameters for Megatron model processing with optional context parallelism. | +| [`_pack_sequences_for_megatron`](#nemo_rl-models-megatron-data-_pack_sequences_for_megatron) | Pack sequences for Megatron model processing with optional context parallelism. | +| [`_unpack_sequences_from_megatron`](#nemo_rl-models-megatron-data-_unpack_sequences_from_megatron) | Unpack sequences from Megatron output format. | +| [`get_and_validate_seqlen`](#nemo_rl-models-megatron-data-get_and_validate_seqlen) | - | +| [`get_microbatch_iterator`](#nemo_rl-models-megatron-data-get_microbatch_iterator) | Create a processed microbatch iterator from a batch of data. | +| [`make_processed_microbatch_iterator`](#nemo_rl-models-megatron-data-make_processed_microbatch_iterator) | Wrap a raw microbatch iterator to yield processed microbatches. | +| [`process_global_batch`](#nemo_rl-models-megatron-data-process_global_batch) | Process a global batch and compute normalization factors. | +| [`process_microbatch`](#nemo_rl-models-megatron-data-process_microbatch) | Process a microbatch for Megatron model forward pass. | + +### API + + + + + +```python +class nemo_rl.models.megatron.data.ProcessedInputs( + input_ids: torch.Tensor, + input_ids_cp_sharded: torch.Tensor, + attention_mask: typing.Optional[torch.Tensor], + position_ids: typing.Optional[torch.Tensor], + packed_seq_params: typing.Optional[megatron.core.packed_seq_params.PackedSeqParams], + cu_seqlens_padded: typing.Optional[torch.Tensor] +) +``` + + + + + + +Dataclass + +Processed microbatch inputs used for model forward pass. + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.megatron.data.ProcessedMicrobatch( + data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + input_ids: torch.Tensor, + input_ids_cp_sharded: torch.Tensor, + attention_mask: typing.Optional[torch.Tensor], + position_ids: typing.Optional[torch.Tensor], + packed_seq_params: typing.Optional[megatron.core.packed_seq_params.PackedSeqParams], + cu_seqlens_padded: typing.Optional[torch.Tensor] +) +``` + + + + + + +Dataclass + +Container for a processed microbatch ready for model forward pass. + +This dataclass holds both the original data dictionary and the processed +tensors needed for the Megatron model forward pass. + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.data._get_pack_sequence_parameters_for_megatron( + megatron_cfg: dict, + max_seq_len_in_batch: int +) +``` + + + + + + +Get pack sequence parameters for Megatron model processing with optional context parallelism. + +**Parameters:** + + +Megatron configuration + + + +Maximum sequence length in batch + + +**Returns:** + +Tuple of: + + + + + + + + +```python +nemo_rl.models.megatron.data._pack_sequences_for_megatron( + input_ids: torch.Tensor, + seq_lengths: torch.Tensor, + pad_individual_seqs_to_multiple_of: int = 1, + pad_packed_seq_to_multiple_of: int = 1, + pad_packed_seq_to: typing.Optional[int] = None, + cp_rank: int = 0, + cp_size: int = 1 +) -> tuple[torch.Tensor, megatron.core.packed_seq_params.PackedSeqParams, torch.Tensor, typing.Optional[torch.Tensor]] +``` + + + + + + +Pack sequences for Megatron model processing with optional context parallelism. + +**Parameters:** + + +Input token IDs [batch_size, seq_length] + + + +Actual sequence lengths for each sample [batch_size] + + + +Pad individual sequences to a multiple of this value + + + +Pad packed sequences to a multiple of this value + + + +Pad packed sequences to this value (before CP) +- The three parameters above can be calculated using _get_pack_sequence_parameters_for_megatron, we do not recommend users to set these parameters manually. + + + +Context parallelism size + + +**Returns:** `torch.Tensor` + +Tuple of: + + + + + + + + +```python +nemo_rl.models.megatron.data._unpack_sequences_from_megatron( + output_tensor: torch.Tensor, + seq_lengths: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqlens_padded: typing.Optional[torch.Tensor], + original_batch_size: int, + original_seq_length: int +) -> torch.Tensor +``` + + + + + + +Unpack sequences from Megatron output format. + +**Parameters:** + + +Packed output tensor [1, T, vocab_size] + + + +Actual sequence lengths for each sample + + + +Cumulative sequence lengths + + + +Padded cumulative sequence lengths (if CP was used) + + + +Original batch size + + + +Original maximum sequence length + + +**Returns:** `torch.Tensor` + +Unpacked output tensor [batch_size, seq_length, vocab_size] + + + + + + + + +```python +nemo_rl.models.megatron.data.get_and_validate_seqlen( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +) +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.data.get_microbatch_iterator( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + cfg: dict[str, typing.Any], + mbs: int, + straggler_timer: megatron.core.utils.StragglerDetector, + seq_length_key: typing.Optional[str] = None +) -> typing.Tuple[typing.Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch], int, int, int, int] +``` + + + + + + +Create a processed microbatch iterator from a batch of data. + +This function creates an iterator that yields ProcessedMicrobatch objects, +which contain both the original data dictionary and the processed tensors +ready for model forward pass. + +**Parameters:** + + +The batch data to create microbatches from + + + +Configuration dictionary + + + +Microbatch size + + + +Key for sequence lengths in data dict (auto-detected if None) + + +**Returns:** `Iterator[ProcessedMicrobatch]` + +Tuple containing the iterator and metadata + + + + + + + + +```python +nemo_rl.models.megatron.data.make_processed_microbatch_iterator( + raw_iterator: typing.Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any]], + cfg: dict[str, typing.Any], + seq_length_key: typing.Optional[str], + pad_individual_seqs_to_multiple_of: int, + pad_packed_seq_to_multiple_of: int, + straggler_timer: megatron.core.utils.StragglerDetector, + pad_full_seq_to: typing.Optional[int] +) -> typing.Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch] +``` + + + + + + +Wrap a raw microbatch iterator to yield processed microbatches. + +This function takes a raw iterator that yields BatchedDataDict objects and +wraps it to yield ProcessedMicrobatch objects that contain both the original +data and the processed tensors ready for model forward pass. + +**Parameters:** + + +Iterator yielding raw BatchedDataDict microbatches + + + +Configuration dictionary containing sequence_packing settings + + + +Key for sequence length in data dict (required for packing) + + + +Padding multiple for individual sequences + + + +Padding multiple for packed sequences + + + +Target length for full sequence padding (optional) + + + + + + + + + +```python +nemo_rl.models.megatron.data.process_global_batch( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + dp_group: torch.distributed.ProcessGroup, + batch_idx: int, + batch_size: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] +``` + + + + + + +Process a global batch and compute normalization factors. + +**Parameters:** + + +Full dataset + + + +Index of batch to extract + + + +Size of batch to extract + + + +Loss function (used to check loss type) + + + +Data parallel mesh + + +**Returns:** `torch.Tensor` + +Dictionary containing: + + + + + + + + +```python +nemo_rl.models.megatron.data.process_microbatch( + data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + seq_length_key: typing.Optional[str] = None, + pad_individual_seqs_to_multiple_of: int = 1, + pad_packed_seq_to_multiple_of: int = 1, + pad_full_seq_to: typing.Optional[int] = None, + pack_sequences: bool = False, + straggler_timer: megatron.core.utils.StragglerDetector = None +) -> tuple[torch.Tensor, torch.Tensor, typing.Optional[torch.Tensor], typing.Optional[torch.Tensor], typing.Optional[megatron.core.packed_seq_params.PackedSeqParams], typing.Optional[torch.Tensor]] +``` + + + + + + +Process a microbatch for Megatron model forward pass. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx new file mode 100644 index 0000000..485915a --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx @@ -0,0 +1,535 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/setup +title: nemo_rl.models.megatron.setup +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MoEFloat16Module`](#nemo_rl-models-megatron-setup-MoEFloat16Module) | Float 16 Module with the ability to keep the expert bias in float32. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_apply_moe_config`](#nemo_rl-models-megatron-setup-_apply_moe_config) | Apply Mixture of Experts configuration. | +| [`_apply_parallelism_config`](#nemo_rl-models-megatron-setup-_apply_parallelism_config) | Apply tensor/pipeline/context parallelism configuration. | +| [`_apply_performance_config`](#nemo_rl-models-megatron-setup-_apply_performance_config) | Apply performance optimization configuration. | +| [`_apply_precision_config`](#nemo_rl-models-megatron-setup-_apply_precision_config) | Apply precision and dtype configuration. | +| [`_create_checkpoint_config`](#nemo_rl-models-megatron-setup-_create_checkpoint_config) | Create checkpoint configurations. | +| [`_create_megatron_config`](#nemo_rl-models-megatron-setup-_create_megatron_config) | Create the final Megatron configuration container. | +| [`_validate_chunking_config`](#nemo_rl-models-megatron-setup-_validate_chunking_config) | Validate chunking configuration. | +| [`_validate_dtype_config`](#nemo_rl-models-megatron-setup-_validate_dtype_config) | - | +| [`_validate_optimizer_config`](#nemo_rl-models-megatron-setup-_validate_optimizer_config) | Validate optimizer configuration. | +| [`_validate_training_config`](#nemo_rl-models-megatron-setup-_validate_training_config) | Validate training configuration. | +| [`destroy_parallel_state`](#nemo_rl-models-megatron-setup-destroy_parallel_state) | Safely destroy parallel state and reset async call tracking. | +| [`finalize_megatron_setup`](#nemo_rl-models-megatron-setup-finalize_megatron_setup) | Finalize the setup with remaining configurations. | +| [`handle_model_import`](#nemo_rl-models-megatron-setup-handle_model_import) | Handle HF model import if checkpoint doesn't exist. | +| [`setup_distributed`](#nemo_rl-models-megatron-setup-setup_distributed) | Handle NCCL settings, dtype mapping, and basic config setup. | +| [`setup_model_and_optimizer`](#nemo_rl-models-megatron-setup-setup_model_and_optimizer) | - | +| [`setup_model_config`](#nemo_rl-models-megatron-setup-setup_model_config) | Handle all the model configuration logic. | +| [`setup_reference_model_state`](#nemo_rl-models-megatron-setup-setup_reference_model_state) | Setup the reference model for inference and return its state dict. | +| [`validate_and_set_config`](#nemo_rl-models-megatron-setup-validate_and_set_config) | - | +| [`validate_model_paths`](#nemo_rl-models-megatron-setup-validate_model_paths) | Validate and setup model paths. | + +### Data + +[`HAVE_FSDP2`](#nemo_rl-models-megatron-setup-HAVE_FSDP2) + +[`TokenizerType`](#nemo_rl-models-megatron-setup-TokenizerType) + +### API + + + + + +```python +class nemo_rl.models.megatron.setup.MoEFloat16Module( + config: megatron.core.transformer.transformer_config.TransformerConfig, + module: torch.nn.Module +) +``` + + + + + + +**Bases:** `Float16Module` + +Float 16 Module with the ability to keep the expert bias in float32. + +**Parameters:** + + +The transformer config used to initalize the model + + + + + + + +```python +nemo_rl.models.megatron.setup.MoEFloat16Module.re_enable_float32_expert_bias() -> None +``` + + + + + + +Ensure MoE router expert bias stays in float32 for numerical stability. + +Walks the wrapped module to find MoE routers and invokes the +`_maintain_float32_expert_bias()` helper which recreates or casts the +expert bias tensors to float32 as required by Megatron-LM. + + + + + + + + + +```python +nemo_rl.models.megatron.setup._apply_moe_config( + model_cfg: typing.Any, + config: nemo_rl.models.policy.PolicyConfig +) -> None +``` + + + + + + +Apply Mixture of Experts configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._apply_parallelism_config( + model_cfg: typing.Any, + config: nemo_rl.models.policy.PolicyConfig +) -> None +``` + + + + + + +Apply tensor/pipeline/context parallelism configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._apply_performance_config( + model_cfg: typing.Any, + config: nemo_rl.models.policy.PolicyConfig +) -> None +``` + + + + + + +Apply performance optimization configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._apply_precision_config( + model_cfg: typing.Any, + config: nemo_rl.models.policy.PolicyConfig, + dtype: torch.dtype +) -> None +``` + + + + + + +Apply precision and dtype configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._create_checkpoint_config( + pretrained_path: str, + weights_path: typing.Optional[str] +) -> megatron.bridge.training.config.CheckpointConfig +``` + + + + + + +Create checkpoint configurations. + + + + + + + + +```python +nemo_rl.models.megatron.setup._create_megatron_config( + model_cfg: typing.Any, + checkpoint_config: megatron.bridge.training.config.CheckpointConfig, + config: nemo_rl.models.policy.PolicyConfig, + hf_model_name: str, + dtype: torch.dtype +) -> megatron.bridge.training.config.ConfigContainer +``` + + + + + + +Create the final Megatron configuration container. + + + + + + + + +```python +nemo_rl.models.megatron.setup._validate_chunking_config( + config: nemo_rl.models.policy.PolicyConfig +) -> None +``` + + + + + + +Validate chunking configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._validate_dtype_config( + dtype: torch.dtype, + model_cfg: typing.Any, + optimizer_cfg: typing.Any +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.setup._validate_optimizer_config( + config: nemo_rl.models.policy.PolicyConfig +) -> None +``` + + + + + + +Validate optimizer configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._validate_training_config( + config: nemo_rl.models.policy.PolicyConfig, + model_cfg: typing.Any +) -> None +``` + + + + + + +Validate training configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup.destroy_parallel_state() +``` + + + + + + +Safely destroy parallel state and reset async call tracking. + +This function is called during initialization to clean up temporary distributed +state from model import operations. Resetting async call tracking ensures that +when the main Megatron distributed context is created, all ranks start with +consistent call_idx values for async checkpointing. + + + + + + + + +```python +nemo_rl.models.megatron.setup.finalize_megatron_setup( + config: nemo_rl.models.policy.PolicyConfig, + megatron_cfg: megatron.bridge.training.config.ConfigContainer, + hf_model_name: str, + worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding, + model, + optimizer +) -> tuple +``` + + + + + + +Finalize the setup with remaining configurations. + +**Returns:** `tuple` + +Tuple of (megatron_tokenizer, megatron_bridge, should_disable_forward_pre_hook, dp_size) + + + + + + + + +```python +nemo_rl.models.megatron.setup.handle_model_import( + config: nemo_rl.models.policy.PolicyConfig, + hf_model_name: str, + pretrained_path: str, + pt_checkpoint_exists: bool +) -> None +``` + + + + + + +Handle HF model import if checkpoint doesn't exist. + + + + + + + + +```python +nemo_rl.models.megatron.setup.setup_distributed() -> None +``` + + + + + + +Handle NCCL settings, dtype mapping, and basic config setup. + + + + + + + + +```python +nemo_rl.models.megatron.setup.setup_model_and_optimizer( + policy_cfg: nemo_rl.models.policy.PolicyConfig, + megatron_cfg: megatron.bridge.training.config.ConfigContainer, + load_optimizer: bool = True, + get_embedding_ranks = None, + get_position_embedding_ranks = None +) +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.setup.setup_model_config( + config: nemo_rl.models.policy.PolicyConfig, + rank, + dtype, + hf_model_name: str, + pretrained_path: str, + weights_path: typing.Optional[str] = None +) -> tuple[megatron.bridge.training.config.ConfigContainer, typing.Any] +``` + + + + + + +Handle all the model configuration logic. + + + + + + + + +```python +nemo_rl.models.megatron.setup.setup_reference_model_state( + config: nemo_rl.models.policy.PolicyConfig, + megatron_cfg: megatron.bridge.training.config.ConfigContainer, + pretrained_path: str +) -> dict +``` + + + + + + +Setup the reference model for inference and return its state dict. + + + + + + + + +```python +nemo_rl.models.megatron.setup.validate_and_set_config( + config, + rank, + hf_model_name, + pretrained_path, + weights_path, + tokenizer +) +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.setup.validate_model_paths( + config: nemo_rl.models.policy.PolicyConfig +) -> tuple[str, str, bool] +``` + + + + + + +Validate and setup model paths. + + + + + + + + +```python +nemo_rl.models.megatron.setup.HAVE_FSDP2 = True +``` + + + + + + + + + +```python +nemo_rl.models.megatron.setup.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx new file mode 100644 index 0000000..1b3ebde --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx @@ -0,0 +1,948 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy +title: nemo_rl.models.policy +--- + +## Subpackages + +- **[`nemo_rl.models.policy.workers`](/nemo-rl/nemo_rl/models/policy/workers)** + +## Submodules + +- **[`nemo_rl.models.policy.interfaces`](/nemo-rl/nemo_rl/models/policy/interfaces)** +- **[`nemo_rl.models.policy.lm_policy`](/nemo-rl/nemo_rl/models/policy/lm_policy)** +- **[`nemo_rl.models.policy.utils`](/nemo-rl/nemo_rl/models/policy/utils)** + +## Package Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AutomodelBackendConfig`](#nemo_rl-models-policy-AutomodelBackendConfig) | Configuration for custom MoE implementation backend in Automodel. | +| [`AutomodelKwargs`](#nemo_rl-models-policy-AutomodelKwargs) | - | +| [`DTensorConfig`](#nemo_rl-models-policy-DTensorConfig) | - | +| [`DTensorConfigDisabled`](#nemo_rl-models-policy-DTensorConfigDisabled) | - | +| [`DynamicBatchingConfig`](#nemo_rl-models-policy-DynamicBatchingConfig) | - | +| [`DynamicBatchingConfigDisabled`](#nemo_rl-models-policy-DynamicBatchingConfigDisabled) | - | +| [`LoRAConfig`](#nemo_rl-models-policy-LoRAConfig) | - | +| [`LoRAConfigDisabled`](#nemo_rl-models-policy-LoRAConfigDisabled) | - | +| [`MegatronConfig`](#nemo_rl-models-policy-MegatronConfig) | - | +| [`MegatronConfigDisabled`](#nemo_rl-models-policy-MegatronConfigDisabled) | - | +| [`MegatronDDPConfig`](#nemo_rl-models-policy-MegatronDDPConfig) | - | +| [`MegatronOptimizerConfig`](#nemo_rl-models-policy-MegatronOptimizerConfig) | - | +| [`MegatronSchedulerConfig`](#nemo_rl-models-policy-MegatronSchedulerConfig) | - | +| [`PolicyConfig`](#nemo_rl-models-policy-PolicyConfig) | - | +| [`PytorchOptimizerConfig`](#nemo_rl-models-policy-PytorchOptimizerConfig) | - | +| [`RewardModelConfig`](#nemo_rl-models-policy-RewardModelConfig) | - | +| [`SequencePackingConfig`](#nemo_rl-models-policy-SequencePackingConfig) | - | +| [`SequencePackingConfigDisabled`](#nemo_rl-models-policy-SequencePackingConfigDisabled) | - | +| [`SinglePytorchMilestonesConfig`](#nemo_rl-models-policy-SinglePytorchMilestonesConfig) | - | +| [`SinglePytorchSchedulerConfig`](#nemo_rl-models-policy-SinglePytorchSchedulerConfig) | - | +| [`TokenizerConfig`](#nemo_rl-models-policy-TokenizerConfig) | - | + +### Data + +[`SchedulerMilestones`](#nemo_rl-models-policy-SchedulerMilestones) + +### API + + + + + +```python +class nemo_rl.models.policy.AutomodelBackendConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for custom MoE implementation backend in Automodel. + +Used when setting the backend in automodel_kwargs in your config. +Alternatively, pass `force_hf: true` in automodel_kwargs to fall back +to the HuggingFace implementation. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.AutomodelKwargs +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.DTensorConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.DTensorConfigDisabled +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.DynamicBatchingConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.DynamicBatchingConfigDisabled +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.LoRAConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.LoRAConfigDisabled +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.MegatronConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.MegatronConfigDisabled +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.MegatronDDPConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.MegatronOptimizerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.MegatronSchedulerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.PolicyConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.PytorchOptimizerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.RewardModelConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.SequencePackingConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.SequencePackingConfigDisabled +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.SinglePytorchMilestonesConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.SinglePytorchSchedulerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.TokenizerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.policy.SchedulerMilestones = dict[str, list[int]] +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx new file mode 100644 index 0000000..8cbb649 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx @@ -0,0 +1,574 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/interfaces +title: nemo_rl.models.policy.interfaces +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ColocatablePolicyInterface`](#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) | - | +| [`LogprobOutputSpec`](#nemo_rl-models-policy-interfaces-LogprobOutputSpec) | logprobs: Tensor of log probabilities. | +| [`PolicyInterface`](#nemo_rl-models-policy-interfaces-PolicyInterface) | Abstract base class defining the interface for RL policies. | +| [`ReferenceLogprobOutputSpec`](#nemo_rl-models-policy-interfaces-ReferenceLogprobOutputSpec) | logprobs: Tensor of log probabilities. | +| [`ScoreOutputSpec`](#nemo_rl-models-policy-interfaces-ScoreOutputSpec) | scores: Tensor of scores. | +| [`TopkLogitsOutputSpec`](#nemo_rl-models-policy-interfaces-TopkLogitsOutputSpec) | Per-position top-k logits and corresponding global token indices. | + +### API + + + + + +```python +class nemo_rl.models.policy.interfaces.ColocatablePolicyInterface() +``` + + + + + + +**Bases:** [PolicyInterface](#nemo_rl-models-policy-interfaces-PolicyInterface) + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.broadcast_weights_for_collective( + kv_scales: typing.Optional[dict[str, float]] = None +) -> list[ray.ObjectRef] +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.init_collective( + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> list[ray.ObjectRef] +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.offload_after_refit() -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.offload_before_refit() -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.prepare_for_lp_inference() -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.stream_weights_via_http( + sglang_url_to_gpu_uuids: dict[str, list[str]] +) -> list[ray.ObjectRef] +``` + + + + + + +Stream model weights to SGLang servers via HTTP API. + +**Parameters:** + + +Dict mapping SGLang server URL to list of GPU UUIDs it uses + + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.stream_weights_via_ipc_zmq( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> list[ray.ObjectRef] +``` + + + + + + +abstract + + + + + + + + + +```python +class nemo_rl.models.policy.interfaces.LogprobOutputSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +logprobs: Tensor of log probabilities. + + + + + + + + + + + +```python +class nemo_rl.models.policy.interfaces.PolicyInterface() +``` + + + + + + +Abstract + +Abstract base class defining the interface for RL policies. + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.calibrate_qkv_fp8_scales( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + micro_batch_size: typing.Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +abstract + +Calibrate FP8 scales for Q/K/V activations used by KV cache. + +**Parameters:** + + +BatchedDataDict containing input_ids and input_lengths. + + + +Optional override for micro batch size during calibration. + + + +Percentile for per-tensor amax estimation. + + + +Safety margin multiplier applied to amax. + + + +Whether to also compute scale for Q in addition to K/V. + + +**Returns:** `dict[str, Any]` + +Dict with overall configuration and per-layer scales. + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.finish_training( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.get_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] +``` + + + + + + +abstract + +Get logprobs of actions from observations. + +**Parameters:** + + +BatchedDataDict containing rollouts (tokens) + + +**Returns:** `BatchedDataDict[LogprobOutputSpec]` + +BatchedDataDict containing: +- logprobs: Tensor of logprobs of actions + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.get_reference_policy_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + micro_batch_size: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec] +``` + + + + + + +abstract + +Get logprobs of actions from observations. + +**Parameters:** + + +BatchedDataDict containing rollouts (tokens) + + +**Returns:** `BatchedDataDict[ReferenceLogprobOutputSpec]` + +BatchedDataDict containing: +- logprobs: Tensor of logprobs of actions + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.get_topk_logits( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + k: int, + micro_batch_size: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec] +``` + + + + + + +abstract + +Get per-position top-k logits and global indices for a batch of inputs. + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.prepare_for_training( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.save_checkpoint( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.shutdown() -> bool +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.train( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict, + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + eval_mode: bool = False, + gbs: typing.Optional[int] = None, + mbs: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> dict[str, typing.Any] +``` + + + + + + +abstract + +Train the policy on a global batch of data. + +**Parameters:** + + +BatchedDataDict containing rollouts (tokens) + + + +Loss function to use for training + + + +Whether to run in evaluation mode (no gradient updates) + + + +Global batch size override (if None, uses config default) + + + +Micro batch size override (if None, uses config default) + + + + + + + + + + +```python +class nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +logprobs: Tensor of log probabilities. + + + + + + + + + + + +```python +class nemo_rl.models.policy.interfaces.ScoreOutputSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +scores: Tensor of scores. + + + + + + + + + + + +```python +class nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +Per-position top-k logits and corresponding global token indices. + + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx new file mode 100644 index 0000000..7636f66 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx @@ -0,0 +1,609 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/lm_policy +title: nemo_rl.models.policy.lm_policy +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Policy`](#nemo_rl-models-policy-lm_policy-Policy) | - | + +### Data + +[`PathLike`](#nemo_rl-models-policy-lm_policy-PathLike) + +### API + + + + + +```python +class nemo_rl.models.policy.lm_policy.Policy( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, + config: nemo_rl.models.policy.PolicyConfig, + tokenizer: transformers.PreTrainedTokenizerBase, + name_prefix: str = 'lm_policy', + workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None, + init_optimizer: bool = True, + weights_path: typing.Optional[nemo_rl.models.policy.lm_policy.PathLike] = None, + optimizer_path: typing.Optional[nemo_rl.models.policy.lm_policy.PathLike] = None, + init_reference_model: bool = True, + processor: typing.Optional[transformers.AutoProcessor] = None +) +``` + + + + + + +**Bases:** [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface), [GenerationInterface](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationInterface) + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.__del__() -> None +``` + + + + + + +Shuts down the worker groups when the object is deleted or is garbage collected. + +This is an extra safety net in case the user forgets to call worker_group.shutdown() and the pointer to +the object is lost due to leaving a function scope. It's always recommended that the +user calls worker_group.shutdown(). + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.broadcast_weights_for_collective( + kv_scales: typing.Optional[dict[str, float]] = None +) -> list[ray.ObjectRef] +``` + + + + + + +Broadcast the weights for collective communication. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.calibrate_qkv_fp8_scales( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + micro_batch_size: typing.Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Trigger KV-cache FP8 scale calibration across Megatron workers and return results. + +Note: The backend `MegatronPolicyWorker.calibrate_qkv_fp8_scales` already implements +distributed reduction, returning results merged across ranks. Therefore, we shard the +input by DP and call in parallel, then take the result from the first worker. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.finish_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.finish_training( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using the policy. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.get_free_memory_bytes() -> int +``` + + + + + + +Get the available free memory. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.get_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] +``` + + + + + + +Get the logprobs of the model for a data dict. + +**Returns:** `BatchedDataDict[LogprobOutputSpec]` + +a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.get_reference_policy_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + micro_batch_size: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec] +``` + + + + + + +Get the logprobs of the reference policy for a data dict. + +Returns: Identical to get_logprobs. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.get_topk_logits( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + k: int, + micro_batch_size: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec] +``` + + + + + + +Dispatch get_topk_logits to workers (no CP/packed support initially). + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.init_collective( + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> list[ray.ObjectRef] +``` + + + + + + +Initialize the collective communication. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.invalidate_kv_cache( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.offload_after_refit() -> None +``` + + + + + + +Offload the optimizer and buffers to the CPU. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.offload_before_refit() -> None +``` + + + + + + +Offload the optimizer and buffers to the CPU. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.prepare_for_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.prepare_for_lp_inference( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.prepare_for_training( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +Prepare the info for refit. + +**Returns:** `Optional[dict[str, Any]]` + +A dictionary containing the info for refit. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.print_node_ip_and_gpu_id() -> list[tuple[str, int]] +``` + + + + + + +Print the node IP and GPU ID of the current worker. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.save_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None, + tokenizer_path: typing.Optional[str] = None, + checkpointing_cfg: typing.Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None +) -> None +``` + + + + + + +Save a checkpoint of the model. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.score( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec] +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec] +``` + + + + + + +Score a batch of data using the policy. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.shutdown() -> bool +``` + + + + + + +Shut down all HF workers and clean up resources. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.start_gpu_profiling() -> None +``` + + + + + + +Start GPU profiling. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.stop_gpu_profiling() -> None +``` + + + + + + +Stop GPU profiling. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.stream_weights_via_http( + sglang_url_to_gpu_uuids: dict[str, list[str]] +) -> list[ray.ObjectRef] +``` + + + + + + +Send the weights to SGLang servers via HTTP API. + +**Parameters:** + + +Dict mapping SGLang server URL to list of GPU UUIDs it uses + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.stream_weights_via_ipc_zmq( + buffer_size_bytes: int, + kv_scales: typing.Optional[dict[str, float]] = None +) -> list[ray.ObjectRef] +``` + + + + + + +Send the weights for IPC handles via ZMQ socket. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.train( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + eval_mode: bool = False, + gbs: typing.Optional[int] = None, + mbs: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> dict[str, typing.Any] +``` + + + + + + +Train the policy on a batch of data with a given loss function. + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.PathLike = Union[str, 'os.PathLike[Any]'] +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx new file mode 100644 index 0000000..63b8675 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx @@ -0,0 +1,624 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/utils +title: nemo_rl.models.policy.utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`IPCProtocol`](#nemo_rl-models-policy-utils-IPCProtocol) | IPC protocol constants for ZMQ weight streaming. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_gather_ipc_handlers`](#nemo_rl-models-policy-utils-_gather_ipc_handlers) | Gather IPC handlers from all ranks in the default FSDP group, then filter by server. | +| [`_send_tensor_to_sglang`](#nemo_rl-models-policy-utils-_send_tensor_to_sglang) | Send gathered IPC handlers to SGLang server via HTTP. | +| [`_setup_ipc_gather_group`](#nemo_rl-models-policy-utils-_setup_ipc_gather_group) | Setup gather configuration for IPC handlers. | +| [`apply_top_k_only`](#nemo_rl-models-policy-utils-apply_top_k_only) | Apply top-k mask to the logits. | +| [`apply_top_k_top_p`](#nemo_rl-models-policy-utils-apply_top_k_top_p) | Apply top-k and top-p masks to the logits. | +| [`calculate_aligned_size`](#nemo_rl-models-policy-utils-calculate_aligned_size) | Calculate aligned size for memory alignment. | +| [`configure_dynamo_cache`](#nemo_rl-models-policy-utils-configure_dynamo_cache) | Disable dynamo autotune_local_cache. | +| [`get_gpu_info`](#nemo_rl-models-policy-utils-get_gpu_info) | Return information about the GPU being used by this worker. | +| [`get_handle_from_tensor`](#nemo_rl-models-policy-utils-get_handle_from_tensor) | Get IPC handle from a tensor. | +| [`get_megatron_checkpoint_dir`](#nemo_rl-models-policy-utils-get_megatron_checkpoint_dir) | Gets the default megatron checkpoint directory for initial HF -> Mcore conversion. | +| [`get_runtime_env_for_policy_worker`](#nemo_rl-models-policy-utils-get_runtime_env_for_policy_worker) | Get runtime environment configuration for policy workers. | +| [`is_vllm_v1_engine_enabled`](#nemo_rl-models-policy-utils-is_vllm_v1_engine_enabled) | Check if vLLM V1 engine is enabled. | +| [`rebuild_cuda_tensor_from_ipc`](#nemo_rl-models-policy-utils-rebuild_cuda_tensor_from_ipc) | Rebuild a CUDA tensor from an IPC handle. | +| [`resolve_model_class`](#nemo_rl-models-policy-utils-resolve_model_class) | Resolve the appropriate model class for a given model name. | +| [`stream_weights_via_http_impl`](#nemo_rl-models-policy-utils-stream_weights_via_http_impl) | Stream weights to SGLang servers via HTTP API (update_weights_from_tensor). | +| [`stream_weights_via_ipc_zmq_impl`](#nemo_rl-models-policy-utils-stream_weights_via_ipc_zmq_impl) | Shared implementation for streaming weights via IPC ZMQ with improved memory management. | + +### Data + +[`AUTOMODEL_FACTORY`](#nemo_rl-models-policy-utils-AUTOMODEL_FACTORY) + +[`NEMO_AUTOMODEL_AVAILABLE`](#nemo_rl-models-policy-utils-NEMO_AUTOMODEL_AVAILABLE) + +### API + + + + + +```python +class nemo_rl.models.policy.utils.IPCProtocol +``` + + + + + + +**Bases:** `enum.Enum` + +IPC protocol constants for ZMQ weight streaming. + + + + + + + + + + + + + +```python +nemo_rl.models.policy.utils._gather_ipc_handlers( + serialized_handler: str, + gather_group: typing.Optional[torch.distributed.ProcessGroup], + gather_src: typing.Optional[int], + rank: int, + matching_ranks: typing.Optional[list[int]] = None +) -> typing.Optional[list[str]] +``` + + + + + + +Gather IPC handlers from all ranks in the default FSDP group, then filter by server. + +**Parameters:** + + +Serialized IPC handler from this rank + + + +Process group (None means use default FSDP group) + + + +Rank that will collect and filter handlers + + + +Current rank + + + +List of ranks that belong to the same SGLang server + + +**Returns:** `Optional[list[str]]` + +List of serialized handlers in rank order (only on gather_src rank), None otherwise + + + + + + + + +```python +nemo_rl.models.policy.utils._send_tensor_to_sglang( + url: str, + tensor_name: str, + gathered_handlers: list[str], + shape: torch.Size, + dtype: str, + flush_cache: bool = False +) -> None +``` + + + + + + +Send gathered IPC handlers to SGLang server via HTTP. + +Key: gathered_handlers are in rank order [rank0, rank1, ...] +SGLang will automatically match: handler = serialized_handlers[tp_rank] + +**Parameters:** + + +SGLang server URL + + + +Name of the tensor + + + +List of serialized IPC handlers in rank order + + + +Tensor shape + + + +Tensor dtype + + + +Whether to flush cache after this tensor (for last tensor) + + + + + + + + + +```python +nemo_rl.models.policy.utils._setup_ipc_gather_group( + rank: int, + current_device_uuid: str, + sglang_gpu_uuids: list[str], + sglang_url_to_gpu_uuids: dict[str, list[str]] +) -> tuple[typing.Optional[torch.distributed.ProcessGroup], typing.Optional[int], typing.Optional[list[int]]] +``` + + + + + + +Setup gather configuration for IPC handlers. + +**Returns:** `Optional[dist.ProcessGroup]` + +Tuple of (gather_group, gather_src_rank, matching_ranks) + + + + + + + + +```python +nemo_rl.models.policy.utils.apply_top_k_only( + logits: torch.Tensor, + top_k: int +) -> torch.Tensor +``` + + + + + + +Apply top-k mask to the logits. + +Simplified version of VLLM's implementation for scalar parameters. +This implementation doesn't involve sorting the entire vocab. + +Based on VLLM's implementation: +https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py +SPDX-License-Identifier: Apache-2.0 +Copyright contributors to the vLLM project + +**Parameters:** + + +Input logits tensor of shape [batch_size, seq_len, vocab_size] + + + +Top-k sampling parameter. + + +**Returns:** `torch.Tensor` + +Filtered logits with top-k applied + + + + + + + + +```python +nemo_rl.models.policy.utils.apply_top_k_top_p( + logits: torch.Tensor, + top_k: typing.Optional[int] = None, + top_p: typing.Optional[float] = None +) -> torch.Tensor +``` + + + + + + +Apply top-k and top-p masks to the logits. + +Simplified version of VLLM's implementation for scalar parameters. + +Based on VLLM's implementation: +https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py +SPDX-License-Identifier: Apache-2.0 +Copyright contributors to the vLLM project + +**Parameters:** + + +Input logits tensor of shape [batch_size, seq_len, vocab_size] + + + +Top-k sampling parameter. Set to -1 to consider all tokens. + + + +Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. + + +**Returns:** `torch.Tensor` + +Filtered logits with sampling parameters applied + + + + + + + + +```python +nemo_rl.models.policy.utils.calculate_aligned_size( + size_bytes: int, + alignment: int = 512 +) -> int +``` + + + + + + +Calculate aligned size for memory alignment. + +**Parameters:** + + +Size in bytes to align + + + +Alignment boundary in bytes (default 512) + + +**Returns:** `int` + +Aligned size in bytes(int). + + + + + + + + +```python +nemo_rl.models.policy.utils.configure_dynamo_cache() -> None +``` + + + + + + +Disable dynamo autotune_local_cache. + +Dynamo may fail at cached_autotune when there's already a cache with different order of node_bundles. +Disable autotune_local_cache as a workaround. +See https://github.com/pytorch/pytorch/issues/153791 for more details. + + + + + + + + +```python +nemo_rl.models.policy.utils.get_gpu_info( + model: torch.nn.Module +) -> dict[str, typing.Any] +``` + + + + + + +Return information about the GPU being used by this worker. + + + + + + + + +```python +nemo_rl.models.policy.utils.get_handle_from_tensor( + tensor: torch.Tensor +) -> tuple[typing.Any] +``` + + + + + + +Get IPC handle from a tensor. + + + + + + + + +```python +nemo_rl.models.policy.utils.get_megatron_checkpoint_dir() -> str +``` + + + + + + +Gets the default megatron checkpoint directory for initial HF -> Mcore conversion. + +Megatron initial checkpoint should be saved to a path available on all nodes. The directory used will take this order of precendence: +1. $NRL_MEGATRON_CHECKPOINT_DIR (if set) +2. $HF_HOME/nemo_rl (if HF_HOME is set) +3. ~/.cache/huggingface/nemo_rl + +HF_HOME is preferred since many users will also have that path mounted and it means one less directory +to mount into your runtime environment. + + + + + + + + +```python +nemo_rl.models.policy.utils.get_runtime_env_for_policy_worker( + policy_worker_name: str +) -> dict[str, typing.Any] +``` + + + + + + +Get runtime environment configuration for policy workers. + +Note: expandable_segments configuration is handled directly in the worker init methods +to ensure proper GPU detection after CUDA initialization. + + + + + + + + +```python +nemo_rl.models.policy.utils.is_vllm_v1_engine_enabled() -> bool +``` + + + + + + +Check if vLLM V1 engine is enabled. + +**Returns:** `bool` + +True if V1 engine is enabled, False otherwise (defaults to True if not set) + + + + + + + + +```python +nemo_rl.models.policy.utils.rebuild_cuda_tensor_from_ipc( + cuda_ipc_handle: tuple, + device_id: int +) -> torch.Tensor +``` + + + + + + +Rebuild a CUDA tensor from an IPC handle. + + + + + + + + +```python +nemo_rl.models.policy.utils.resolve_model_class( + model_name: str +) -> typing.Any +``` + + + + + + +Resolve the appropriate model class for a given model name. + + + + + + + + +```python +nemo_rl.models.policy.utils.stream_weights_via_http_impl( + params_generator, + sglang_url_to_gpu_uuids: dict[str, list[str]], + rank: int, + worker_name: str, + current_device_uuid: str +) -> None +``` + + + + + + +Stream weights to SGLang servers via HTTP API (update_weights_from_tensor). + +Flow: Each rank creates IPC handler → gather handlers in rank order → send list → SGLang matches by tp_rank index + +Key points: +- Each rank creates handler on its own GPU +- Handlers are gathered in rank order: [rank0_handler, rank1_handler, ...] +- List index = rank = GPU ID +- SGLang automatically matches: handler = serialized_handlers[tp_rank] + +**Parameters:** + + +Generator yielding (name, tensor) pairs + + + +Dict mapping SGLang server URL to list of GPU UUIDs it uses + + + +Worker rank for logging + + + +Name of the worker for logging + + + +UUID of the current training worker's GPU + + + + + + + + + +```python +nemo_rl.models.policy.utils.stream_weights_via_ipc_zmq_impl( + params_generator, + buffer_size_bytes: int, + zmq_socket, + rank: int, + worker_name: str +) -> None +``` + + + + + + +Shared implementation for streaming weights via IPC ZMQ with improved memory management. + +Uses ping-pong double buffering to enable overlapping communication while reusing buffers +to reduce memory allocation overhead and improve stability. + +**Parameters:** + + +Generator yielding (name, tensor) pairs + + + +total size of buffer in bytes for batching parameters + + + +ZMQ socket for communication + + + +Worker rank for logging + + + +Name of the worker for logging + + + + + + + + + +```python +nemo_rl.models.policy.utils.AUTOMODEL_FACTORY: Dict[str, Any] = {'qwen2_5_vl': AutoModelForImageTextToText, 'qwen2_vl': AutoModelForImageTextToT... +``` + + + + + + + + + +```python +nemo_rl.models.policy.utils.NEMO_AUTOMODEL_AVAILABLE = True +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx new file mode 100644 index 0000000..3becc39 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx @@ -0,0 +1,13 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers +title: nemo_rl.models.policy.workers +--- + +## Submodules + +- **[`nemo_rl.models.policy.workers.base_policy_worker`](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker)** +- **[`nemo_rl.models.policy.workers.dtensor_policy_worker`](/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker)** +- **[`nemo_rl.models.policy.workers.dtensor_policy_worker_v2`](/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2)** +- **[`nemo_rl.models.policy.workers.megatron_policy_worker`](/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker)** +- **[`nemo_rl.models.policy.workers.patches`](/nemo-rl/nemo_rl/models/policy/workers/patches)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx new file mode 100644 index 0000000..0983a40 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx @@ -0,0 +1,309 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker +title: nemo_rl.models.policy.workers.base_policy_worker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AbstractPolicyWorker`](#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker) | Base class for policy workers with shared functionality. | + +### API + + + + + +```python +class nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker() +``` + + + + + + +Base class for policy workers with shared functionality. + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.finish_training( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_free_memory_bytes() -> int +``` + + + + + + +Get the available free memory. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_gpu_info() -> dict[str, typing.Any] +``` + + + + + + +Return information about the GPU being used by this worker. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_reference_policy_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec] +``` + + + + + + +Get the logprobs from the reference policy for a batch of data. + +If micro_batch_size is provided, it will be used instead of the configured +logprob_batch_size. + +**Returns:** `BatchedDataDict[ReferenceLogprobOutputSpec]` + +a BatchedDataDict with key "reference_logprobs" and shape [batch_size, sequence_length]. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_zmq_address() -> str +``` + + + + + + +Get the ZMQ address for the current device. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.init_collective( + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> None +``` + + + + + + +Initialize the collective communication. + +**Parameters:** + + +IP address for the process group + + + +Port for the process group + + + +Total world size (train_world_size + inference_world_size) + + + +Number of training workers (used in inference cluster) + + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.is_alive() -> bool +``` + + + + + + +Check if the worker is alive. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.maybe_init_zmq() -> None +``` + + + + + + +Initialize the ZMQ socket if it doesn't exist. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.report_device_id() -> str +``` + + + + + + +Report the UUID of the current CUDA device using NVML. + +**Returns:** `str` + +UUID of the device in the format "GPU-xxxxx" + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.report_node_ip_and_gpu_id() -> tuple[str, int] +``` + + + + + + +Report the node IP and GPU ID of the current worker. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.reset_peak_memory_stats() -> None +``` + + + + + + +Reset peak memory statistics. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.shutdown() -> bool +``` + + + + + + +Shutdown the policy. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.start_gpu_profiling() -> None +``` + + + + + + +Start GPU profiling. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.stop_gpu_profiling() -> None +``` + + + + + + +Stop GPU profiling. + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx new file mode 100644 index 0000000..6fb84a0 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx @@ -0,0 +1,693 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker +title: nemo_rl.models.policy.workers.dtensor_policy_worker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DTensorPolicyWorker`](#nemo_rl-models-policy-workers-dtensor_policy_worker-DTensorPolicyWorker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`get_cpu_state_dict`](#nemo_rl-models-policy-workers-dtensor_policy_worker-get_cpu_state_dict) | Copy the state dict generator to CPU memory. | +| [`unshard_fsdp2_model`](#nemo_rl-models-policy-workers-dtensor_policy_worker-unshard_fsdp2_model) | Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference. | + +### API + + + + + +```python +class nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker( + config: nemo_rl.models.policy.PolicyConfig, + tokenizer: transformers.AutoTokenizer, + processor: typing.Optional[transformers.AutoProcessor] = None, + weights_path: typing.Optional[str] = None, + optimizer_path: typing.Optional[str] = None, + init_optimizer: bool = True, + init_reference_model: bool = True, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [AbstractPolicyWorker](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker), [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.__repr__() -> str +``` + + + + + + +Customizes the actor's prefix in the Ray logs. + +This makes it easier to identify which worker is producing specific log messages. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker._add_noise_to_weights() -> None +``` + + + + + + +Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker._apply_temperature_scaling( + logits: torch.Tensor +) -> torch.Tensor +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.broadcast_weights_for_collective( + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Broadcast the weights for collective communication. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.calibrate_qkv_fp8_scales( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Placeholder for FP8 Q/K/V scale calibration, not implemented for DTensorPolicyWorker. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.create_context_parallel_ctx( + cp_mesh: torch.distributed.device_mesh.DeviceMesh, + cp_buffers: list[torch.Tensor], + cp_seq_dims: list[int], + cp_no_restore_buffers: typing.Set[torch.Tensor], + cp_rotate_method: typing.Optional[str] = None +) +``` + + + + + + +staticmethod + +Create a context parallel context. + +**Parameters:** + + +The device mesh for context parallel. + + + +The buffers for context parallel. + + + +The sequence dimensions for context parallel. + + + +The no restore buffers for context parallel. + + + +The rotation method for context parallel, such as "allgather" or "addtoall". + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.get_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] +``` + + + + + + +Get the logprobs of the model for a batch of data. + +Uses the configured logprob_batch_size to do microbatching. + +Input data is assumed to be right-padded. The method internally converts to +left-padded format for computation, and returns outputs in right-padded format. + +**Returns:** `BatchedDataDict[LogprobOutputSpec]` + +a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.get_topk_logits( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + k: int, + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +``` + + + + + + +Return per-position top-k logits and corresponding global indices. + +Notes: +- Return shapes are [B, S, k]. +- Computes top-k over the full sequence (no trimming of the last position). +- If alignment with next-token targets is required, the caller should handle it. +- If logits are TP-sharded DTensor, performs distributed global top-k across TP. +- Supports context parallelism with proper CP gather. +- Otherwise, computes local top-k on full-vocab tensor. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.load_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None +) -> None +``` + + + + + + +Load a checkpoint into the model. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_buffer_to_device( + model: torch.nn.Module, + device: str | torch.device +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_optimizer_to_device( + device: str | torch.device +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_to_cpu( + model: torch.nn.Module +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_to_cuda( + model: torch.nn.Module +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_to_device( + model: torch.nn.Module, + device: str | torch.device +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.offload_after_refit() -> None +``` + + + + + + +Offload as much as possible on the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.offload_before_refit() -> None +``` + + + + + + +Offload the optimizer to the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.prepare_for_lp_inference() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.prepare_for_training( + args = (), + kwargs = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +Prepare state dict metadata for weight refitting and IPC streaming. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.return_model_config() -> dict[str, typing.Any] +``` + + + + + + +Return the model configuration as a dictionary. + +**Returns:** `dict[str, Any]` + +Model configuration dictionary + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.return_state_dict() +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.save_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None, + tokenizer_path: typing.Optional[str] = None +) -> None +``` + + + + + + +Save a checkpoint of the model. + +the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.score( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec] +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.stream_weights_via_ipc_zmq( + buffer_size_bytes: int = 0, + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Stream model weights to peer process via ZMQ IPC socket. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.train( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + eval_mode: bool = False, + gbs: typing.Optional[int] = None, + mbs: typing.Optional[int] = None +) -> dict[str, typing.Any] +``` + + + + + + +Train the policy on a batch of data with a given loss function. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.train_context( + cp_context: typing.Optional[typing.Generator[None, None, None]] = None +) +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.use_reference_model() -> typing.Generator[None, None, None] +``` + + + + + + +Context manager that temporarily swaps the reference model and active model. + +On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references +On exit: Restores original references and re-flips cuda/cpu + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.get_cpu_state_dict( + state_generator: typing.Iterable[tuple[str, typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]]], + pin_memory: bool = False +) -> dict[str, torch.Tensor] +``` + + + + + + +Copy the state dict generator to CPU memory. + +**Parameters:** + + + +An iterable that yields (key, tensor) pairs from a model state. + + + + +Whether to allocate the CPU tensors in pinned memory for faster GPU transfer. +Defaults to False. + + +**Returns:** `dict[str, torch.Tensor]` + +dict[str, torch.Tensor]: A dictionary mapping parameter names to CPU tensors. + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.unshard_fsdp2_model( + model: torch.nn.Module +) -> typing.Generator[None, None, None] +``` + + + + + + +Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx new file mode 100644 index 0000000..b6ff0e4 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx @@ -0,0 +1,714 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2 +title: nemo_rl.models.policy.workers.dtensor_policy_worker_v2 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DTensorPolicyWorkerV2`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-DTensorPolicyWorkerV2) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_maybe_adapt_tensor_to_hf`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-_maybe_adapt_tensor_to_hf) | - | +| [`_maybe_merge_lora_weight`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-_maybe_merge_lora_weight) | - | +| [`dtensor_params_generator`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-dtensor_params_generator) | Generator that yields (name, tensor) pairs, converting DTensors to local tensors and adapting to HF format. | +| [`get_train_context`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-get_train_context) | Create combined context manager for training with context parallel and autocast. | + +### API + + + + + +```python +class nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2( + config: nemo_rl.models.policy.PolicyConfig, + tokenizer: transformers.AutoTokenizer, + processor: typing.Optional[transformers.AutoProcessor] = None, + weights_path: typing.Optional[str] = None, + optimizer_path: typing.Optional[str] = None, + init_optimizer: bool = True, + init_reference_model: bool = True, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [AbstractPolicyWorker](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker), [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.__repr__() -> str +``` + + + + + + +Customizes the actor's prefix in the Ray logs. + +This makes it easier to identify which worker is producing specific log messages. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2._add_noise_to_weights() -> None +``` + + + + + + +Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2._init_checkpoint_manager( + config_updates: typing.Optional[dict[str, typing.Any]] = None, + checkpoint_root: typing.Optional[str] = None +) -> None +``` + + + + + + +Initialize the AutomodelCheckpointManager for this worker. + +This creates the checkpoint manager bound to this worker's device meshes +and initializes its underlying checkpointer. + +**Parameters:** + + +Dict of CheckpointingConfig fields to set during initialization. + + + +Optional root directory for checkpoints. + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.broadcast_weights_for_collective( + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Broadcast the weights for collective communication. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.calibrate_qkv_fp8_scales( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Placeholder for FP8 Q/K/V scale calibration, not implemented for DTensorPolicyWorkerV2. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.get_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] +``` + + + + + + +Get the logprobs of the model for a batch of data. + +Uses the configured logprob_batch_size to do microbatching. + +Input data is assumed to be right-padded. The method internally converts to +left-padded format for computation, and returns outputs in right-padded format. + +**Returns:** `BatchedDataDict[LogprobOutputSpec]` + +a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.get_topk_logits( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + k: int, + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +``` + + + + + + +Return per-position top-k logits and corresponding global indices. + +Notes: +- Return shapes are [B, S, k]. +- Computes top-k over the full sequence (no trimming of the last position). +- If alignment with next-token targets is required, the caller should handle it. +- If logits are TP-sharded DTensor, performs distributed global top-k across TP. +- Supports context parallelism with proper CP gather. +- Otherwise, computes local top-k on full-vocab tensor. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.load_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None +) -> None +``` + + + + + + +Load a checkpoint into the model using Automodel Checkpointer. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_buffer_to_device( + model: torch.nn.Module, + device: str | torch.device +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_optimizer_to_device( + device: str | torch.device +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_to_cpu( + model: torch.nn.Module +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_to_cuda( + model: torch.nn.Module +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_to_device( + model: torch.nn.Module, + device: str | torch.device +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.offload_after_refit() -> None +``` + + + + + + +Offload as much as possible on the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.offload_before_refit() -> None +``` + + + + + + +Offload the optimizer to the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.prepare_for_lp_inference() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.prepare_for_training( + args = (), + kwargs = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +Prepare state dict metadata for weight refitting and IPC streaming. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.return_model_config() -> dict[str, typing.Any] +``` + + + + + + +Return the model configuration as a dictionary. + +**Returns:** `dict[str, Any]` + +Model configuration dictionary + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.return_state_dict() +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.save_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None, + tokenizer_path: typing.Optional[str] = None, + checkpointing_cfg: typing.Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None +) -> None +``` + + + + + + +Save a checkpoint of the model. + +the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.score( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec] +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.stream_weights_via_http( + sglang_url_to_gpu_uuids: dict[str, list[str]] +) -> None +``` + + + + + + +Stream model weights to SGLang servers via HTTP API. + +**Parameters:** + + +Dict mapping SGLang server URL to list of GPU UUIDs it uses + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.stream_weights_via_ipc_zmq( + buffer_size_bytes: int = 0, + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Stream model weights to peer process via ZMQ IPC socket. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.train( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + eval_mode: bool = False, + gbs: typing.Optional[int] = None, + mbs: typing.Optional[int] = None +) -> dict[str, typing.Any] +``` + + + + + + +Train the policy on a batch of data with a given loss function. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.use_reference_model() -> typing.Generator[None, None, None] +``` + + + + + + +Context manager that temporarily swaps the reference model and active model. + +On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references +On exit: Restores original references and re-flips cuda/cpu + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2._maybe_adapt_tensor_to_hf( + model_part: torch.nn.Module, + fqn: str, + tensor: torch.Tensor, + quantization: bool = False +) -> list[tuple[str, torch.Tensor]] +``` + + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2._maybe_merge_lora_weight( + module_map: dict[str, torch.nn.Module], + fqn: str, + tensor: torch.Tensor +) -> torch.Tensor +``` + + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.dtensor_params_generator( + model: torch.nn.Module, + target_dtype: torch.dtype +) -> typing.Generator[tuple[str, torch.Tensor], None, None] +``` + + + + + + +Generator that yields (name, tensor) pairs, converting DTensors to local tensors and adapting to HF format. + +**Parameters:** + + +The model whose parameters to generate. + + + +The dtype to convert tensors to. + + + +Optional LoRA config for filtering which layers to merge. + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.get_train_context( + cp_size: int, + cp_mesh: typing.Any, + cp_buffers: list, + sequence_dim: int, + dtype: torch.dtype, + autocast_enabled: bool = True +) -> typing.Generator[None, None, None] +``` + + + + + + +Create combined context manager for training with context parallel and autocast. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx new file mode 100644 index 0000000..c8803a8 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx @@ -0,0 +1,682 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker +title: nemo_rl.models.policy.workers.megatron_policy_worker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MegatronPolicyWorker`](#nemo_rl-models-policy-workers-megatron_policy_worker-MegatronPolicyWorker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`broadcast_object_across_pp_ranks`](#nemo_rl-models-policy-workers-megatron_policy_worker-broadcast_object_across_pp_ranks) | Broadcast an object across pipeline parallel ranks. | + +### Data + +[`TokenizerType`](#nemo_rl-models-policy-workers-megatron_policy_worker-TokenizerType) + +### API + + + + + +```python +class nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker( + config: nemo_rl.models.policy.PolicyConfig, + tokenizer: nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType, + weights_path: typing.Optional[str] = None, + optimizer_path: typing.Optional[str] = None, + init_optimizer: bool = True, + init_reference_model: bool = True, + worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [AbstractPolicyWorker](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker), [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.__repr__() +``` + + + + + + +Customizes the actor's prefix in the Ray logs. + +This makes it easier to identify which worker is producing specific log messages. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker._calculate_refit_param_info() -> list[tuple[str, int]] +``` + + + + + + +Calculate parameter information for refit. + +Each task contains: +- param_name: Local parameter name without module prefixes +- mapping: MegatronParamMapping instance for weight transformation +- pp_rank: Pipeline-parallel rank owning the parameter +- vp_stage: Virtual-pipeline stage index +- megatron_module: Reference to Megatron model/submodule +- param_weight: Target parameter tensor for converted weight + +**Returns:** `list[tuple[str, int]]` + +List of (parameter_name, size_in_bytes) tuples. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker._iter_params_with_optional_kv_scales( + kv_scales: typing.Optional[dict[str, float]] = None +) -> typing.Iterator[tuple[str, torch.Tensor]] +``` + + + + + + +Yield exported HF parameters and optionally append FP8 KV/Q scale tensors. + +This helper is used by both IPC-based streaming and collective broadcast +so that the logic for adding KV scales stays consistent in one place. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.broadcast_weights_for_collective( + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Broadcast the weights for collective communication. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.calibrate_qkv_fp8_scales( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +One-shot calibration of Q/K/V activation scales (for FP8 KV cache). + +- Captures each layer's `query_key_value` output through forward hooks, splits Q/K/V, and computes percentile amax. +- In parallel (DP/TP/PP) environments, first computes local percentiles, then takes max across all ranks for conservativeness. +- By default only returns and saves K/V scales, optionally returns Q. + +**Parameters:** + + +Representative sample batch for calibration, following get_logprobs input conventions. + + + +Micro batch size during calibration; if None, reuses logprob_batch_size. + + + +Percentile for amax (e.g. 99.9). + + + +Margin factor, e.g. 1.05. + + + +If provided, rank0 will save results as JSON. + + + +Whether to also return Q scale (usually only K/V needed). + + +**Returns:** `dict[str, Any]` + +{ "format": "fp8", "percentile": float, "margin": float, +"layers": { layer_name: {"k_scale": float, "v_scale": float[, "q_scale": float] } } } + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.check_tensor_parallel_attributes() -> dict[str, typing.Any] +``` + + + + + + +Check tensor parallel attributes on model parameters. + +**Returns:** `dict[str, Any]` + +Dictionary containing information about tensor parallel parameters: + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.disable_forward_pre_hook( + param_sync = True +) +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.enable_forward_pre_hook() +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using huggingface framework generation. + +Returns: + BatchedDataDict conforming to GenerationOutputSpec: + - output_ids: input + generated token IDs + - logprobs: Log probabilities for each token + - generation_lengths: Lengths of each response + +**Parameters:** + + +BatchedDataDict containing input_ids and input_lengths tensors + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.get_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] +``` + + + + + + +Get the logprobs of the model for a batch of data. + +Uses the configured logprob_batch_size to do microbatching. +Input data is assumed to be right-padded. The method internally converts to +left-padded format for computation, and returns outputs in right-padded format. +If micro_batch_size is provided, it will be used instead of the configured +logprob_batch_size. + +**Returns:** `BatchedDataDict[LogprobOutputSpec]` + +a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.get_topk_logits( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + k: int, + micro_batch_size: typing.Optional[int] = None +) +``` + + + + + + +Get the top-k logits and indices for a batch of data. + +The major difference from get_logprobs is that we compute top-k logits and indices for each position in the sequence. + +**Returns:** + +BatchedDataDict containing: +- topk_logits: Tensor of top-k logits for each position in the sequence +- topk_indices: Tensor of top-k indices for each position in the sequence + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.load_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None +) +``` + + + + + + +Load a training checkpoint. + +**Parameters:** + + +The exact directory path from which to load the checkpoint. + + + +If not None, attempts to load optimizer and scheduler states + if self.optimizer and self.scheduler are initialized. + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.move_model( + model: torch.nn.Module, + device: str, + move_params: bool = True, + move_grads: bool = True +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.move_optimizer( + device: str +) +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.offload_after_refit() +``` + + + + + + +Offload as much as possible on the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.offload_before_refit() +``` + + + + + + +Offload the optimizer and buffers to the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.prepare_for_lp_inference() +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.prepare_for_training( + args = (), + kwargs = {} +) +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.prepare_refit_info() -> None +``` + + + + + + +Prepare state dict metadata for weight refitting and IPC streaming. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.save_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None, + kwargs = {} +) +``` + + + + + + +Save a training checkpoint. + +**Parameters:** + + +The specific directory path where the checkpoint will be saved. + + + +If not None, optimizer and scheduler states are saved if they exist. + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.stream_weights_via_ipc_zmq( + buffer_size_bytes: int = 0, + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Stream model weights to peer process via ZMQ IPC socket. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.train( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict, + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + eval_mode: bool = False, + gbs: typing.Optional[int] = None, + mbs: typing.Optional[int] = None +) -> dict[str, typing.Any] +``` + + + + + + +Train the policy on a batch of data with a given loss function. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.use_reference_model() +``` + + + + + + +Context manager that temporarily swaps the reference model and active model. + +On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references +On exit: Restores original references and re-flips cuda/cpu + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.broadcast_object_across_pp_ranks( + obj +) +``` + + + + + + +Broadcast an object across pipeline parallel ranks. + +This utility function handles broadcasting an object from the rank that owns it +to all other pipeline parallel ranks. If only one rank has the object (non-None), +it will be broadcast to all other ranks. + +**Parameters:** + + +The object to broadcast. Can be None on ranks that don't own it. + + +**Returns:** + +The object on all ranks (either the original or the broadcast copy). + +**Raises:** + +- `ValueError`: If the object doesn't exist on any pipeline parallel rank. + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx new file mode 100644 index 0000000..4250027 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx @@ -0,0 +1,85 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers/patches +title: nemo_rl.models.policy.workers.patches +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_transformer_engine_file`](#nemo_rl-models-policy-workers-patches-_get_transformer_engine_file) | Return absolute path to a Transformer Engine file or raise if it cannot be found. | +| [`apply_torch_aten_alias_tensor_patch`](#nemo_rl-models-policy-workers-patches-apply_torch_aten_alias_tensor_patch) | Register a sharding rule for `torch.ops.aten.alias.default`. | +| [`apply_transformer_engine_patch`](#nemo_rl-models-policy-workers-patches-apply_transformer_engine_patch) | Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files. | + +### API + + + + + +```python +nemo_rl.models.policy.workers.patches._get_transformer_engine_file( + relative_path: str +) -> str +``` + + + + + + +Return absolute path to a Transformer Engine file or raise if it cannot be found. + +The relative_path should be a POSIX-style path under the transformer_engine +package root, e.g. "pytorch/triton/permutation.py". + + + + + + + + +```python +nemo_rl.models.policy.workers.patches.apply_torch_aten_alias_tensor_patch() +``` + + + + + + +Register a sharding rule for `torch.ops.aten.alias.default`. + +Work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' +in PyTorch 2.9. See https://github.com/pytorch/pytorch/pull/166867 for the upstream fix. +We can remove this patch when we upgrade torch to include this fix. + + + + + + + + +```python +nemo_rl.models.policy.workers.patches.apply_transformer_engine_patch() +``` + + + + + + +Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files. + +This locates the target file via importlib metadata instead of importing +`transformer_engine`, to avoid side effects during initialization. If the +permutation module has already been imported, it will be reloaded so that +the patched source takes effect. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx new file mode 100644 index 0000000..2dc77ed --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx @@ -0,0 +1,235 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/package_info +title: nemo_rl.package_info +--- + +## Module Contents + +### Data + +[`MAJOR`](#nemo_rl-package_info-MAJOR) + +[`MINOR`](#nemo_rl-package_info-MINOR) + +[`PATCH`](#nemo_rl-package_info-PATCH) + +[`PRE_RELEASE`](#nemo_rl-package_info-PRE_RELEASE) + +[`VERSION`](#nemo_rl-package_info-VERSION) + +[`__contact_emails__`](#nemo_rl-package_info-__contact_emails__) + +[`__contact_names__`](#nemo_rl-package_info-__contact_names__) + +[`__description__`](#nemo_rl-package_info-__description__) + +[`__download_url__`](#nemo_rl-package_info-__download_url__) + +[`__homepage__`](#nemo_rl-package_info-__homepage__) + +[`__keywords__`](#nemo_rl-package_info-__keywords__) + +[`__license__`](#nemo_rl-package_info-__license__) + +[`__package_name__`](#nemo_rl-package_info-__package_name__) + +[`__repository_url__`](#nemo_rl-package_info-__repository_url__) + +[`__shortversion__`](#nemo_rl-package_info-__shortversion__) + +[`__version__`](#nemo_rl-package_info-__version__) + +### API + + + + + +```python +nemo_rl.package_info.MAJOR = 0 +``` + + + + + + + + + +```python +nemo_rl.package_info.MINOR = 5 +``` + + + + + + + + + +```python +nemo_rl.package_info.PATCH = 0 +``` + + + + + + + + + +```python +nemo_rl.package_info.PRE_RELEASE = 'rc0' +``` + + + + + + + + + +```python +nemo_rl.package_info.VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE) +``` + + + + + + + + + +```python +nemo_rl.package_info.__contact_emails__ = 'nemo-tookit@nvidia.com' +``` + + + + + + + + + +```python +nemo_rl.package_info.__contact_names__ = 'NVIDIA' +``` + + + + + + + + + +```python +nemo_rl.package_info.__description__ = 'NeMo-RL - a toolkit for model alignment' +``` + + + + + + + + + +```python +nemo_rl.package_info.__download_url__ = 'https://github.com/NVIDIA-NeMo/RL/releases' +``` + + + + + + + + + +```python +nemo_rl.package_info.__homepage__ = 'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/' +``` + + + + + + + + + +```python +nemo_rl.package_info.__keywords__ = 'deep learning, machine learning, gpu, NLP, NeMo, nvidia, pytorch, torch, langua... +``` + + + + + + + + + +```python +nemo_rl.package_info.__license__ = 'Apache2' +``` + + + + + + + + + +```python +nemo_rl.package_info.__package_name__ = 'nemo_rl' +``` + + + + + + + + + +```python +nemo_rl.package_info.__repository_url__ = 'https://github.com/NVIDIA-NeMo/RL' +``` + + + + + + + + + +```python +nemo_rl.package_info.__shortversion__ = '.'.join(map(str, VERSION[:3])) +``` + + + + + + + + + +```python +nemo_rl.package_info.__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:]) +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx new file mode 100644 index 0000000..b7dfc66 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx @@ -0,0 +1,22 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils +title: nemo_rl.utils +--- + +## Submodules + +- **[`nemo_rl.utils.automodel_checkpoint`](/nemo-rl/nemo_rl/utils/automodel_checkpoint)** +- **[`nemo_rl.utils.checkpoint`](/nemo-rl/nemo_rl/utils/checkpoint)** +- **[`nemo_rl.utils.config`](/nemo-rl/nemo_rl/utils/config)** +- **[`nemo_rl.utils.flops_formulas`](/nemo-rl/nemo_rl/utils/flops_formulas)** +- **[`nemo_rl.utils.flops_tracker`](/nemo-rl/nemo_rl/utils/flops_tracker)** +- **[`nemo_rl.utils.logger`](/nemo-rl/nemo_rl/utils/logger)** +- **[`nemo_rl.utils.memory_tracker`](/nemo-rl/nemo_rl/utils/memory_tracker)** +- **[`nemo_rl.utils.native_checkpoint`](/nemo-rl/nemo_rl/utils/native_checkpoint)** +- **[`nemo_rl.utils.nsys`](/nemo-rl/nemo_rl/utils/nsys)** +- **[`nemo_rl.utils.nvml`](/nemo-rl/nemo_rl/utils/nvml)** +- **[`nemo_rl.utils.packed_tensor`](/nemo-rl/nemo_rl/utils/packed_tensor)** +- **[`nemo_rl.utils.prefetch_venvs`](/nemo-rl/nemo_rl/utils/prefetch_venvs)** +- **[`nemo_rl.utils.timer`](/nemo-rl/nemo_rl/utils/timer)** +- **[`nemo_rl.utils.venvs`](/nemo-rl/nemo_rl/utils/venvs)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx new file mode 100644 index 0000000..2afdec6 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx @@ -0,0 +1,436 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/automodel_checkpoint +title: nemo_rl.utils.automodel_checkpoint +--- + +Automodel checkpoint utilities for DTensor policy workers. + +This module provides a wrapper class around the nemo_automodel Checkpointer +for saving and loading model checkpoints in DTensor-based policy workers. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AutomodelCheckpointManager`](#nemo_rl-utils-automodel_checkpoint-AutomodelCheckpointManager) | Manages checkpointing for DTensor-based models using nemo_automodel's Checkpointer. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_infer_checkpoint_root`](#nemo_rl-utils-automodel_checkpoint-_infer_checkpoint_root) | Infer checkpoint root directory from weights path. | +| [`detect_checkpoint_format`](#nemo_rl-utils-automodel_checkpoint-detect_checkpoint_format) | Detect model save format and PEFT status from checkpoint directory. | + +### API + + + + + +```python +class nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager( + dp_mesh: torch.distributed.device_mesh.DeviceMesh, + tp_mesh: torch.distributed.device_mesh.DeviceMesh, + model_state_dict_keys: typing.Optional[list[str]] = None, + moe_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None +) +``` + + + + + + +Manages checkpointing for DTensor-based models using nemo_automodel's Checkpointer. + +This class provides a clean interface for saving and loading model checkpoints, +wrapping the nemo_automodel Checkpointer with configuration management. + + + + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager._get_dp_rank() -> int +``` + + + + + + +Get the data parallel rank. + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager._get_tp_rank() -> int +``` + + + + + + +Get the tensor parallel rank. + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager._rebuild_checkpointer_addons() -> None +``` + + + + + + +Rebuild the checkpointer's _addons list based on current config. + +The Checkpointer's _addons list is populated during __init__ based on config. +When config changes (e.g., model_save_format or is_peft), we need to rebuild +the addons list to match the new config. + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.init_checkpointer( + config_updates: typing.Optional[dict[str, typing.Any]] = None, + checkpoint_root: typing.Optional[str] = None +) -> None +``` + + + + + + +Initialize the Automodel Checkpointer if not already created. + +This method creates a new Checkpointer instance with the provided configuration. +If a checkpointer already exists, this method does nothing. + +**Parameters:** + + +Dict of CheckpointingConfig fields to set during initialization. + + + +Optional root directory for checkpoints. + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.load_base_model( + model: torch.nn.Module, + model_name: str, + hf_cache_dir: typing.Optional[str] = None, + dequantize_base_checkpoint: bool = False, + peft_init_method: typing.Optional[str] = None +) -> None +``` + + + + + + +Load base model weights using the Automodel Checkpointer. + +This method loads the initial HuggingFace model weights into the parallelized model. + +**Parameters:** + + +The model to load weights into. + + + +Name or path of the model. + + + +Optional HuggingFace cache directory. + + + +Whether to dequantize the base checkpoint. + + +**Raises:** + +- `AssertionError`: If checkpointer has not been initialized. + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.load_checkpoint( + model: torch.nn.Module, + weights_path: str, + optimizer: typing.Optional[torch.optim.Optimizer] = None, + optimizer_path: typing.Optional[str] = None, + scheduler: typing.Optional[torch.optim.lr_scheduler.LRScheduler] = None +) -> None +``` + + + + + + +Load a checkpoint into the model using Automodel Checkpointer. + +**Parameters:** + + +The model to load weights into. + + + +Path to the checkpoint weights. + + + +Optional optimizer to load state into. + + + +Optional path to optimizer checkpoint. + + + +Optional learning rate scheduler. + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.save_checkpoint( + model: torch.nn.Module, + weights_path: str, + optimizer: typing.Optional[torch.optim.Optimizer] = None, + optimizer_path: typing.Optional[str] = None, + scheduler: typing.Optional[torch.optim.lr_scheduler.LRScheduler] = None, + tokenizer: typing.Optional[transformers.AutoTokenizer] = None, + tokenizer_path: typing.Optional[str] = None, + checkpointing_cfg: typing.Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None, + lora_enabled: bool = False, + peft_config: typing.Optional[nemo_automodel.components._peft.lora.PeftConfig] = None +) -> None +``` + + + + + + +Save a checkpoint of the model. + +The optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + +**Parameters:** + + +The model to save. + + + +Path to save model weights. + + + +Optional optimizer to save. + + + +Optional path to save optimizer state. + + + +Optional learning rate scheduler. + + + +Optional tokenizer to save with the checkpoint. + + + +Optional path to save tokenizer separately. + + + +Checkpointing configuration. + + + +Whether LoRA is enabled. + + + +Optional PEFT configuration. + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.set_model_state_dict_keys( + keys: list[str] +) -> None +``` + + + + + + +Set the model state dict keys for checkpoint validation. + +**Parameters:** + + +List of model state dict keys. + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.update_checkpointer_config( + config_updates: typing.Optional[dict[str, typing.Any]] = None, + checkpoint_root: typing.Optional[str] = None +) -> None +``` + + + + + + +Update the configuration of an existing Checkpointer. + +This method updates the mutable config fields on the existing Checkpointer instance. +If no checkpointer exists, this method does nothing. + +Note: Some config changes (like model_save_format) require rebuilding the +checkpointer's internal addons list. This method handles that automatically. + +**Parameters:** + + +Dict of CheckpointingConfig fields to update. + + + +Optional root directory for checkpoints. + + + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint._infer_checkpoint_root( + weights_path: str +) -> str +``` + + + + + + +Infer checkpoint root directory from weights path. + +When weights_path ends with "…/weights/model", we need the parent of +the weights directory (the checkpoint root), not the weights directory itself. + +**Parameters:** + + +Path to model weights (e.g., "/path/to/policy/weights/model") + + +**Returns:** `str` + +Checkpoint root directory (e.g., "/path/to/policy") + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.detect_checkpoint_format( + weights_path: str +) -> tuple[str, bool] +``` + + + + + + +Detect model save format and PEFT status from checkpoint directory. + +**Parameters:** + + +Path to the checkpoint directory (e.g., weights/model) + + +**Returns:** `tuple[str, bool]` + +(model_save_format, is_peft) where: + model_save_format is "torch_save" for DCP or "safetensors" for safetensors + is_peft is True if PEFT/adapter patterns are detected + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx new file mode 100644 index 0000000..8c380d8 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx @@ -0,0 +1,411 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/checkpoint +title: nemo_rl.utils.checkpoint +--- + +Checkpoint management utilities for the rl algorithm loop. + +It handles logic at the algorithm level. Each RL Actor is expected to have its +own checkpoint saving function (called by the algorithm loop). + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CheckpointManager`](#nemo_rl-utils-checkpoint-CheckpointManager) | Manages model checkpoints during training. | +| [`CheckpointingConfig`](#nemo_rl-utils-checkpoint-CheckpointingConfig) | Configuration for checkpoint management. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_load_checkpoint_history`](#nemo_rl-utils-checkpoint-_load_checkpoint_history) | Load the history of checkpoints and their metrics. | + +### Data + +[`PathLike`](#nemo_rl-utils-checkpoint-PathLike) + +### API + + + + + +```python +class nemo_rl.utils.checkpoint.CheckpointManager( + config: nemo_rl.utils.checkpoint.CheckpointingConfig +) +``` + + + + + + +Manages model checkpoints during training. + +This class handles creating checkpoint dirs, saving training info, and +configurations. It also provides utilities for keeping just the top-k checkpoints. +The checkpointing structure looks like this: + + +```python +checkpoint_dir/ + step_0/ + training_info.json + config.yaml + policy.py (up to the algorithm loop to save here) + policy_optimizer.py (up to the algorithm loop to save here) + ... + step_1/ + ... +``` + + + +Attributes: Derived from the CheckpointingConfig. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.finalize_checkpoint( + checkpoint_path: nemo_rl.utils.checkpoint.PathLike +) -> None +``` + + + + + + +Complete a checkpoint by moving it from temporary to permanent location. + +If a checkpoint at the target location already exists (i.e when resuming training), +we override the old one. +Also triggers cleanup of old checkpoints based on the keep_top_k setting. + +**Parameters:** + + +Path to the temporary checkpoint directory. + + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.get_best_checkpoint_path() -> typing.Optional[str] +``` + + + + + + +Get the path to the best checkpoint based on the metric. + +Returns the path to the checkpoint with the best metric value. If no checkpoints +exist, returns None. If some checkpoints are missing the metric, they are filtered +out with a warning. If no checkpoints have the metric, returns the latest checkpoint. + +**Returns:** `Optional[str]` + +Optional[str]: Path to the best checkpoint, or None if no checkpoints exist. + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.get_latest_checkpoint_path() -> typing.Optional[str] +``` + + + + + + +Get the path to the latest checkpoint. + +Returns the path to the checkpoint with the highest step number. + +**Returns:** `Optional[str]` + +Optional[str]: Path to the latest checkpoint, or None if no checkpoints exist. + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.init_tmp_checkpoint( + step: int, + training_info: typing.Mapping[str, typing.Any], + run_config: typing.Optional[typing.Mapping[str, typing.Any]] = None +) -> nemo_rl.utils.checkpoint.PathLike +``` + + + + + + +Initialize a temporary checkpoint directory. + +Creates a temporary directory for a new checkpoint and saves training info +and configuration. The directory is named 'tmp_step_{step}' and will be renamed +to 'step_{step}' when the checkpoint is completed. +We do it this way to allow the algorithm loop to save any files it wants to save +in a safe, temporary directory. + +**Parameters:** + + +The training step number. + + + +Dictionary containing training metrics and info. + + + +Optional configuration for the training run. + + +**Returns:** `PathLike` + +Path to the temporary checkpoint directory. + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.load_training_info( + checkpoint_path: typing.Optional[nemo_rl.utils.checkpoint.PathLike] = None +) -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +Load the training info from a checkpoint. + +**Parameters:** + + +Path to the checkpoint. If None, +returns None. + + +**Returns:** `Optional[dict[str, Any]]` + +Optional[dict[str, Any]]: Dictionary containing the training info, or None if +checkpoint_path is None. + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.remove_old_checkpoints( + exclude_latest: bool = True +) -> None +``` + + + + + + +Remove checkpoints that are not in the top-k or latest based on the (optional) metric. + +If keep_top_k is set, this method removes all checkpoints except the top-k +best ones. The "best" checkpoints are determined by: +- If a metric is provided: the given metric value and the higher_is_better setting. + When multiple checkpoints have the same metric value, more recent checkpoints + (higher step numbers) are preferred. +- If no metric is provided: the step number. The most recent k checkpoints are kept. + +**Parameters:** + + +Whether to exclude the latest checkpoint from deletion. (may result in K+1 checkpoints) + + + + + + + + + + +```python +class nemo_rl.utils.checkpoint.CheckpointingConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for checkpoint management. + +Attributes: +enabled (bool): Whether checkpointing is enabled. +checkpoint_dir (PathLike): Directory where checkpoints will be saved. +metric_name (str | None): Name of the metric to use for determining best checkpoints. + Must be of the form "val:<metric_name>" or "train:<metric_name>" to indicate whether + the metric should be taken from the validation or training metrics. +higher_is_better (bool): Whether higher values of the metric indicate better performance. +keep_top_k (Optional[int]): Number of best checkpoints to keep. If None, all checkpoints are kept. +model_save_format (str | None): Format for saving model (v2 allowed values: "torch_save" or "safetensors", v1 allowed values: None). +save_consolidated (bool): Whether to save consolidated checkpoints (for HF compatibility). +model_cache_dir (str): Directory for model cache (for safetensors format). +model_repo_id (str): Repository ID for the model (for safetensors format). +is_peft (bool): Whether the model uses PEFT. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.checkpoint._load_checkpoint_history( + checkpoint_dir: pathlib.Path +) -> list[tuple[int, nemo_rl.utils.checkpoint.PathLike, dict[str, typing.Any]]] +``` + + + + + + +Load the history of checkpoints and their metrics. + +**Parameters:** + + +Directory containing the checkpoints. + + +**Returns:** `list[tuple[int, PathLike, dict[str, Any]]]` + +list[tuple[int, PathLike, dict[str, Any]]]: List of tuples containing +(step_number, checkpoint_path, info) for each checkpoint. + + + + + + + + +```python +nemo_rl.utils.checkpoint.PathLike = Union[str, 'os.PathLike[Any]'] +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx new file mode 100644 index 0000000..ba6ab36 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx @@ -0,0 +1,266 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/config +title: nemo_rl.utils.config +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`OverridesError`](#nemo_rl-utils-config-OverridesError) | Custom exception for Hydra override parsing errors. | + +### Functions + +| Name | Description | +|------|-------------| +| [`load_config`](#nemo_rl-utils-config-load_config) | Load a config file with inheritance support and convert it to an OmegaConf object. | +| [`load_config_with_inheritance`](#nemo_rl-utils-config-load_config_with_inheritance) | Load a config file with inheritance support. | +| [`merge_with_override`](#nemo_rl-utils-config-merge_with_override) | Merge configs with support for _override_ marker to completely override sections. | +| [`parse_hydra_overrides`](#nemo_rl-utils-config-parse_hydra_overrides) | Parse and apply Hydra overrides to an OmegaConf config. | +| [`register_omegaconf_resolvers`](#nemo_rl-utils-config-register_omegaconf_resolvers) | Register shared OmegaConf resolvers used in configs. | +| [`resolve_path`](#nemo_rl-utils-config-resolve_path) | Resolve a path relative to the base path. | + +### API + + + + + +```python +class nemo_rl.utils.config.OverridesError() +``` + + + + + + +Exception + +**Bases:** `Exception` + +Custom exception for Hydra override parsing errors. + + + + + + + + +```python +nemo_rl.utils.config.load_config( + config_path: typing.Union[str, pathlib.Path] +) -> omegaconf.DictConfig +``` + + + + + + +Load a config file with inheritance support and convert it to an OmegaConf object. + +The config inheritance system supports: + +1. Single inheritance: + ```python + # child.yaml + defaults: parent.yaml + common: + value: 43 + ``` + +2. Multiple inheritance: + ```python + # child.yaml + defaults: + - parent1.yaml + - parent2.yaml + common: + value: 44 + ``` + +3. Nested inheritance: + ```python + # parent.yaml + defaults: grandparent.yaml + common: + value: 43 + + # child.yaml + defaults: parent.yaml + common: + value: 44 + ``` + +4. Variable interpolation: + ```python + # parent.yaml + base_value: 42 + derived: + value: ${base_value} + + # child.yaml + defaults: parent.yaml + base_value: 43 # This will update both base_value and derived.value + ``` + +The system handles: +- Relative and absolute paths +- Multiple inheritance +- Nested inheritance +- Variable interpolation + +The inheritance is resolved depth-first, with later configs overriding earlier ones. +This means in multiple inheritance, the last config in the list takes precedence. + +**Parameters:** + + +Path to the config file + + +**Returns:** `DictConfig` + +Merged config dictionary + + + + + + + + +```python +nemo_rl.utils.config.load_config_with_inheritance( + config_path: typing.Union[str, pathlib.Path], + base_dir: typing.Optional[typing.Union[str, pathlib.Path]] = None +) -> omegaconf.DictConfig +``` + + + + + + +Load a config file with inheritance support. + +**Parameters:** + + +Path to the config file + + + +Base directory for resolving relative paths. If None, uses config_path's directory + + +**Returns:** `DictConfig` + +Merged config dictionary + + + + + + + + +```python +nemo_rl.utils.config.merge_with_override( + base_config: omegaconf.DictConfig, + override_config: omegaconf.DictConfig +) -> omegaconf.DictConfig +``` + + + + + + +Merge configs with support for _override_ marker to completely override sections. + + + + + + + + +```python +nemo_rl.utils.config.parse_hydra_overrides( + cfg: omegaconf.DictConfig, + overrides: list[str] +) -> omegaconf.DictConfig +``` + + + + + + +Parse and apply Hydra overrides to an OmegaConf config. + +**Parameters:** + + +OmegaConf config to apply overrides to + + + +List of Hydra override strings + + +**Returns:** `DictConfig` + +Updated config with overrides applied + +**Raises:** + +- `OverridesError`: If there's an error parsing or applying overrides + + + + + + + + +```python +nemo_rl.utils.config.register_omegaconf_resolvers() -> None +``` + + + + + + +Register shared OmegaConf resolvers used in configs. + + + + + + + + +```python +nemo_rl.utils.config.resolve_path( + base_path: pathlib.Path, + path: str +) -> pathlib.Path +``` + + + + + + +Resolve a path relative to the base path. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx new file mode 100644 index 0000000..a0ee8f2 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx @@ -0,0 +1,501 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/flops_formulas +title: nemo_rl.utils.flops_formulas +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FLOPSConfig`](#nemo_rl-utils-flops_formulas-FLOPSConfig) | Contains the model hparams needed for FLOPS computations. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_hybrid_model_flops`](#nemo_rl-utils-flops_formulas-_hybrid_model_flops) | Model FLOPs for hybrid model. | +| [`_mamba_layer_flops`](#nemo_rl-utils-flops_formulas-_mamba_layer_flops) | Model FLOPs for Mamba layer. We ignore part of the flops of scan because the chunk size is not known from model config. | +| [`_mlp_layer_flops`](#nemo_rl-utils-flops_formulas-_mlp_layer_flops) | Model FLOPs for MLP layer. | +| [`_non_mla_attn_layer_flops`](#nemo_rl-utils-flops_formulas-_non_mla_attn_layer_flops) | Model FLOPs for attention layer. | +| [`bert`](#nemo_rl-utils-flops_formulas-bert) | Model FLOPs for BERT family. | +| [`deepseekv3`](#nemo_rl-utils-flops_formulas-deepseekv3) | Model FLOPs for DeepSeek V3. | +| [`flux`](#nemo_rl-utils-flops_formulas-flux) | Model FLOPs for FLUX. | +| [`gpt3`](#nemo_rl-utils-flops_formulas-gpt3) | Model FLOPs for GPT3 family. | +| [`llama`](#nemo_rl-utils-flops_formulas-llama) | Model FLOPs for llama3 family. | +| [`mixtral`](#nemo_rl-utils-flops_formulas-mixtral) | Model FLOPs for mixtral family. | +| [`nemotron`](#nemo_rl-utils-flops_formulas-nemotron) | Model FLOPs for nemotron family. | +| [`nemotronh`](#nemo_rl-utils-flops_formulas-nemotronh) | Model FLOPs for NemotronH. | +| [`qwen2`](#nemo_rl-utils-flops_formulas-qwen2) | Model FLOPs for Qwen2 family. | +| [`qwen3`](#nemo_rl-utils-flops_formulas-qwen3) | Model FLOPs for Qwen3 family. | +| [`transformer`](#nemo_rl-utils-flops_formulas-transformer) | Calculate FLOPs for a standard Transformer model. | + +### API + + + + + +```python +class nemo_rl.utils.flops_formulas.FLOPSConfig( + gbs: int, + enc_seq_len: typing.Optional[int] = None, + hs: typing.Optional[int] = None, + layers: typing.Optional[int] = None, + ffn_hs: typing.Optional[int] = None, + attention_heads: typing.Optional[int] = None, + moe_router_topk: typing.Optional[int] = None, + query_groups: typing.Optional[int] = None, + img_seq_len: typing.Optional[int] = None, + img_h: typing.Optional[int] = None, + img_w: typing.Optional[int] = None, + in_channels: typing.Optional[int] = None, + patch_dim: typing.Optional[int] = None, + class_token_len: typing.Optional[int] = None, + projector_type: typing.Optional[str] = None, + inp_s: typing.Optional[int] = None, + model_pattern: typing.Optional[str] = None, + vocab_size: typing.Optional[int] = None, + model_channels: typing.Optional[int] = None, + vec_in_dim: typing.Optional[int] = None, + q_lora_rank: typing.Optional[int] = None, + kv_lora_rank: typing.Optional[int] = None, + qk_head_dim: typing.Optional[int] = None, + qk_pos_emb_head_dim: typing.Optional[int] = None, + v_head_dim: typing.Optional[int] = None, + moe_layer_freq: typing.Optional[typing.Union[int, typing.List[int]]] = None, + moe_shared_expert_intermediate_size: typing.Optional[int] = None, + moe_ffn_hidden_size: typing.Optional[int] = None, + mtp_num_layers: typing.Optional[int] = None, + causal_self_attn: typing.Optional[bool] = None, + is_hybrid_model: bool = False, + hybrid_override_pattern: typing.Optional[str] = None, + mamba_state_dim: typing.Optional[int] = None, + mamba_head_dim: typing.Optional[int] = None, + mamba_num_groups: typing.Optional[int] = None, + mamba_num_heads: typing.Optional[int] = None +) +``` + + + + + + +Dataclass + +Contains the model hparams needed for FLOPS computations. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.flops_formulas._hybrid_model_flops( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for hybrid model. + + + + + + + + +```python +nemo_rl.utils.flops_formulas._mamba_layer_flops( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for Mamba layer. We ignore part of the flops of scan because the chunk size is not known from model config. + + + + + + + + +```python +nemo_rl.utils.flops_formulas._mlp_layer_flops( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for MLP layer. + + + + + + + + +```python +nemo_rl.utils.flops_formulas._non_mla_attn_layer_flops( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for attention layer. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.bert( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for BERT family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.deepseekv3( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for DeepSeek V3. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.flux( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for FLUX. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.gpt3( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for GPT3 family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.llama( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for llama3 family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.mixtral( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for mixtral family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.nemotron( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for nemotron family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.nemotronh( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for NemotronH. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.qwen2( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for Qwen2 family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.qwen3( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for Qwen3 family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.transformer( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Calculate FLOPs for a standard Transformer model. + +Note: This does not cover encoder-decoder models. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx new file mode 100644 index 0000000..1965cd5 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx @@ -0,0 +1,215 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/flops_tracker +title: nemo_rl.utils.flops_tracker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FLOPTracker`](#nemo_rl-utils-flops_tracker-FLOPTracker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`convert_config_to_flops_config`](#nemo_rl-utils-flops_tracker-convert_config_to_flops_config) | Convert a pretrained config to a tuple containing a FLOPSConfig and a flops formula. | +| [`get_default_hf_config`](#nemo_rl-utils-flops_tracker-get_default_hf_config) | Get the default Hugging Face config for a model. | +| [`get_theoretical_tflops`](#nemo_rl-utils-flops_tracker-get_theoretical_tflops) | Get the theoretical total flops for a device name. | +| [`is_using_tf32`](#nemo_rl-utils-flops_tracker-is_using_tf32) | Check if the current device is using TF32. | + +### Data + +[`THEORETICAL_TFLOPS`](#nemo_rl-utils-flops_tracker-THEORETICAL_TFLOPS) + +### API + + + + + +```python +class nemo_rl.utils.flops_tracker.FLOPTracker( + model_name: str, + base_config: nemo_rl.utils.flops_formulas.FLOPSConfig | None = None, + flops_formula: typing.Callable[[FLOPSConfig], float] | None = None +) +``` + + + + + + + + + + + + +```python +nemo_rl.utils.flops_tracker.FLOPTracker.from_config( + model_name: str, + config: transformers.configuration_utils.PretrainedConfig +) -> nemo_rl.utils.flops_tracker.FLOPTracker +``` + + + + + + +classmethod + + + + + + + +```python +nemo_rl.utils.flops_tracker.FLOPTracker.reset() +``` + + + + + + + + + + + + +```python +nemo_rl.utils.flops_tracker.FLOPTracker.track( + n_samples: int, + padded_seq_len: int +) +``` + + + + + + + + + + + + +```python +nemo_rl.utils.flops_tracker.FLOPTracker.track_batch( + sequence_lengths: list[int] +) +``` + + + + + + +Track the flops for a batch of sequences. + + + + + + + + + +```python +nemo_rl.utils.flops_tracker.convert_config_to_flops_config( + config: transformers.configuration_utils.PretrainedConfig +) -> tuple[nemo_rl.utils.flops_formulas.FLOPSConfig, typing.Callable] +``` + + + + + + +Convert a pretrained config to a tuple containing a FLOPSConfig and a flops formula. + + + + + + + + +```python +nemo_rl.utils.flops_tracker.get_default_hf_config( + model_name: str +) -> transformers.configuration_utils.PretrainedConfig +``` + + + + + + +Get the default Hugging Face config for a model. + +Both the DTensor and MCore paths use the same default config, we initialize the model config +here to allow computation of theoretical flops which is agnostic to the backend. + + + + + + + + +```python +nemo_rl.utils.flops_tracker.get_theoretical_tflops( + device_name: str, + model_dtype: torch.dtype +) -> float +``` + + + + + + +Get the theoretical total flops for a device name. + + + + + + + + +```python +nemo_rl.utils.flops_tracker.is_using_tf32() -> bool +``` + + + + + + +Check if the current device is using TF32. + + + + + + + + +```python +nemo_rl.utils.flops_tracker.THEORETICAL_TFLOPS = {('NVIDIA A100 80GB PCIe', torch.bfloat16): 624 / 2, ('NVIDIA A100 80GB PCIe', t... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx new file mode 100644 index 0000000..b78a4b3 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx @@ -0,0 +1,1856 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/logger +title: nemo_rl.utils.logger +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`GPUMonitoringConfig`](#nemo_rl-utils-logger-GPUMonitoringConfig) | - | +| [`GpuMetricSnapshot`](#nemo_rl-utils-logger-GpuMetricSnapshot) | - | +| [`Logger`](#nemo_rl-utils-logger-Logger) | Main logger class that delegates to multiple backend loggers. | +| [`LoggerConfig`](#nemo_rl-utils-logger-LoggerConfig) | - | +| [`LoggerInterface`](#nemo_rl-utils-logger-LoggerInterface) | Abstract base class for logger backends. | +| [`MLflowConfig`](#nemo_rl-utils-logger-MLflowConfig) | - | +| [`MLflowLogger`](#nemo_rl-utils-logger-MLflowLogger) | MLflow logger backend. | +| [`RayGpuMonitorLogger`](#nemo_rl-utils-logger-RayGpuMonitorLogger) | Monitor GPU utilization across a Ray cluster and log metrics to a parent logger. | +| [`SwanlabConfig`](#nemo_rl-utils-logger-SwanlabConfig) | - | +| [`SwanlabLogger`](#nemo_rl-utils-logger-SwanlabLogger) | SwanLab logger backend. | +| [`TensorboardConfig`](#nemo_rl-utils-logger-TensorboardConfig) | - | +| [`TensorboardLogger`](#nemo_rl-utils-logger-TensorboardLogger) | Tensorboard logger backend. | +| [`WandbConfig`](#nemo_rl-utils-logger-WandbConfig) | - | +| [`WandbLogger`](#nemo_rl-utils-logger-WandbLogger) | Weights & Biases logger backend. | + +### Functions + +| Name | Description | +|------|-------------| +| [`configure_rich_logging`](#nemo_rl-utils-logger-configure_rich_logging) | Configure rich logging for more visually appealing log output. | +| [`flatten_dict`](#nemo_rl-utils-logger-flatten_dict) | Flatten a nested dictionary. | +| [`get_next_experiment_dir`](#nemo_rl-utils-logger-get_next_experiment_dir) | Create a new experiment directory with an incremented ID. | +| [`print_message_log_samples`](#nemo_rl-utils-logger-print_message_log_samples) | Visualization for message logs and rewards using a more visual approach with emoji indicators and horizontal layout. | + +### Data + +[`_rich_logging_configured`](#nemo_rl-utils-logger-_rich_logging_configured) + +### API + + + + + +```python +class nemo_rl.utils.logger.GPUMonitoringConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.GpuMetricSnapshot +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.Logger( + cfg: nemo_rl.utils.logger.LoggerConfig +) +``` + + + + + + +**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) + +Main logger class that delegates to multiple backend loggers. + + + + + + + + + + + +```python +nemo_rl.utils.logger.Logger.__del__() -> None +``` + + + + + + +Clean up resources when the logger is destroyed. + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_batched_dict_as_jsonl( + to_log: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] | dict[str, typing.Any], + filename: str +) -> None +``` + + + + + + +Log a list of dictionaries to a JSONL file. + +**Parameters:** + + +BatchedDataDict to log + + + +Filename to log to (within the log directory) + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log histogram metrics to all backends if available. + +**Parameters:** + + +List of histogram values + + + +Global step value + + + +Name of the metric + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +Log hyperparameters to all enabled backends. + +**Parameters:** + + +Dict of hyperparameters to log + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +Log metrics to all enabled backends. + +**Parameters:** + + +Dict of metrics to log + + + +Global step value + + + +Optional prefix for metric names + + + +Optional name of a field in metrics to use as step instead + of the provided step value (currently only needed for wandb) + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +Log a matplotlib figure to all backends. + +**Parameters:** + + +Matplotlib figure to log + + + +Global step value + + + +Name of the plot + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_plot_per_worker_timeline_metrics( + metrics: dict[int, list[typing.Any]], + step: int, + prefix: str, + name: str, + timeline_interval: float +) -> None +``` + + + + + + +Log a plot of per-worker timeline metrics. + +**Parameters:** + + +Dictionary of metrics to log, where the keys are the worker IDs and the values are the lists of metric values + + + +dict[str, list[Any]] = {worker_id: [metric_value_1, metric_value_2, ...]} + + + +Global step value + + + +Name of the plot + + + +Interval between timeline points (in seconds) + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_plot_token_mult_prob_error( + data: dict[str, typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log a plot of log probability errors in samples. + +This function logs & plots the per-token log-probabilities and errors over the sequence +for the sample with the highest multiplicative probability error in the batch. + +**Parameters:** + + +Dictionary of log probability samples + + + +Global step value + + + +Name of the plot + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_string_list_as_jsonl( + to_log: list[str], + filename: str +) -> None +``` + + + + + + +Log a list of strings to a JSONL file. + +**Parameters:** + + +list of strings to log + + + +Filename to log to (within the log directory) + + + + + + + + + + +```python +class nemo_rl.utils.logger.LoggerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.LoggerInterface() +``` + + + + + + +Abstract + +Abstract base class for logger backends. + + + + + + +```python +nemo_rl.utils.logger.LoggerInterface.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +abstract + +Log histogram metrics. + + + + + + + +```python +nemo_rl.utils.logger.LoggerInterface.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +abstract + +Log dictionary of hyperparameters. + + + + + + + +```python +nemo_rl.utils.logger.LoggerInterface.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +abstract + +Log a dictionary of metrics. + + + + + + + +```python +nemo_rl.utils.logger.LoggerInterface.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +abstract + +Log a matplotlib figure. + + + + + + + + + +```python +class nemo_rl.utils.logger.MLflowConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.MLflowLogger( + cfg: nemo_rl.utils.logger.MLflowConfig, + log_dir: typing.Optional[str] = None +) +``` + + + + + + +**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) + +MLflow logger backend. + + + + + + + + +```python +nemo_rl.utils.logger.MLflowLogger.__del__() -> None +``` + + + + + + +Clean up resources when the logger is destroyed. + + + + + + + +```python +nemo_rl.utils.logger.MLflowLogger.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log histogram metrics to MLflow. + + + + + + + +```python +nemo_rl.utils.logger.MLflowLogger.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +Log hyperparameters to MLflow. + +**Parameters:** + + +Dictionary of hyperparameters to log + + + + + + + + +```python +nemo_rl.utils.logger.MLflowLogger.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +Log metrics to MLflow. + +**Parameters:** + + +Dict of metrics to log + + + +Global step value + + + +Optional prefix for metric names + + + +Optional step metric name (ignored in MLflow) + + + + + + + + +```python +nemo_rl.utils.logger.MLflowLogger.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +Log a plot to MLflow. + +**Parameters:** + + +Matplotlib figure to log + + + +Global step value + + + +Name of the plot + + + + + + + + + + +```python +class nemo_rl.utils.logger.RayGpuMonitorLogger( + collection_interval: int | float, + flush_interval: int | float, + metric_prefix: str, + step_metric: str, + parent_logger: typing.Optional[nemo_rl.utils.logger.Logger] = None +) +``` + + + + + + +Monitor GPU utilization across a Ray cluster and log metrics to a parent logger. + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._collect( + metrics: bool = False, + sku: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Collect GPU metrics from all Ray nodes. + +**Returns:** `dict[str, Any]` + +Dictionary of collected metrics + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._collect_gpu_sku() -> dict[str, str] +``` + + + + + + +Collect GPU SKU from all Ray nodes. + +Note: This is an internal API and users are not expected to call this. + +**Returns:** `dict[str, str]` + +Dictionary of SKU types on all Ray nodes + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._collect_metrics() -> dict[str, typing.Any] +``` + + + + + + +Collect GPU metrics from all Ray nodes. + +**Returns:** `dict[str, Any]` + +Dictionary of collected metrics + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._collection_loop() -> None +``` + + + + + + +Main collection loop that runs in a separate thread. + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._fetch_and_parse_metrics( + node_idx: int, + metric_address: str, + parser_fn: typing.Callable +) +``` + + + + + + +Fetch metrics from a node and parse GPU metrics. + +**Parameters:** + + +Index of the node + + + +Address of the metrics endpoint + + +**Returns:** + +Dictionary of GPU metrics + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._parse_gpu_sku( + sample: prometheus_client.samples.Sample, + node_idx: int +) -> dict[str, str] +``` + + + + + + +Parse a GPU metric sample into a standardized format. + +**Parameters:** + + +Prometheus metric sample + + + +Index of the node + + +**Returns:** `dict[str, str]` + +Dictionary with metric name and value + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._parse_metric( + sample: prometheus_client.samples.Sample, + node_idx: int +) -> dict[str, typing.Any] +``` + + + + + + +Parse a metric sample into a standardized format. + +**Parameters:** + + +Prometheus metric sample + + + +Index of the node + + +**Returns:** `dict[str, Any]` + +Dictionary with metric name and value + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger.flush() -> None +``` + + + + + + +Flush collected metrics to the parent logger. + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger.start() -> None +``` + + + + + + +Start the GPU monitoring thread. + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger.stop() -> None +``` + + + + + + +Stop the GPU monitoring thread. + + + + + + + + + +```python +class nemo_rl.utils.logger.SwanlabConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.SwanlabLogger( + cfg: nemo_rl.utils.logger.SwanlabConfig, + log_dir: typing.Optional[str] = None +) +``` + + + + + + +**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) + +SwanLab logger backend. + + + + + + + + +```python +nemo_rl.utils.logger.SwanlabLogger.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log histogram metrics to swanlab. + + + + + + + +```python +nemo_rl.utils.logger.SwanlabLogger.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +Update the Swanlab run configuration with the provided hyperparameters. + +**Parameters:** + + +Mapping of hyperparameter names to values to store in the run configuration. + + + + + + + + +```python +nemo_rl.utils.logger.SwanlabLogger.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +Log metrics to the associated Swanlab run. + +**Parameters:** + + +Mapping of metric names to metric values. + + + +Global step value to associate with all logged metrics. + + + +Optional prefix applied to metric names; metric names equal to `step_metric` are not prefixed. + + + +Name of a metric that should be excluded from prefixing. + + + + + + + + +```python +nemo_rl.utils.logger.SwanlabLogger.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +Log a plot to swanlab. + +**Parameters:** + + +Matplotlib figure to log + + + +Global step value + + + + + + + + + + +```python +class nemo_rl.utils.logger.TensorboardConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.utils.logger.TensorboardLogger( + cfg: nemo_rl.utils.logger.TensorboardConfig, + log_dir: typing.Optional[str] = None +) +``` + + + + + + +**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) + +Tensorboard logger backend. + + + + + + + + +```python +nemo_rl.utils.logger.TensorboardLogger._coerce_to_scalar( + value: typing.Any +) -> int | float | bool | str | None +``` + + + + + + +staticmethod + +Coerce a value to a Python scalar for TensorBoard logging. + +Returns the coerced value, or None if it can't be converted to a scalar. + + + + + + + +```python +nemo_rl.utils.logger.TensorboardLogger.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log histogram metrics to Tensorboard. + + + + + + + +```python +nemo_rl.utils.logger.TensorboardLogger.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +Log hyperparameters to Tensorboard. + +**Parameters:** + + +Dictionary of hyperparameters to log + + + + + + + + +```python +nemo_rl.utils.logger.TensorboardLogger.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +Log metrics to Tensorboard. + +**Parameters:** + + +Dict of metrics to log + + + +Global step value + + + +Optional prefix for metric names + + + +Optional step metric name (ignored in TensorBoard) + + + + + + + + +```python +nemo_rl.utils.logger.TensorboardLogger.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +Log a plot to Tensorboard. + +**Parameters:** + + +Dictionary of plot data + + + +Global step value + + + + + + + + + + +```python +class nemo_rl.utils.logger.WandbConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.WandbLogger( + cfg: nemo_rl.utils.logger.WandbConfig, + log_dir: typing.Optional[str] = None +) +``` + + + + + + +**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) + +Weights & Biases logger backend. + + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger._log_code() +``` + + + + + + +Log code that is tracked by git to wandb. + +This function gets a list of all files tracked by git in the project root +and manually uploads them to the current wandb run as an artifact. + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger._log_diffs() +``` + + + + + + +Log git diffs to wandb. + +This function captures and logs two types of diffs: +1. Uncommitted changes (working tree diff against HEAD) +2. All changes (including uncommitted) against the main branch + +Each diff is saved as a text file in a wandb artifact. + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger.define_metric( + name: str, + step_metric: typing.Optional[str] = None +) -> None +``` + + + + + + +Define a metric with custom step metric. + +**Parameters:** + + +Name of the metric or pattern (e.g. 'ray/*') + + + +Optional name of the step metric to use + + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log histogram metrics to wandb. + +**Parameters:** + + +List of histogram values + + + +Global step value + + + +Name of the metric + + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +Log hyperparameters to wandb. + +**Parameters:** + + +Dict of hyperparameters to log + + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +Log metrics to wandb. + +**Parameters:** + + +Dict of metrics to log + + + +Global step value + + + +Optional prefix for metric names + + + +Optional name of a field in metrics to use as step instead + of the provided step value + + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +Log a plot to wandb. + +**Parameters:** + + +Matplotlib figure to log + + + +Global step value + + + + + + + + + + +```python +nemo_rl.utils.logger.configure_rich_logging( + level: str = 'INFO', + show_time: bool = True, + show_path: bool = True +) -> None +``` + + + + + + +Configure rich logging for more visually appealing log output. + +**Parameters:** + + +The logging level to use + + + +Whether to show timestamps in logs + + + +Whether to show file paths in logs + + + + + + + + + +```python +nemo_rl.utils.logger.flatten_dict( + d: typing.Mapping[str, typing.Any], + sep: str = '.' +) -> dict[str, typing.Any] +``` + + + + + + +Flatten a nested dictionary. + +Handles nested dictionaries and lists by creating keys with separators. +For lists, the index is used as part of the key. + +**Parameters:** + + +Dictionary to flatten + + + +Separator to use between nested keys + + +**Returns:** `dict[str, Any]` + +Flattened dictionary with compound keys + +**Examples:** + + + +```python +>>> from nemo_rl.utils.logger import flatten_dict +>>> flatten_dict({"a": 1, "b": {"c": 2}}) +{'a': 1, 'b.c': 2} + +>>> flatten_dict({"a": [1, 2], "b": {"c": [3, 4]}}) +{'a.0': 1, 'a.1': 2, 'b.c.0': 3, 'b.c.1': 4} + +>>> flatten_dict({"a": [{"b": 1}, {"c": 2}]}) +{'a.0.b': 1, 'a.1.c': 2} +``` + + + + + + + + + + +```python +nemo_rl.utils.logger.get_next_experiment_dir( + base_log_dir: str +) -> str +``` + + + + + + +Create a new experiment directory with an incremented ID. + +**Parameters:** + + +The base log directory path + + +**Returns:** `str` + +Path to the new experiment directory with incremented ID + + + + + + + + +```python +nemo_rl.utils.logger.print_message_log_samples( + message_logs: list[nemo_rl.data.interfaces.LLMMessageLogType], + rewards: list[float], + num_samples: int = 5, + step: int = 0 +) -> None +``` + + + + + + +Visualization for message logs and rewards using a more visual approach with emoji indicators and horizontal layout. + +**Parameters:** + + +List of message logs to sample from + + + +List of rewards corresponding to each message log + + + +Number of samples to display (default: 5) + + + +Current training step (for display purposes) + + + + + + + + + +```python +nemo_rl.utils.logger._rich_logging_configured = False +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx new file mode 100644 index 0000000..e06cd16 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx @@ -0,0 +1,122 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/memory_tracker +title: nemo_rl.utils.memory_tracker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MemoryTracker`](#nemo_rl-utils-memory_tracker-MemoryTracker) | - | +| [`MemoryTrackerDataPoint`](#nemo_rl-utils-memory_tracker-MemoryTrackerDataPoint) | - | + +### API + + + + + +```python +class nemo_rl.utils.memory_tracker.MemoryTracker() +``` + + + + + + +**Bases:** `BaseModel` + + + + + + + +```python +nemo_rl.utils.memory_tracker.MemoryTracker.model_post_init( + context +) +``` + + + + + + + + + + + + +```python +nemo_rl.utils.memory_tracker.MemoryTracker.snapshot_start_of_stage( + new_stage: str, + all_current_variables: typing.List[str] +) -> None +``` + + + + + + + + + + + + + + +```python +class nemo_rl.utils.memory_tracker.MemoryTrackerDataPoint() +``` + + + + + + +**Bases:** `BaseModel` + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.memory_tracker.MemoryTrackerDataPoint.get_snapshot_str() -> str +``` + + + + + + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx new file mode 100644 index 0000000..073652c --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx @@ -0,0 +1,351 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/native_checkpoint +title: nemo_rl.utils.native_checkpoint +--- + +Checkpoint management utilities for HF models. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ModelState`](#nemo_rl-utils-native_checkpoint-ModelState) | Helper class for tracking model state in distributed checkpointing. | +| [`OptimizerState`](#nemo_rl-utils-native_checkpoint-OptimizerState) | Helper class for tracking optimizer state in distributed checkpointing. | + +### Functions + +| Name | Description | +|------|-------------| +| [`convert_dcp_to_hf`](#nemo_rl-utils-native_checkpoint-convert_dcp_to_hf) | Convert a Torch DCP checkpoint to a Hugging Face checkpoint. | +| [`load_checkpoint`](#nemo_rl-utils-native_checkpoint-load_checkpoint) | Load a model weights and optionally optimizer state. | +| [`save_checkpoint`](#nemo_rl-utils-native_checkpoint-save_checkpoint) | Save a checkpoint of the model and optionally optimizer state. | + +### API + + + + + +```python +class nemo_rl.utils.native_checkpoint.ModelState( + model: torch.nn.Module +) +``` + + + + + + +**Bases:** `Stateful` + +Helper class for tracking model state in distributed checkpointing. + +This class is compliant with the Stateful protocol, allowing DCP to automatically +call state_dict/load_state_dict as needed in the dcp.save/load APIs. + +**Parameters:** + + +The PyTorch model to track. + + + + + + + +```python +nemo_rl.utils.native_checkpoint.ModelState.load_state_dict( + state_dict: dict[str, typing.Any] +) -> None +``` + + + + + + +Load the state dictionary into the model. + +**Parameters:** + + +State dictionary to load. + + + + + + + + +```python +nemo_rl.utils.native_checkpoint.ModelState.state_dict() -> dict[str, typing.Any] +``` + + + + + + +Get the model's state dictionary. + +**Returns:** `dict[str, Any]` + +Dictionary containing the model's state dict with CPU offloading enabled. + + + + + + + + + +```python +class nemo_rl.utils.native_checkpoint.OptimizerState( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: typing.Optional[typing.Any] = None +) +``` + + + + + + +**Bases:** `Stateful` + +Helper class for tracking optimizer state in distributed checkpointing. + +This class is compliant with the Stateful protocol, allowing DCP to automatically +call state_dict/load_state_dict as needed in the dcp.save/load APIs. + +**Parameters:** + + +The PyTorch model associated with the optimizer. + + + +The optimizer to track. + + + +Optional learning rate scheduler. + + + + + + + +```python +nemo_rl.utils.native_checkpoint.OptimizerState.load_state_dict( + state_dict: dict[str, typing.Any] +) -> None +``` + + + + + + +Load the state dictionaries into the optimizer and scheduler. + +**Parameters:** + + +State dictionary containing optimizer and scheduler states to load. + + + + + + + + +```python +nemo_rl.utils.native_checkpoint.OptimizerState.state_dict() -> dict[str, typing.Any] +``` + + + + + + +Get the optimizer and scheduler state dictionaries. + +**Returns:** `dict[str, Any]` + +Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled. + + + + + + + + + +```python +nemo_rl.utils.native_checkpoint.convert_dcp_to_hf( + dcp_ckpt_path: str, + hf_ckpt_path: str, + model_name_or_path: str, + tokenizer_name_or_path: str, + overwrite: bool = False, + hf_overrides: typing.Optional[dict[str, typing.Any]] = {} +) -> str +``` + + + + + + +Convert a Torch DCP checkpoint to a Hugging Face checkpoint. + +This is not an optimized utility. If checkpoint is too large, consider saving DCP during training +and using this utility to convert to HF format. + +**Parameters:** + + +Path to DCP checkpoint + + + +Path to save HF checkpoint + + + +Model name or path for config + + + +Tokenizer name or path. + Defaults to model_name_or_path if None. + + + +Whether to overwrite existing checkpoint. Defaults to False. + + +**Returns:** `str` + +Path to the saved HF checkpoint + +**Raises:** + +- `FileExistsError`: If HF checkpoint already exists and overwrite is False + + + + + + + + +```python +nemo_rl.utils.native_checkpoint.load_checkpoint( + model: torch.nn.Module, + weights_path: str, + optimizer: typing.Optional[torch.optim.Optimizer] = None, + scheduler: typing.Optional[typing.Any] = None, + optimizer_path: typing.Optional[str] = None +) -> None +``` + + + + + + +Load a model weights and optionally optimizer state. + +**Parameters:** + + +The PyTorch model whose weights to update + + + +Path to load model weights from + + + +Optional optimizer to load state into + + + +Optional scheduler to load state into + + + +Path to load optimizer state from (required if optimizer provided) + + + + + + + + + +```python +nemo_rl.utils.native_checkpoint.save_checkpoint( + model: torch.nn.Module, + weights_path: str, + optimizer: typing.Optional[torch.optim.Optimizer] = None, + scheduler: typing.Optional[typing.Any] = None, + optimizer_path: typing.Optional[str] = None, + tokenizer: typing.Optional[typing.Any] = None, + tokenizer_path: typing.Optional[str] = None +) -> None +``` + + + + + + +Save a checkpoint of the model and optionally optimizer state. + +**Parameters:** + + +The PyTorch model to save + + + +Path to save model weights + + + +Optional optimizer to save + + + +Optional scheduler to save + + + +Path to save optimizer state (required if optimizer provided) + + + +Optional tokenizer to save + + + +Path to save tokenizer state (required if tokenizer provided) + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx new file mode 100644 index 0000000..1c32117 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx @@ -0,0 +1,138 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/nsys +title: nemo_rl.utils.nsys +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ProfilablePolicy`](#nemo_rl-utils-nsys-ProfilablePolicy) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`maybe_gpu_profile_step`](#nemo_rl-utils-nsys-maybe_gpu_profile_step) | - | +| [`wrap_with_nvtx_name`](#nemo_rl-utils-nsys-wrap_with_nvtx_name) | A decorator to wrap a function with an NVTX range with the given name. | + +### Data + +[`NRL_NSYS_PROFILE_STEP_RANGE`](#nemo_rl-utils-nsys-NRL_NSYS_PROFILE_STEP_RANGE) + +[`NRL_NSYS_WORKER_PATTERNS`](#nemo_rl-utils-nsys-NRL_NSYS_WORKER_PATTERNS) + +### API + + + + + +```python +class nemo_rl.utils.nsys.ProfilablePolicy() +``` + + + + + + +Protocol + + + + + +```python +nemo_rl.utils.nsys.ProfilablePolicy.start_gpu_profiling() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.utils.nsys.ProfilablePolicy.stop_gpu_profiling() -> None +``` + + + + + + + + + + + + + + +```python +nemo_rl.utils.nsys.maybe_gpu_profile_step( + policy: nemo_rl.utils.nsys.ProfilablePolicy, + step: int +) +``` + + + + + + + + + + + + + +```python +nemo_rl.utils.nsys.wrap_with_nvtx_name( + name: str +) +``` + + + + + + +A decorator to wrap a function with an NVTX range with the given name. + + + + + + + + +```python +nemo_rl.utils.nsys.NRL_NSYS_PROFILE_STEP_RANGE = os.environ.get('NRL_NSYS_PROFILE_STEP_RANGE', '') +``` + + + + + + + + + +```python +nemo_rl.utils.nsys.NRL_NSYS_WORKER_PATTERNS = os.environ.get('NRL_NSYS_WORKER_PATTERNS', '') +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx new file mode 100644 index 0000000..a8ada45 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx @@ -0,0 +1,100 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/nvml +title: nemo_rl.utils.nvml +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`device_id_to_physical_device_id`](#nemo_rl-utils-nvml-device_id_to_physical_device_id) | Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES. | +| [`get_device_uuid`](#nemo_rl-utils-nvml-get_device_uuid) | Get the UUID of a CUDA device using NVML. | +| [`get_free_memory_bytes`](#nemo_rl-utils-nvml-get_free_memory_bytes) | Get the free memory of a CUDA device in bytes using NVML. | +| [`nvml_context`](#nemo_rl-utils-nvml-nvml_context) | Context manager for NVML initialization and shutdown. | + +### API + + + + + +```python +nemo_rl.utils.nvml.device_id_to_physical_device_id( + device_id: int +) -> int +``` + + + + + + +Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES. + + + + + + + + +```python +nemo_rl.utils.nvml.get_device_uuid( + device_idx: int +) -> str +``` + + + + + + +Get the UUID of a CUDA device using NVML. + + + + + + + + +```python +nemo_rl.utils.nvml.get_free_memory_bytes( + device_idx: int +) -> float +``` + + + + + + +Get the free memory of a CUDA device in bytes using NVML. + + + + + + + + +```python +nemo_rl.utils.nvml.nvml_context() -> typing.Generator[None, None, None] +``` + + + + + + +Context manager for NVML initialization and shutdown. + +**Raises:** + +- `RuntimeError`: If NVML initialization fails + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx new file mode 100644 index 0000000..786e38c --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx @@ -0,0 +1,140 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/packed_tensor +title: nemo_rl.utils.packed_tensor +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_num_buffers`](#nemo_rl-utils-packed_tensor-get_num_buffers) | - | +| [`get_target_packed_tensor_size`](#nemo_rl-utils-packed_tensor-get_target_packed_tensor_size) | - | +| [`packed_broadcast_consumer`](#nemo_rl-utils-packed_tensor-packed_broadcast_consumer) | Consume a packed tensor and unpack it into a list of tensors. | +| [`packed_broadcast_producer`](#nemo_rl-utils-packed_tensor-packed_broadcast_producer) | Broadcast a list of tensors in a packed manner. | + +### API + + + + + +```python +nemo_rl.utils.packed_tensor.get_num_buffers() +``` + + + + + + + + + + + + + +```python +nemo_rl.utils.packed_tensor.get_target_packed_tensor_size() +``` + + + + + + + + + + + + + +```python +nemo_rl.utils.packed_tensor.packed_broadcast_consumer( + iterator, + group, + src, + post_unpack_func +) +``` + + + + + + +Consume a packed tensor and unpack it into a list of tensors. + +**Parameters:** + + +iterator of model parameters. Returns a tuple of (name, tensor) + + + +process group (vllm PyNcclCommunicator) + + + +source rank (0 in current implementation) + + + +function to apply to each tensor after unpacking + + +**Returns:** + +None + + + + + + + + +```python +nemo_rl.utils.packed_tensor.packed_broadcast_producer( + iterator, + group, + src, + post_iter_func +) +``` + + + + + + +Broadcast a list of tensors in a packed manner. + +**Parameters:** + + +iterator of model parameters. Returns a tuple of (name, tensor) + + + +process group (vllm PyNcclCommunicator) + + + +source rank (0 in current implementation) + + + +function to apply to each tensor before packing, should return a tensor + + +**Returns:** + +None + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx new file mode 100644 index 0000000..b0fb043 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx @@ -0,0 +1,108 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/prefetch_venvs +title: nemo_rl.utils.prefetch_venvs +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`create_frozen_environment_symlinks`](#nemo_rl-utils-prefetch_venvs-create_frozen_environment_symlinks) | Create python-{ClassName} wrapper scripts in /usr/local/bin for frozen environment support. | +| [`prefetch_venvs`](#nemo_rl-utils-prefetch_venvs-prefetch_venvs) | Prefetch all virtual environments that will be used by workers. | + +### Data + +[`args`](#nemo_rl-utils-prefetch_venvs-args) + +[`parser`](#nemo_rl-utils-prefetch_venvs-parser) + +### API + + + + + +```python +nemo_rl.utils.prefetch_venvs.create_frozen_environment_symlinks( + venv_configs +) +``` + + + + + + +Create python-{ClassName} wrapper scripts in /usr/local/bin for frozen environment support. + +Only runs in container (when NRL_CONTAINER=1 is set). + +**Parameters:** + + +Dictionary mapping py_executable to list of actor FQNs + + + + + + + + + +```python +nemo_rl.utils.prefetch_venvs.prefetch_venvs( + filters = None, + negative_filters = None +) +``` + + + + + + +Prefetch all virtual environments that will be used by workers. + +**Parameters:** + + +List of strings to match against actor FQNs. If provided, only + actors whose FQN contains at least one of the filter strings will + be prefetched. If None, all venvs are prefetched. + + + +List of strings to exclude from prefetching. Actors whose + FQN contains any of these strings will be skipped. + + + + + + + + + +```python +nemo_rl.utils.prefetch_venvs.args = parser.parse_args() +``` + + + + + + + + + +```python +nemo_rl.utils.prefetch_venvs.parser = argparse.ArgumentParser(description='Prefetch virtual environments for Ray actor... +``` + + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx new file mode 100644 index 0000000..13c7047 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx @@ -0,0 +1,441 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/timer +title: nemo_rl.utils.timer +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`TimeoutChecker`](#nemo_rl-utils-timer-TimeoutChecker) | - | +| [`Timer`](#nemo_rl-utils-timer-Timer) | A utility for timing code execution. | + +### Functions + +| Name | Description | +|------|-------------| +| [`convert_to_seconds`](#nemo_rl-utils-timer-convert_to_seconds) | Converts a time string in the format 'DD:HH:MM:SS' to total seconds. | + +### API + + + + + +```python +class nemo_rl.utils.timer.TimeoutChecker( + timeout: typing.Optional[str] = '00:03:45:00', + fit_last_save_time: bool = False +) +``` + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.timer.TimeoutChecker.check_save() +``` + + + + + + + + + + + + +```python +nemo_rl.utils.timer.TimeoutChecker.mark_iteration() +``` + + + + + + + + + + + + +```python +nemo_rl.utils.timer.TimeoutChecker.start_iterations() +``` + + + + + + + + + + + + + + +```python +class nemo_rl.utils.timer.Timer() +``` + + + + + + +A utility for timing code execution. + +Supports two usage patterns: +1. Explicit start/stop: timer.start("label"), timer.stop("label") +2. Context manager: with timer.time("label"): ... + +The timer keeps track of multiple timing measurements for each label, +and supports different reductions on these measurements (mean, median, +min, max, std dev). + +Example usage: + + +```python +timer = Timer() + +# Method 1: start/stop +timer.start("load_data") +data = load_data() +timer.stop("load_data") + +# Method 2: context manager +with timer.time("model_forward"): + model_outputs = model(inputs) + +# Multiple timing measurements for the same operation +for batch in dataloader: + with timer.time("model_forward_multiple"): + outputs = model(batch) + +# Get all times for one label +model_forward_times = timer.get_elapsed("model_forward_multiple") + +# Get reductions for one label +mean_forward_time = timer.reduce("model_forward_multiple") +max_forward_time = timer.reduce("model_forward_multiple", "max") +``` + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.timer.Timer.get_elapsed( + label: str +) -> list[float] +``` + + + + + + +Get all elapsed time measurements for a specific label. + +**Parameters:** + + +The timing label to get elapsed times for + + +**Returns:** `list[float]` + +A list of all elapsed time measurements in seconds + +**Raises:** + +- `KeyError`: If the label doesn't exist + + + + + + + +```python +nemo_rl.utils.timer.Timer.get_latest_elapsed( + label: str +) -> float +``` + + + + + + +Get the most recent elapsed time measurement for a specific label. + +**Parameters:** + + +The timing label to get the latest elapsed time for + + +**Returns:** `float` + +The most recent elapsed time measurement in seconds + +**Raises:** + +- `KeyError`: If the label doesn't exist +- `IndexError`: If the label exists but has no measurements + + + + + + + +```python +nemo_rl.utils.timer.Timer.get_timing_metrics( + reduction_op: typing.Union[str, dict[str, str]] = 'mean' +) -> dict[str, float | list[float]] +``` + + + + + + +Get all timing measurements with optional reduction. + +**Parameters:** + + +Either a string specifying a reduction operation to apply to all labels, + or a dictionary mapping specific labels to reduction operations. + Valid reduction operations are: "mean", "median", "min", "max", "std", "sum", "count". + If a label is not in the dictionary, no reduction is applied and all measurements are returned. + + +**Returns:** `dict[str, float | list[float]]` + +A dictionary mapping labels to either: + +**Raises:** + +- `ValueError`: If an invalid reduction operation is provided + + + + + + + +```python +nemo_rl.utils.timer.Timer.reduce( + label: str, + operation: str = 'mean' +) -> float +``` + + + + + + +Apply a reduction function to timing measurements for the specified label. + +**Parameters:** + + +The timing label to get reduction for + + + +The type of reduction to apply. Valid options are: +- "mean": Average time (default) +- "median": Median time +- "min": Minimum time +- "max": Maximum time +- "std": Standard deviation +- "sum": Total time +- "count": Number of measurements + + +**Returns:** `float` + +A single float with the reduction result + +**Raises:** + +- `KeyError`: If the label doesn't exist +- `ValueError`: If an invalid operation is provided + + + + + + + +```python +nemo_rl.utils.timer.Timer.reset( + label: typing.Optional[str] = None +) -> None +``` + + + + + + +Reset timings for the specified label or all labels. + +**Parameters:** + + +Optional label to reset. If None, resets all timers. + + + + + + + + +```python +nemo_rl.utils.timer.Timer.start( + label: str +) -> None +``` + + + + + + +Start timing for the given label. + + + + + + + +```python +nemo_rl.utils.timer.Timer.stop( + label: str +) -> float +``` + + + + + + +Stop timing for the given label and return the elapsed time. + +**Parameters:** + + +The label to stop timing for + + +**Returns:** `float` + +The elapsed time in seconds + +**Raises:** + +- `ValueError`: If the timer for the given label is not running + + + + + + + +```python +nemo_rl.utils.timer.Timer.time( + label: str +) -> typing.Generator[None, None, None] +``` + + + + + + +Context manager for timing a block of code. + +**Parameters:** + + +The label to use for this timing + + + + + + + + + + +```python +nemo_rl.utils.timer.convert_to_seconds( + time_string: str +) -> int +``` + + + + + + +Converts a time string in the format 'DD:HH:MM:SS' to total seconds. + +**Parameters:** + + +Time duration string, e.g., '00:03:45:00'. + + +**Returns:** `int` + +Total time in seconds. + + + diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx new file mode 100644 index 0000000..f2e6973 --- /dev/null +++ b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx @@ -0,0 +1,177 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/venvs +title: nemo_rl.utils.venvs +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_env_builder`](#nemo_rl-utils-venvs-_env_builder) | - | +| [`create_local_venv`](#nemo_rl-utils-venvs-create_local_venv) | Create a virtual environment using uv and execute a command within it. | +| [`create_local_venv_on_each_node`](#nemo_rl-utils-venvs-create_local_venv_on_each_node) | Create a virtual environment on each Ray node. | + +### Data + +[`DEFAULT_VENV_DIR`](#nemo_rl-utils-venvs-DEFAULT_VENV_DIR) + +[`dir_path`](#nemo_rl-utils-venvs-dir_path) + +[`git_root`](#nemo_rl-utils-venvs-git_root) + +[`logger`](#nemo_rl-utils-venvs-logger) + +### API + + + + + +```python +nemo_rl.utils.venvs._env_builder( + py_executable: str, + venv_name: str, + node_idx: int, + force_rebuild: bool = False +) +``` + + + + + + + + + + + + + +```python +nemo_rl.utils.venvs.create_local_venv( + py_executable: str, + venv_name: str, + force_rebuild: bool = False +) -> str +``` + + + + + + +Create a virtual environment using uv and execute a command within it. + +The output can be used as a py_executable for a Ray worker assuming the worker +nodes also have access to the same file system as the head node. + +This function is cached to avoid multiple calls to uv to create the same venv, +which avoids duplicate logging. + +**Parameters:** + + +Command to run with the virtual environment (e.g., "uv.sh run --locked") + + + +Name of the virtual environment (e.g., "foobar.Worker") + + + +If True, force rebuild the venv even if it already exists + + +**Returns:** `str` + +Path to the python executable in the created virtual environment + + + + + + + + +```python +nemo_rl.utils.venvs.create_local_venv_on_each_node( + py_executable: str, + venv_name: str +) +``` + + + + + + +Create a virtual environment on each Ray node. + +**Parameters:** + + +Command to run with the virtual environment + + + +Name of the virtual environment + + +**Returns:** + +Path to the python executable in the created virtual environment + + + + + + + + +```python +nemo_rl.utils.venvs.DEFAULT_VENV_DIR = os.path.join(git_root, 'venvs') +``` + + + + + + + + + +```python +nemo_rl.utils.venvs.dir_path = os.path.dirname(os.path.abspath(__file__)) +``` + + + + + + + + + +```python +nemo_rl.utils.venvs.git_root = os.path.abspath(os.path.join(dir_path, '../..')) +``` + + + + + + + + + +```python +nemo_rl.utils.venvs.logger = logging.getLogger(__name__) +``` + + + + From 21a99dd32d0ad523226f408c13ffd83b864e4dc7 Mon Sep 17 00:00:00 2001 From: Paarth Gupta Date: Wed, 11 Feb 2026 16:57:42 -0500 Subject: [PATCH 3/6] render multiple libraries --- fern/docs.yml | 14 + fern/static/ttl-docs/_navigation.yml | 149 +++ fern/static/ttl-docs/ttl/ttl.mdx | 60 ++ fern/static/ttl-docs/ttl/ttl/_mlir_libs.mdx | 9 + .../ttl/ttl/_mlir_libs/_site_initialize_1.mdx | 35 + fern/static/ttl-docs/ttl/ttl/_src.mdx | 11 + .../ttl-docs/ttl/ttl/_src/auto_profile.mdx | 479 +++++++++ .../ttl-docs/ttl/ttl/_src/tensor_registry.mdx | 169 ++++ fern/static/ttl-docs/ttl/ttl/_src/ttl_ast.mdx | 731 ++++++++++++++ .../ttl-docs/ttl/ttl/circular_buffer.mdx | 283 ++++++ fern/static/ttl-docs/ttl/ttl/constants.mdx | 41 + fern/static/ttl-docs/ttl/ttl/diagnostics.mdx | 466 +++++++++ fern/static/ttl-docs/ttl/ttl/dialects.mdx | 12 + .../ttl-docs/ttl/ttl/dialects/_ods_common.mdx | 39 + fern/static/ttl-docs/ttl/ttl/dialects/ttl.mdx | 81 ++ fern/static/ttl-docs/ttl/ttl/dtype_utils.mdx | 212 ++++ .../static/ttl-docs/ttl/ttl/kernel_runner.mdx | 274 ++++++ fern/static/ttl-docs/ttl/ttl/layouts.mdx | 126 +++ fern/static/ttl-docs/ttl/ttl/operators.mdx | 642 +++++++++++++ fern/static/ttl-docs/ttl/ttl/ttl.mdx | 27 + fern/static/ttl-docs/ttl/ttl/ttl_api.mdx | 907 ++++++++++++++++++ fern/static/ttl-docs/ttl/ttl/ttl_math.mdx | 29 + fern/static/ttl-docs/ttl/ttl/ttl_utils.mdx | 70 ++ 23 files changed, 4866 insertions(+) create mode 100644 fern/static/ttl-docs/_navigation.yml create mode 100644 fern/static/ttl-docs/ttl/ttl.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/_mlir_libs.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/_src.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/_src/auto_profile.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/_src/tensor_registry.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/_src/ttl_ast.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/circular_buffer.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/constants.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/diagnostics.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/dialects.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/dialects/_ods_common.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/dialects/ttl.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/dtype_utils.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/kernel_runner.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/layouts.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/operators.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/ttl.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/ttl_api.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/ttl_math.mdx create mode 100644 fern/static/ttl-docs/ttl/ttl/ttl_utils.mdx diff --git a/fern/docs.yml b/fern/docs.yml index 14db967..1003620 100644 --- a/fern/docs.yml +++ b/fern/docs.yml @@ -25,6 +25,9 @@ tabs: Library Reference: display-name: Library Reference icon: book + TTL Reference: + display-name: TTL Reference + icon: book libraries: nemo-rl: @@ -33,6 +36,13 @@ libraries: output: path: ./static/nemo-rl-docs lang: python + ttl: + input: + git: https://github.com/tenstorrent/tt-lang + subpath: python/ttl + output: + path: ./static/ttl-docs + lang: python navigation: - tab: home @@ -80,6 +90,10 @@ navigation: - tab: Library Reference layout: - library: nemo-rl + - tab: TTL Reference + layout: + - library: ttl + navbar-links: - type: minimal diff --git a/fern/static/ttl-docs/_navigation.yml b/fern/static/ttl-docs/_navigation.yml new file mode 100644 index 0000000..e4a8b13 --- /dev/null +++ b/fern/static/ttl-docs/_navigation.yml @@ -0,0 +1,149 @@ +# AUTO-GENERATED by `fern docs md generate` — DO NOT EDIT +- type: section + title: _mlir_libs + slug: ttl/ttl/_mlir_libs + children: + - type: section + title: _site_initialize_1 + slug: ttl/ttl/_mlir_libs/_site_initialize_1 + children: + - type: page + title: _site_initialize_1 + slug: ttl/ttl/_mlir_libs/_site_initialize_1 + pageId: ttl/ttl/_mlir_libs/_site_initialize_1.mdx +- type: section + title: _src + slug: ttl/ttl/_src + children: + - type: section + title: auto_profile + slug: ttl/ttl/_src/auto_profile + children: + - type: page + title: auto_profile + slug: ttl/ttl/_src/auto_profile + pageId: ttl/ttl/_src/auto_profile.mdx + - type: section + title: tensor_registry + slug: ttl/ttl/_src/tensor_registry + children: + - type: page + title: tensor_registry + slug: ttl/ttl/_src/tensor_registry + pageId: ttl/ttl/_src/tensor_registry.mdx + - type: section + title: ttl_ast + slug: ttl/ttl/_src/ttl_ast + children: + - type: page + title: ttl_ast + slug: ttl/ttl/_src/ttl_ast + pageId: ttl/ttl/_src/ttl_ast.mdx +- type: section + title: circular_buffer + slug: ttl/ttl/circular_buffer + children: + - type: page + title: circular_buffer + slug: ttl/ttl/circular_buffer + pageId: ttl/ttl/circular_buffer.mdx +- type: section + title: constants + slug: ttl/ttl/constants + children: + - type: page + title: constants + slug: ttl/ttl/constants + pageId: ttl/ttl/constants.mdx +- type: section + title: diagnostics + slug: ttl/ttl/diagnostics + children: + - type: page + title: diagnostics + slug: ttl/ttl/diagnostics + pageId: ttl/ttl/diagnostics.mdx +- type: section + title: dialects + slug: ttl/ttl/dialects + children: + - type: section + title: _ods_common + slug: ttl/ttl/dialects/_ods_common + children: + - type: page + title: _ods_common + slug: ttl/ttl/dialects/_ods_common + pageId: ttl/ttl/dialects/_ods_common.mdx + - type: section + title: ttl + slug: ttl/ttl/dialects/ttl + children: + - type: page + title: ttl + slug: ttl/ttl/dialects/ttl + pageId: ttl/ttl/dialects/ttl.mdx +- type: section + title: dtype_utils + slug: ttl/ttl/dtype_utils + children: + - type: page + title: dtype_utils + slug: ttl/ttl/dtype_utils + pageId: ttl/ttl/dtype_utils.mdx +- type: section + title: kernel_runner + slug: ttl/ttl/kernel_runner + children: + - type: page + title: kernel_runner + slug: ttl/ttl/kernel_runner + pageId: ttl/ttl/kernel_runner.mdx +- type: section + title: layouts + slug: ttl/ttl/layouts + children: + - type: page + title: layouts + slug: ttl/ttl/layouts + pageId: ttl/ttl/layouts.mdx +- type: section + title: operators + slug: ttl/ttl/operators + children: + - type: page + title: operators + slug: ttl/ttl/operators + pageId: ttl/ttl/operators.mdx +- type: section + title: ttl + slug: ttl/ttl/ttl + children: + - type: page + title: ttl + slug: ttl/ttl/ttl + pageId: ttl/ttl/ttl.mdx +- type: section + title: ttl_api + slug: ttl/ttl/ttl_api + children: + - type: page + title: ttl_api + slug: ttl/ttl/ttl_api + pageId: ttl/ttl/ttl_api.mdx +- type: section + title: ttl_math + slug: ttl/ttl/ttl_math + children: + - type: page + title: ttl_math + slug: ttl/ttl/ttl_math + pageId: ttl/ttl/ttl_math.mdx +- type: section + title: ttl_utils + slug: ttl/ttl/ttl_utils + children: + - type: page + title: ttl_utils + slug: ttl/ttl/ttl_utils + pageId: ttl/ttl/ttl_utils.mdx diff --git a/fern/static/ttl-docs/ttl/ttl.mdx b/fern/static/ttl-docs/ttl/ttl.mdx new file mode 100644 index 0000000..30235a7 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl.mdx @@ -0,0 +1,60 @@ +--- +layout: overview +slug: ttl/ttl +title: ttl +--- + +## Subpackages + +- **[`ttl._mlir_libs`](/ttl/ttl/_mlir_libs)** +- **[`ttl._src`](/ttl/ttl/_src)** +- **[`ttl.dialects`](/ttl/ttl/dialects)** + +## Submodules + +- **[`ttl.circular_buffer`](/ttl/ttl/circular_buffer)** +- **[`ttl.constants`](/ttl/ttl/constants)** +- **[`ttl.diagnostics`](/ttl/ttl/diagnostics)** +- **[`ttl.dtype_utils`](/ttl/ttl/dtype_utils)** +- **[`ttl.ir`](/ttl/ttl/ir)** +- **[`ttl.kernel_runner`](/ttl/ttl/kernel_runner)** +- **[`ttl.layouts`](/ttl/ttl/layouts)** +- **[`ttl.operators`](/ttl/ttl/operators)** +- **[`ttl.ttl`](/ttl/ttl/ttl)** +- **[`ttl.ttl_api`](/ttl/ttl/ttl_api)** +- **[`ttl.ttl_math`](/ttl/ttl/ttl_math)** +- **[`ttl.ttl_utils`](/ttl/ttl/ttl_utils)** + +## Package Contents + +### Data + +[`__all__`](#ttl-__all__) + +[`__version__`](#ttl-__version__) + +### API + + + + + +```python +ttl.__all__ = ['kernel', 'compute', 'datamovement', 'Program', 'CircularBuffer', 'TensorBlock'... +``` + + + + + + + + + +```python +ttl.__version__ = '0.1.0' +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/_mlir_libs.mdx b/fern/static/ttl-docs/ttl/ttl/_mlir_libs.mdx new file mode 100644 index 0000000..7dd8cd3 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/_mlir_libs.mdx @@ -0,0 +1,9 @@ +--- +layout: overview +slug: ttl/ttl/_mlir_libs +title: ttl._mlir_libs +--- + +## Submodules + +- **[`ttl._mlir_libs._site_initialize_1`](/ttl/ttl/_mlir_libs/_site_initialize_1)** diff --git a/fern/static/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx b/fern/static/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx new file mode 100644 index 0000000..414f0c4 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx @@ -0,0 +1,35 @@ +--- +layout: overview +slug: ttl/ttl/_mlir_libs/_site_initialize_1 +title: ttl._mlir_libs._site_initialize_1 +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`register_dialects`](#ttl-_mlir_libs-_site_initialize_1-register_dialects) | Called by MLIR site initialization to add TTL dialects to the registry. | + +### API + + + + + +```python +ttl._mlir_libs._site_initialize_1.register_dialects( + registry +) +``` + + + + + + +Called by MLIR site initialization to add TTL dialects to the registry. + + + diff --git a/fern/static/ttl-docs/ttl/ttl/_src.mdx b/fern/static/ttl-docs/ttl/ttl/_src.mdx new file mode 100644 index 0000000..20050a7 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/_src.mdx @@ -0,0 +1,11 @@ +--- +layout: overview +slug: ttl/ttl/_src +title: ttl._src +--- + +## Submodules + +- **[`ttl._src.auto_profile`](/ttl/ttl/_src/auto_profile)** +- **[`ttl._src.tensor_registry`](/ttl/ttl/_src/tensor_registry)** +- **[`ttl._src.ttl_ast`](/ttl/ttl/_src/ttl_ast)** diff --git a/fern/static/ttl-docs/ttl/ttl/_src/auto_profile.mdx b/fern/static/ttl-docs/ttl/ttl/_src/auto_profile.mdx new file mode 100644 index 0000000..6e2407b --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/_src/auto_profile.mdx @@ -0,0 +1,479 @@ +--- +layout: overview +slug: ttl/ttl/_src/auto_profile +title: ttl._src.auto_profile +--- + +Auto-profiling infrastructure for tt-lang kernels. + +Enabled via TTLANG_AUTO_PROFILE=1 environment variable. +Automatically instruments every operation with signposts and generates +a visual profile report showing cycle counts per source line. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Colors`](#ttl-_src-auto_profile-Colors) | ANSI color codes for terminal output. | +| [`ProfileResult`](#ttl-_src-auto_profile-ProfileResult) | Represents profiling results for a single signpost. | +| [`SourceLineMapper`](#ttl-_src-auto_profile-SourceLineMapper) | Maps signpost markers back to source code lines. | + +### Functions + +| Name | Description | +|------|-------------| +| [`build_cb_wait_to_dma_map`](#ttl-_src-auto_profile-build_cb_wait_to_dma_map) | Build mapping from cb_wait locations to DMA barrier locations. | +| [`build_dma_producer_to_cb_map`](#ttl-_src-auto_profile-build_dma_producer_to_cb_map) | Build mapping from DMA barrier locations to CB index. | +| [`generate_signpost_name`](#ttl-_src-auto_profile-generate_signpost_name) | Generate before/after signpost names for an operation. | +| [`get_line_mapper`](#ttl-_src-auto_profile-get_line_mapper) | Get the global line mapper instance. | +| [`is_auto_profile_enabled`](#ttl-_src-auto_profile-is_auto_profile_enabled) | Check if auto-profiling is enabled via environment variable. | +| [`load_cb_flow_graph`](#ttl-_src-auto_profile-load_cb_flow_graph) | Load CB flow graph JSON from same directory as CSV. | +| [`parse_device_profile_csv`](#ttl-_src-auto_profile-parse_device_profile_csv) | Parse the device profile CSV and extract signpost timing data. | +| [`parse_signpost_name`](#ttl-_src-auto_profile-parse_signpost_name) | Parse op name and implicit flag from signpost name. | +| [`print_profile_report`](#ttl-_src-auto_profile-print_profile_report) | Print a profile report organized by thread. | + +### Data + +[`_global_line_mapper`](#ttl-_src-auto_profile-_global_line_mapper) + +### API + + + + + +```python +class ttl._src.auto_profile.Colors() +``` + + + + + + +ANSI color codes for terminal output. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +ttl._src.auto_profile.Colors.cb_bg( + cb_index: int +) -> str +``` + + + + + + +classmethod + +Get background color for a CB index, or empty if out of range. + + + + + + + + + +```python +class ttl._src.auto_profile.ProfileResult( + signpost: str, + thread: str, + cycles: int, + lineno: int, + source: str +) +``` + + + + + + +Represents profiling results for a single signpost. + + + + + + + + + + +```python +class ttl._src.auto_profile.SourceLineMapper() +``` + + + + + + +Maps signpost markers back to source code lines. + + + + + + + + + + + + + + +```python +ttl._src.auto_profile.SourceLineMapper.get_line_info( + signpost_name: str +) -> typing.Optional[typing.Tuple[int, str]] +``` + + + + + + +Get line number and source for a signpost. + + + + + + + +```python +ttl._src.auto_profile.SourceLineMapper.register_signpost( + signpost_name: str, + lineno: int, + source: str +) +``` + + + + + + +Register a signpost with its source line information. + + + + + + + +```python +ttl._src.auto_profile.SourceLineMapper.set_source( + source_lines: typing.List[str] +) +``` + + + + + + +Set the source code lines for display. + + + + + + + + + +```python +ttl._src.auto_profile.build_cb_wait_to_dma_map( + cb_flow: typing.Optional[typing.Dict] +) -> typing.Dict[typing.Tuple[str, int], typing.Tuple[str, int, int]] +``` + + + + + + +Build mapping from cb_wait locations to DMA barrier locations. + +Only maps consumers waiting for DMA reads (data flowing into CB). +cb_wait ops waiting for compute output (where DMA is a write) are not mapped. + +**Returns:** `Dict[Tuple[str, int], Tuple[str, int, int]]` + +Dict mapping (kernel, line) of cb_wait -> (barrier_kernel, barrier_line, cb_index) + + + + + + + + +```python +ttl._src.auto_profile.build_dma_producer_to_cb_map( + cb_flow: typing.Optional[typing.Dict] +) -> typing.Dict[typing.Tuple[str, int], int] +``` + + + + + + +Build mapping from DMA barrier locations to CB index. + +**Returns:** `Dict[Tuple[str, int], int]` + +Dict mapping (kernel, line) of DMA read barrier -> cb_index + + + + + + + + +```python +ttl._src.auto_profile.generate_signpost_name( + operation: str, + lineno: int, + col: int +) -> typing.Tuple[str, str] +``` + + + + + + +Generate before/after signpost names for an operation. + +**Returns:** `Tuple[str, str]` + +Tuple of (before_name, after_name) + + + + + + + + +```python +ttl._src.auto_profile.get_line_mapper() -> ttl._src.auto_profile.SourceLineMapper +``` + + + + + + +Get the global line mapper instance. + + + + + + + + +```python +ttl._src.auto_profile.is_auto_profile_enabled() -> bool +``` + + + + + + +Check if auto-profiling is enabled via environment variable. + + + + + + + + +```python +ttl._src.auto_profile.load_cb_flow_graph( + csv_path: pathlib.Path +) -> typing.Optional[typing.Dict] +``` + + + + + + +Load CB flow graph JSON from same directory as CSV. + + + + + + + + +```python +ttl._src.auto_profile.parse_device_profile_csv( + csv_path: pathlib.Path, + line_mapper: ttl._src.auto_profile.SourceLineMapper +) -> typing.List[ttl._src.auto_profile.ProfileResult] +``` + + + + + + +Parse the device profile CSV and extract signpost timing data. + +**Parameters:** + + +Path to profile_log_device.csv + + + +Mapper to correlate signposts to source lines + + +**Returns:** `List[ProfileResult]` + +List of ProfileResult objects sorted by line number + + + + + + + + +```python +ttl._src.auto_profile.parse_signpost_name( + signpost: str +) -> typing.Tuple[typing.Optional[str], bool] +``` + + + + + + +Parse op name and implicit flag from signpost name. + +Returns (op_name, is_implicit) where op_name is None for line-only signposts. +Examples: + "line_52_before" -> (None, False) + "line_52_cb_wait_before" -> ("cb_wait", False) + "line_52_implicit_cb_pop_before" -> ("cb_pop", True) + + + + + + + + +```python +ttl._src.auto_profile.print_profile_report( + results: typing.List[ttl._src.auto_profile.ProfileResult], + all_source_lines: typing.Dict[str, typing.List[str]], + thread_to_kernel: typing.Dict[str, str], + line_mapper: typing.Optional[ttl._src.auto_profile.SourceLineMapper] = None, + cb_wait_to_dma: typing.Optional[typing.Dict[typing.Tuple[str, int], typing.Tuple[str, int, int]]] = None, + dma_producer_to_cb: typing.Optional[typing.Dict[typing.Tuple[str, int], int]] = None, + kernel_line_offsets: typing.Optional[typing.Dict[str, int]] = None +) +``` + + + + + + +Print a profile report organized by thread. + +Shows full source context with cycle annotations where available. +Each thread displays its corresponding kernel's source code. + +**Parameters:** + + +List of ProfileResult from CSV parsing + + + +Dict mapping kernel name to source lines + + + +Dict mapping RISC thread name to kernel name + + + +Optional SourceLineMapper with line offset info + + + +Optional mapping from (kernel, line) -> (dma_kernel, dma_line, cb_index) + + + +Optional mapping from (kernel, line) -> cb_index for DMA producers + + + +Optional mapping from kernel name to line offset + + + + + + + + + +```python +ttl._src.auto_profile._global_line_mapper = SourceLineMapper() +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/_src/tensor_registry.mdx b/fern/static/ttl-docs/ttl/ttl/_src/tensor_registry.mdx new file mode 100644 index 0000000..0c2cdf6 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/_src/tensor_registry.mdx @@ -0,0 +1,169 @@ +--- +layout: overview +slug: ttl/ttl/_src/tensor_registry +title: ttl._src.tensor_registry +--- + +Registry for tensor global names, used to track tensor parameter names. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_tensor_global_index`](#ttl-_src-tensor_registry-get_tensor_global_index) | Get the global index for a tensor. | +| [`get_tensor_global_name`](#ttl-_src-tensor_registry-get_tensor_global_name) | Get the global name for a tensor, checking registry first then attribute. | +| [`get_tensor_source`](#ttl-_src-tensor_registry-get_tensor_source) | Get the source location where a tensor was assigned, if tracked. | +| [`register_tensor_name`](#ttl-_src-tensor_registry-register_tensor_name) | Register a global name and index for a tensor. | +| [`register_tensor_source`](#ttl-_src-tensor_registry-register_tensor_source) | Register the source location where a tensor variable was assigned. | + +### Data + +[`_tensor_index_registry`](#ttl-_src-tensor_registry-_tensor_index_registry) + +[`_tensor_name_registry`](#ttl-_src-tensor_registry-_tensor_name_registry) + +[`_tensor_source_registry`](#ttl-_src-tensor_registry-_tensor_source_registry) + +### API + + + + + +```python +ttl._src.tensor_registry.get_tensor_global_index( + tensor +) -> int +``` + + + + + + +Get the global index for a tensor. + + + + + + + + +```python +ttl._src.tensor_registry.get_tensor_global_name( + tensor +) -> str +``` + + + + + + +Get the global name for a tensor, checking registry first then attribute. + + + + + + + + +```python +ttl._src.tensor_registry.get_tensor_source( + tensor +) -> typing.Optional[typing.Tuple[str, int]] +``` + + + + + + +Get the source location where a tensor was assigned, if tracked. + + + + + + + + +```python +ttl._src.tensor_registry.register_tensor_name( + tensor, + name: str, + index: int = -1 +) -> None +``` + + + + + + +Register a global name and index for a tensor. + + + + + + + + +```python +ttl._src.tensor_registry.register_tensor_source( + tensor, + source_file: str, + line: int +) -> None +``` + + + + + + +Register the source location where a tensor variable was assigned. + + + + + + + + +```python +ttl._src.tensor_registry._tensor_index_registry: Dict[int, int] = {} +``` + + + + + + + + + +```python +ttl._src.tensor_registry._tensor_name_registry: Dict[int, str] = {} +``` + + + + + + + + + +```python +ttl._src.tensor_registry._tensor_source_registry: Dict[int, Tuple[str, int]] = {} +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/_src/ttl_ast.mdx b/fern/static/ttl-docs/ttl/ttl/_src/ttl_ast.mdx new file mode 100644 index 0000000..cc33586 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/_src/ttl_ast.mdx @@ -0,0 +1,731 @@ +--- +layout: overview +slug: ttl/ttl/_src/ttl_ast +title: ttl._src.ttl_ast +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CompilerContext`](#ttl-_src-ttl_ast-CompilerContext) | Immutable compilation context for TTL kernels. | +| [`TTLGenericCompiler`](#ttl-_src-ttl_ast-TTLGenericCompiler) | Compiler that generates TTL dialect ops from Python AST. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_build_tensor_type`](#ttl-_src-ttl_ast-_build_tensor_type) | Build MLIR tensor type for a ttnn tensor with TTNNLayoutAttr. | +| [`_get_annotation_name`](#ttl-_src-ttl_ast-_get_annotation_name) | Extract the type name from an annotation node. | +| [`_make_file_loc`](#ttl-_src-ttl_ast-_make_file_loc) | Create an MLIR file location from an AST node. | +| [`_raise_tensor_error`](#ttl-_src-ttl_ast-_raise_tensor_error) | Raise TTLangCompileError with tensor source location if available. | +| [`syntax`](#ttl-_src-ttl_ast-syntax) | - | + +### API + + + + + +```python +class ttl._src.ttl_ast.CompilerContext( + grid: typing.List[int], + memory_space: str, + tiled: bool +) +``` + + + + + + +Dataclass + +Immutable compilation context for TTL kernels. + + + + + + + + + + + + + + + + +```python +class ttl._src.ttl_ast.TTLGenericCompiler( + name, + kernel_type = None, + captures = {}, + args = (), + kwargs = {} +) +``` + + + + + + +**Bases:** `TTCompilerBase` + +Compiler that generates TTL dialect ops from Python AST. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._build_index_or_range( + node +) +``` + + + + + + +Convert AST node to (start_value, is_range) tuple. + +For slice syntax (start:end), returns (start_value, True). +For index syntax (value), returns (value, False). + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._close_final_signpost() +``` + + + + + + +Close the final signpost at the end of function body. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._emit_cb_from_capture( + cb +) +``` + + + + + + +Emit ttl.bind_cb for a captured CircularBuffer instance. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._emit_entry( + node +) +``` + + + + + + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._emit_line_signpost_if_needed( + node +) +``` + + + + + + +Emit signposts at line boundaries for auto-profiling. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._emit_op_signposts( + op_name: str, + node, + op_fn, + implicit = False +) +``` + + + + + + +Emit signposts for CB operations with op name included. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._emit_signpost( + name: str +) +``` + + + + + + +Emit a signpost operation into the MLIR. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._get_cb_tensor_type( + cb_val, + node = None +) +``` + + + + + + +Extract the tensor type from a TTL CB type. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._is_ttl_math_access( + node +) +``` + + + + + + +Check if node is ttl.math.XXX access pattern. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._is_ttl_module_access( + node +) +``` + + + + + + +Check if node is ttl.XXX access pattern. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._loc_for_node( + node +) +``` + + + + + + +Return file location for node if debug_locations enabled, else name location. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._raise_error( + node, + message: str +) +``` + + + + + + +Raise a TTLangCompileError with source location from AST node. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._resolve_ttl_function( + node, + func_args, + kwargs +) +``` + + + + + + +Resolve and call a ttl.XXX or ttl.math.XXX function. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._to_index_value( + node +) +``` + + + + + + +Convert AST node to MLIR index Value. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._try_emit_auto_signposts( + node, + visit_fn +) +``` + + + + + + +Emit line-based signposts if auto-profiling is enabled. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Assign( + node +) +``` + + + + + + +Handle tuple unpacking for TTL functions like core(dims=2). + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_AsyncFunctionDef( + node +) +``` + + + + + + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Attribute( + node, + func_args = [], + kwargs = {} +) +``` + + + + + + +Override to set location context and catch errors for method calls. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_BinOp( + node +) +``` + + + + + + +Override to inject auto-profiling and provide better error messages. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Call( + node +) +``` + + + + + + +Override to set location context, catch errors, and inject auto-profiling. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Constant( + node +) +``` + + + + + + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_FunctionDef( + node +) +``` + + + + + + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_List( + node +) +``` + + + + + + +Parse a list of constants. Returns a Python list, not MLIR values. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Name( + node +) +``` + + + + + + +Override to check function globals for simple constants. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Subscript( + node +) +``` + + + + + + +Handle tensor[row, col] or tensor[r0:r1, c0:c1] indexing. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_With( + node +) +``` + + + + + + +Handle 'with' for CircularBuffer acquire/release. + +Acquire ops (wait/reserve) are generated left-to-right. +Release ops (pop/push) are generated in reverse order at scope end. + + + + + + + + + +```python +ttl._src.ttl_ast._build_tensor_type( + ctx, + tensor, + grid, + tiled, + memory_space +) +``` + + + + + + +Build MLIR tensor type for a ttnn tensor with TTNNLayoutAttr. + + + + + + + + +```python +ttl._src.ttl_ast._get_annotation_name( + annotation +) +``` + + + + + + +Extract the type name from an annotation node. + +Handles both simple names (CircularBuffer) and qualified names (ttl.CircularBuffer). +Returns the simple type name (e.g., 'CircularBuffer') in both cases. + + + + + + + + +```python +ttl._src.ttl_ast._make_file_loc( + ctx, + source_file: str, + node, + line_offset: int = 0 +) -> Location +``` + + + + + + +Create an MLIR file location from an AST node. + + + + + + + + +```python +ttl._src.ttl_ast._raise_tensor_error( + tensor, + message: str +) +``` + + + + + + +Raise TTLangCompileError with tensor source location if available. + + + + + + + + +```python +ttl._src.ttl_ast.syntax( + syntax_name +) +``` + + + + + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/circular_buffer.mdx b/fern/static/ttl-docs/ttl/ttl/circular_buffer.mdx new file mode 100644 index 0000000..38c6508 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/circular_buffer.mdx @@ -0,0 +1,283 @@ +--- +layout: overview +slug: ttl/ttl/circular_buffer +title: ttl.circular_buffer +--- + +Circular buffer operations for inter-thread communication. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CircularBuffer`](#ttl-circular_buffer-CircularBuffer) | Circular buffer for inter-thread communication. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_cb_tensor_type`](#ttl-circular_buffer-_get_cb_tensor_type) | Extract the tensor type from a TTL CB type. | +| [`_next_cb_index`](#ttl-circular_buffer-_next_cb_index) | Get next CB index and increment counter. | +| [`_reset_cb_counter`](#ttl-circular_buffer-_reset_cb_counter) | Reset the CB index counter. Called at kernel start. | +| [`get_cb_count`](#ttl-circular_buffer-get_cb_count) | Return number of CBs allocated so far. | +| [`make_circular_buffer_like`](#ttl-circular_buffer-make_circular_buffer_like) | Create a circular buffer with properties derived from a tensor. | + +### Data + +[`_cb_index_counter`](#ttl-circular_buffer-_cb_index_counter) + +### API + + + + + +```python +class ttl.circular_buffer.CircularBuffer( + tensor: typing.Any, + shape: typing.Tuple[int, int], + buffer_factor: int +) +``` + + + + + + +Circular buffer for inter-thread communication. + +Circular buffers provide producer-consumer synchronization between +compute and data movement threads. + +Can be instantiated via make_circular_buffer_like() in kernel body, +then captured by thread closures. Methods generate TTL ops during compilation. + + + + + + + + +```python +ttl.circular_buffer.CircularBuffer.pop( + ast_self: ttl.circular_buffer.CircularBuffer +) -> None +``` + + + + + + +Signal that data has been consumed (consumer release). + +Use in consumer threads after wait() to signal that data has been +consumed and space is available for producers. + + + + + + + +```python +ttl.circular_buffer.CircularBuffer.push( + ast_self: ttl.circular_buffer.CircularBuffer +) -> None +``` + + + + + + +Signal that data is ready in the circular buffer (producer release). + +Use in producer threads after reserve() to signal that data has been +written and is ready for consumers. + + + + + + + +```python +ttl.circular_buffer.CircularBuffer.reserve( + ast_self: ttl.circular_buffer.CircularBuffer +) -> ttl.ttl_api.TensorBlock +``` + + + + + + +Reserve space in the circular buffer (producer acquire). + +Use in producer threads to acquire space for writing. Must be followed +by push() to signal data is ready. + +**Returns:** `TensorBlock` + +The reserved space with CB association. + + + + + + + +```python +ttl.circular_buffer.CircularBuffer.wait( + ast_self: ttl.circular_buffer.CircularBuffer +) -> ttl.ttl_api.TensorBlock +``` + + + + + + +Wait for data from the circular buffer (consumer acquire). + +Use in consumer threads to acquire data. Must be followed by pop() +to signal consumption is complete. + +**Returns:** `TensorBlock` + +The acquired data with CB association. + + + + + + + + + +```python +ttl.circular_buffer._get_cb_tensor_type( + cb_val +) +``` + + + + + + +Extract the tensor type from a TTL CB type. + + + + + + + + +```python +ttl.circular_buffer._next_cb_index() +``` + + + + + + +Get next CB index and increment counter. + + + + + + + + +```python +ttl.circular_buffer._reset_cb_counter() +``` + + + + + + +Reset the CB index counter. Called at kernel start. + + + + + + + + +```python +ttl.circular_buffer.get_cb_count() +``` + + + + + + +Return number of CBs allocated so far. + + + + + + + + +```python +ttl.circular_buffer.make_circular_buffer_like( + tensor: typing.Any, + shape: typing.Tuple[int, int], + buffer_factor: int = 2 +) -> ttl.circular_buffer.CircularBuffer +``` + + + + + + +Create a circular buffer with properties derived from a tensor. + +**Parameters:** + + +Tensor that determines the CB's data type + + + +(rows, cols) in tiles for wait/reserve operations + + + +Capacity multiplier (default 2 for double-buffering) + + +**Returns:** `CircularBuffer` + +CircularBuffer for use in thread function closures + + + + + + + + +```python +ttl.circular_buffer._cb_index_counter = 0 +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/constants.mdx b/fern/static/ttl-docs/ttl/ttl/constants.mdx new file mode 100644 index 0000000..5aa63c2 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/constants.mdx @@ -0,0 +1,41 @@ +--- +layout: overview +slug: ttl/ttl/constants +title: ttl.constants +--- + +Constants used throughout the DSL. + +## Module Contents + +### Data + +[`DEFAULT_TILE_SIZE`](#ttl-constants-DEFAULT_TILE_SIZE) + +[`SUPPORTED_MEMORY_SPACES`](#ttl-constants-SUPPORTED_MEMORY_SPACES) + +### API + + + + + +```python +ttl.constants.DEFAULT_TILE_SIZE = 32 +``` + + + + + + + + + +```python +ttl.constants.SUPPORTED_MEMORY_SPACES = frozenset(['L1', 'DRAM']) +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/diagnostics.mdx b/fern/static/ttl-docs/ttl/ttl/diagnostics.mdx new file mode 100644 index 0000000..709211b --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/diagnostics.mdx @@ -0,0 +1,466 @@ +--- +layout: overview +slug: ttl/ttl/diagnostics +title: ttl.diagnostics +--- + +Diagnostic utilities for formatting compiler errors with source context. + +This module provides Rust/Swift-style error formatting that displays +source code snippets with ASCII arrows pointing to the error location. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`SourceDiagnostic`](#ttl-diagnostics-SourceDiagnostic) | Format errors with source context and ASCII arrows. | +| [`TTLangCompileError`](#ttl-diagnostics-TTLangCompileError) | Exception for tt-lang compilation errors with source context. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_extract_core_message`](#ttl-diagnostics-_extract_core_message) | Extract the core error message from MLIR diagnostic output. | +| [`_extract_note`](#ttl-diagnostics-_extract_note) | Extract any note from the MLIR error message. | +| [`_read_file_lines`](#ttl-diagnostics-_read_file_lines) | Read source lines from a file if it exists. | +| [`_verbose_errors_enabled`](#ttl-diagnostics-_verbose_errors_enabled) | Check if verbose MLIR error output is enabled. | +| [`extract_location_from_mlir_error`](#ttl-diagnostics-extract_location_from_mlir_error) | Extract source location from an MLIR error message. | +| [`find_variable_assignment`](#ttl-diagnostics-find_variable_assignment) | Find the line where a variable was assigned, searching backwards. | +| [`format_mlir_error`](#ttl-diagnostics-format_mlir_error) | Format an MLIR error with source context if location is available. | +| [`format_python_error`](#ttl-diagnostics-format_python_error) | Format a Python error with source context. | +| [`parse_mlir_location`](#ttl-diagnostics-parse_mlir_location) | Parse an MLIR location string to extract file, line, and column. | + +### API + + + + + +```python +class ttl.diagnostics.SourceDiagnostic( + source_lines: typing.List[str], + filename: str +) +``` + + + + + + +Format errors with source context and ASCII arrows. + +Produces error messages in the style of modern compilers (Rust, Swift): + + error: type mismatch in add operation + --> kernel.py:43:16 + | + 43 | result = l + r + | ^^^ expected bf16, got f32 + | + + + + + + +```python +ttl.diagnostics.SourceDiagnostic.format_error( + line: int, + col: int, + message: str, + label: str = 'error', + span_length: int = 1, + note: typing.Optional[str] = None +) -> str +``` + + + + + + +Format an error with source context. + +**Parameters:** + + +1-based line number + + + +1-based column number + + + +Main error message + + + +Error label (e.g., "error", "warning") + + + +Length of the underline (^^^) + + + +Optional additional note + + +**Returns:** `str` + +Formatted error string with source context + + + + + + + +```python +ttl.diagnostics.SourceDiagnostic.format_error_chain( + errors: typing.List[typing.Tuple[int, int, str, typing.Optional[str]]] +) -> str +``` + + + + + + +Format multiple related errors. + +**Parameters:** + + +List of (line, col, message, note) tuples + + +**Returns:** `str` + +Formatted error chain + + + + + + + + + +```python +class ttl.diagnostics.TTLangCompileError( + message: str, + source_file: typing.Optional[str] = None, + line: typing.Optional[int] = None, + col: typing.Optional[int] = None, + source_lines: typing.Optional[typing.List[str]] = None +) +``` + + + + + + +Exception + +**Bases:** `Exception` + +Exception for tt-lang compilation errors with source context. + +This exception carries enough information to produce pretty error messages +pointing to the exact source location where the error occurred. + + + + + + +```python +ttl.diagnostics.TTLangCompileError.format() -> str +``` + + + + + + +Format error with source context if available. + + + + + + + + + +```python +ttl.diagnostics._extract_core_message( + error_msg: str +) -> str +``` + + + + + + +Extract the core error message from MLIR diagnostic output. + +This extracts: "expects transfer handle to be synchronized with ttl.wait" + + + + + + + + +```python +ttl.diagnostics._extract_note( + error_msg: str +) -> typing.Optional[str] +``` + + + + + + +Extract any note from the MLIR error message. + + + + + + + + +```python +ttl.diagnostics._read_file_lines( + filepath: str +) -> typing.Optional[typing.List[str]] +``` + + + + + + +Read source lines from a file if it exists. + + + + + + + + +```python +ttl.diagnostics._verbose_errors_enabled() -> bool +``` + + + + + + +Check if verbose MLIR error output is enabled. + + + + + + + + +```python +ttl.diagnostics.extract_location_from_mlir_error( + error_msg: str +) -> typing.Optional[typing.Tuple[str, int, int]] +``` + + + + + + +Extract source location from an MLIR error message. + +**Parameters:** + + +Full MLIR error message + + +**Returns:** `Optional[Tuple[str, int, int]]` + +Tuple of (filename, line, col) or None if no location found + + + + + + + + +```python +ttl.diagnostics.find_variable_assignment( + source_lines: typing.List[str], + var_name: str, + before_line: int +) -> typing.Optional[int] +``` + + + + + + +Find the line where a variable was assigned, searching backwards. + +**Parameters:** + + +List of source lines (0-indexed) + + + +Variable name to search for + + + +Search backwards from this 1-based line number + + +**Returns:** `Optional[int]` + +1-based line number where assignment was found, or None + + + + + + + + +```python +ttl.diagnostics.format_mlir_error( + error_msg: str, + source_lines: typing.Optional[typing.List[str]] = None, + source_file: typing.Optional[str] = None +) -> str +``` + + + + + + +Format an MLIR error with source context if location is available. + +**Parameters:** + + +The MLIR error message + + + +Original Python source lines (optional, will read from file if needed) + + + +Source filename (optional, extracted from error if not provided) + + +**Returns:** `str` + +Formatted error message, with source context if available + + + + + + + + +```python +ttl.diagnostics.format_python_error( + error: Exception, + source_file: str, + line: int, + source_lines: typing.Optional[typing.List[str]] = None +) -> str +``` + + + + + + +Format a Python error with source context. + +**Parameters:** + + +The Python exception + + + +Source file path + + + +Line number in source file + + + +Source lines (will read from file if not provided) + + +**Returns:** `str` + +Formatted error message with source context + + + + + + + + +```python +ttl.diagnostics.parse_mlir_location( + loc_str: str +) -> typing.Optional[typing.Tuple[str, int, int]] +``` + + + + + + +Parse an MLIR location string to extract file, line, and column. + +MLIR locations can appear in several formats: +- loc("filename":line:col) +- loc("filename":line:col to :line:col) +- loc(#loc1) with #loc1 = loc("filename":line:col) + +**Parameters:** + + +MLIR location string + + +**Returns:** `Optional[Tuple[str, int, int]]` + +Tuple of (filename, line, col) or None if not parseable + + + diff --git a/fern/static/ttl-docs/ttl/ttl/dialects.mdx b/fern/static/ttl-docs/ttl/ttl/dialects.mdx new file mode 100644 index 0000000..2eb4b50 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/dialects.mdx @@ -0,0 +1,12 @@ +--- +layout: overview +slug: ttl/ttl/dialects +title: ttl.dialects +--- + +TTLang dialect modules. + +## Submodules + +- **[`ttl.dialects._ods_common`](/ttl/ttl/dialects/_ods_common)** +- **[`ttl.dialects.ttl`](/ttl/ttl/dialects/ttl)** diff --git a/fern/static/ttl-docs/ttl/ttl/dialects/_ods_common.mdx b/fern/static/ttl-docs/ttl/ttl/dialects/_ods_common.mdx new file mode 100644 index 0000000..7b2c71e --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/dialects/_ods_common.mdx @@ -0,0 +1,39 @@ +--- +layout: overview +slug: ttl/ttl/dialects/_ods_common +title: ttl.dialects._ods_common +--- + +## Module Contents + +### Data + +[`__all__`](#ttl-dialects-_ods_common-__all__) + +[`_cext`](#ttl-dialects-_ods_common-_cext) + +### API + + + + + +```python +ttl.dialects._ods_common.__all__ = ['_cext'] +``` + + + + + + + + + +```python +ttl.dialects._ods_common._cext = _upstream._cext +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/dialects/ttl.mdx b/fern/static/ttl-docs/ttl/ttl/dialects/ttl.mdx new file mode 100644 index 0000000..d8c69a1 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/dialects/ttl.mdx @@ -0,0 +1,81 @@ +--- +layout: overview +slug: ttl/ttl/dialects/ttl +title: ttl.dialects.ttl +--- + +TTL (TT-Lang) dialect Python bindings. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`ensure_dialects_registered`](#ttl-dialects-ttl-ensure_dialects_registered) | Ensure TTL dialect is registered with the given MLIR context. | + +### Data + +[`CircularBufferType`](#ttl-dialects-ttl-CircularBufferType) + +[`SliceAttr`](#ttl-dialects-ttl-SliceAttr) + +[`__all__`](#ttl-dialects-ttl-__all__) + +### API + + + + + +```python +ttl.dialects.ttl.ensure_dialects_registered( + ctx +) +``` + + + + + + +Ensure TTL dialect is registered with the given MLIR context. + + + + + + + + +```python +ttl.dialects.ttl.CircularBufferType = ir.CircularBufferType +``` + + + + + + + + + +```python +ttl.dialects.ttl.SliceAttr = ir.SliceAttr +``` + + + + + + + + + +```python +ttl.dialects.ttl.__all__ = [*[name for name in (globals().keys()) if not name.startswith('_')]] +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/dtype_utils.mdx b/fern/static/ttl-docs/ttl/ttl/dtype_utils.mdx new file mode 100644 index 0000000..fce940c --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/dtype_utils.mdx @@ -0,0 +1,212 @@ +--- +layout: overview +slug: ttl/ttl/dtype_utils +title: ttl.dtype_utils +--- + +Data type conversion utilities between PyTorch, TTNN, and MLIR types. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`is_ttnn_tensor`](#ttl-dtype_utils-is_ttnn_tensor) | Check if tensor is a ttnn.Tensor. | +| [`tensor_dtype_to_ttcore_datatype`](#ttl-dtype_utils-tensor_dtype_to_ttcore_datatype) | Convert tensor dtype to ttcore.DataType, supporting both torch and ttnn dtypes. | +| [`tile_bytes_from_dtype`](#ttl-dtype_utils-tile_bytes_from_dtype) | Calculate tile size in bytes from ttnn dtype. | +| [`torch_dtype_to_ttcore_datatype`](#ttl-dtype_utils-torch_dtype_to_ttcore_datatype) | Convert PyTorch dtype to ttcore.DataType enum. | +| [`torch_dtype_to_ttnn_datatype`](#ttl-dtype_utils-torch_dtype_to_ttnn_datatype) | Convert PyTorch dtype to ttnn.DataType enum. | +| [`ttnn_dtype_to_ttcore_datatype`](#ttl-dtype_utils-ttnn_dtype_to_ttcore_datatype) | Convert ttnn.DataType to ttcore.DataType enum. | + +### API + + + + + +```python +ttl.dtype_utils.is_ttnn_tensor( + tensor +) -> bool +``` + + + + + + +Check if tensor is a ttnn.Tensor. + + + + + + + + +```python +ttl.dtype_utils.tensor_dtype_to_ttcore_datatype( + dtype +) +``` + + + + + + +Convert tensor dtype to ttcore.DataType, supporting both torch and ttnn dtypes. + +**Parameters:** + + +Either torch dtype or ttnn.DataType + + +**Returns:** + +ttcore.DataType enum value + + + + + + + + +```python +ttl.dtype_utils.tile_bytes_from_dtype( + dtype +) -> int +``` + + + + + + +Calculate tile size in bytes from ttnn dtype. + +For tiled tensors, each tile is 32x32 elements. The byte size depends on +the data type's element size plus any format-specific overhead. + +**Parameters:** + + +ttnn.DataType enum value + + +**Returns:** `int` + +Tile size in bytes + +**Raises:** + +- `ValueError`: If dtype is not supported + + + + + + + + +```python +ttl.dtype_utils.torch_dtype_to_ttcore_datatype( + torch_dtype +) +``` + + + + + + +Convert PyTorch dtype to ttcore.DataType enum. + +**Parameters:** + + +PyTorch dtype (torch.float32, torch.int32, etc.) + + +**Returns:** + +ttcore.DataType enum value + +**Raises:** + +- `ValueError`: If dtype is not supported + + + + + + + + +```python +ttl.dtype_utils.torch_dtype_to_ttnn_datatype( + torch_dtype +) +``` + + + + + + +Convert PyTorch dtype to ttnn.DataType enum. + +**Parameters:** + + +PyTorch dtype (torch.float32, torch.bfloat16, etc.) + + +**Returns:** + +ttnn.DataType enum value + +**Raises:** + +- `ImportError`: If ttnn is not available +- `ValueError`: If dtype is not supported + + + + + + + + +```python +ttl.dtype_utils.ttnn_dtype_to_ttcore_datatype( + ttnn_dtype +) +``` + + + + + + +Convert ttnn.DataType to ttcore.DataType enum. + +**Parameters:** + + +ttnn.DataType enum value + + +**Returns:** + +ttcore.DataType enum value + +**Raises:** + +- `ValueError`: If dtype is not supported + + + diff --git a/fern/static/ttl-docs/ttl/ttl/kernel_runner.mdx b/fern/static/ttl-docs/ttl/ttl/kernel_runner.mdx new file mode 100644 index 0000000..a87c92a --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/kernel_runner.mdx @@ -0,0 +1,274 @@ +--- +layout: overview +slug: ttl/ttl/kernel_runner +title: ttl.kernel_runner +--- + +Shared kernel execution logic for tt-lang. + +Provides functions for building kernel descriptors, CB descriptors, and +executing kernels on device via ttnn.generic_op. Used by both the Python +DSL (CompiledTTNNKernel) and ME2E tests. + +This module provides a single reusable implementation of kernel argument +building and execution. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`KernelSpec`](#ttl-kernel_runner-KernelSpec) | Specification for a single kernel to execute. | + +### Functions + +| Name | Description | +|------|-------------| +| [`build_cb_descriptors`](#ttl-kernel_runner-build_cb_descriptors) | Build circular buffer descriptors for ttnn.generic_op. | +| [`build_kernel_descriptors`](#ttl-kernel_runner-build_kernel_descriptors) | Build kernel descriptors for ttnn.generic_op. | +| [`build_tensor_accessor_args`](#ttl-kernel_runner-build_tensor_accessor_args) | Build compile-time args for tensor accessors. | +| [`run_kernel_on_device`](#ttl-kernel_runner-run_kernel_on_device) | Execute kernels on device using ttnn.generic_op. | + +### Data + +[`__all__`](#ttl-kernel_runner-__all__) + +### API + + + + + +```python +class ttl.kernel_runner.KernelSpec( + path: str, + thread_type: str, + tensor_indices: typing.List[int], + config: typing.Any +) +``` + + + + + + +Dataclass + +Specification for a single kernel to execute. + + + + + + + + + + + + + + + + +```python +ttl.kernel_runner.build_cb_descriptors( + tensors: typing.List[typing.Any], + cb_configs: typing.List[typing.Any], + core_ranges: typing.Any +) -> typing.List[typing.Any] +``` + + + + + + +Build circular buffer descriptors for ttnn.generic_op. + +**Parameters:** + + +List of ttnn.Tensor objects. Each tensor's position (0, 1, 2, ...) +corresponds to its CB index. For intermediate CBs (not backed by +input/output tensors), pass None in the corresponding position. + + + +List of CircularBuffer objects for each CB, indexed by CB index. +Each CB has shape, buffer_factor, tensor (for dtype), and _cb_index attributes. + + + +ttnn.CoreRangeSet for CB allocation. + + +**Returns:** `List[Any]` + +List of ttnn.CBDescriptor objects. + + + + + + + + +```python +ttl.kernel_runner.build_kernel_descriptors( + kernel_specs: typing.List[ttl.kernel_runner.KernelSpec], + tensors: typing.List[typing.Any], + tensor_accessor_args: typing.List[int], + core_ranges: typing.Any, + grid_cols: int, + grid_rows: int, + num_cbs: int +) -> typing.List[typing.Any] +``` + + + + + + +Build kernel descriptors for ttnn.generic_op. + +**Parameters:** + + +List of kernel specifications. + + + +List of ttnn.Tensor objects. Position in this list determines +the global tensor index. Individual kernels access subsets via +tensor_indices in each KernelSpec. + + + +Flattened compile-time args from all tensors. + + + +ttnn.CoreRangeSet for kernel execution. + + + +Number of grid columns (x dimension). + + + +Number of grid rows (y dimension). + + + +Total number of circular buffers (including intermediate CBs). + + +**Returns:** `List[Any]` + +List of ttnn.KernelDescriptor objects. + + + + + + + + +```python +ttl.kernel_runner.build_tensor_accessor_args( + tensors: typing.List[typing.Any] +) -> typing.List[int] +``` + + + + + + +Build compile-time args for tensor accessors. + +**Parameters:** + + +List of ttnn.Tensor objects on device. + + +**Returns:** `List[int]` + +List of compile-time args (flattened TensorAccessorArgs for all tensors). + + + + + + + + +```python +ttl.kernel_runner.run_kernel_on_device( + kernel_specs: typing.List[ttl.kernel_runner.KernelSpec], + tensors: typing.List[typing.Any], + cb_configs: typing.List[typing.Any], + core_ranges: typing.Any, + program_hash: int = None +) -> typing.Any +``` + + + + + + +Execute kernels on device using ttnn.generic_op. + +This is the main entry point for kernel execution. It builds all +descriptors and runs the program. + +**Parameters:** + + +List of kernel specifications (path, thread_type, tensor_indices, config). + + + +List of ttnn.Tensor objects. Position in this list determines the +global tensor index. Individual kernels access subsets via tensor_indices +in each KernelSpec. + + + +List of CircularBuffer objects for each CB, indexed by CB index. +Includes both tensor-backed CBs and intermediate CBs. Each CB has shape, +buffer_factor, tensor (for dtype), and _cb_index attributes. + + + +ttnn.CoreRangeSet for kernel execution. + + + +Hash for tt-metal program cache (not yet used). + + +**Returns:** `Any` + +Result from ttnn.generic_op (typically None or output tensor). + + + + + + + + +```python +ttl.kernel_runner.__all__ = ['KernelSpec', 'build_tensor_accessor_args', 'build_kernel_descriptors', 'build_... +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/layouts.mdx b/fern/static/ttl-docs/ttl/ttl/layouts.mdx new file mode 100644 index 0000000..41784be --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/layouts.mdx @@ -0,0 +1,126 @@ +--- +layout: overview +slug: ttl/ttl/layouts +title: ttl.layouts +--- + +Layout creation utilities for tensor distribution across cores. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`TTNNLayoutConfig`](#ttl-layouts-TTNNLayoutConfig) | Configuration for TTNN layout creation. Supports L1/DRAM interleaved tiled layouts. | + +### Functions + +| Name | Description | +|------|-------------| +| [`create_ttnn_layout`](#ttl-layouts-create_ttnn_layout) | Create a TTNNLayoutAttr for L1 interleaved tiled tensors. | + +### Data + +[`_TTNN_BUFFER_TYPE_L1`](#ttl-layouts-_TTNN_BUFFER_TYPE_L1) + +[`_TTNN_TENSOR_MEMORY_LAYOUT_INTERLEAVED`](#ttl-layouts-_TTNN_TENSOR_MEMORY_LAYOUT_INTERLEAVED) + +### API + + + + + +```python +class ttl.layouts.TTNNLayoutConfig( + logical_shape: typing.List[int], + grid: typing.List[int], + dtype: str +) +``` + + + + + + +Dataclass + +Configuration for TTNN layout creation. Supports L1/DRAM interleaved tiled layouts. + + + + + + + + + + + + + + + + +```python +ttl.layouts.create_ttnn_layout( + ctx, + config: ttl.layouts.TTNNLayoutConfig +) +``` + + + + + + +Create a TTNNLayoutAttr for L1 interleaved tiled tensors. + +Supports: L1/DRAM memory, Interleaved layout, tiled (32x32 tiles). + +**Parameters:** + + +MLIR context + + + +Configuration with logical_shape, grid, and dtype + + +**Returns:** + +TTNNLayoutAttr + +**Raises:** + +- `ValueError`: If configuration is unsupported + + + + + + + + +```python +ttl.layouts._TTNN_BUFFER_TYPE_L1 = 1 +``` + + + + + + + + + +```python +ttl.layouts._TTNN_TENSOR_MEMORY_LAYOUT_INTERLEAVED = 0 +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/operators.mdx b/fern/static/ttl-docs/ttl/ttl/operators.mdx new file mode 100644 index 0000000..14fabc9 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/operators.mdx @@ -0,0 +1,642 @@ +--- +layout: overview +slug: ttl/ttl/operators +title: ttl.operators +--- + +DSL operators for tensor operations and data movement. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CopyTransferHandler`](#ttl-operators-CopyTransferHandler) | Transfer handle for asynchronous copy operations. | +| [`TensorBlock`](#ttl-operators-TensorBlock) | Represents a block of tensor data in the TTL dialect. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_cb_from_block`](#ttl-operators-_get_cb_from_block) | Extract the CB from a block (result of ttl.attach_cb). | +| [`_get_cb_shape`](#ttl-operators-_get_cb_shape) | Extract the block shape from a CB value. | +| [`_get_constant_int`](#ttl-operators-_get_constant_int) | Extract Python int from MLIR arith.ConstantOp or return as-is if already int. | +| [`_get_current_grid`](#ttl-operators-_get_current_grid) | Get the current grid dimensions. | +| [`_is_block`](#ttl-operators-_is_block) | Check if a value is a block (result of cb.reserve() or cb.wait()). | +| [`_make_tensor_slice`](#ttl-operators-_make_tensor_slice) | Create a ttl.tensor_slice from a tensor, tile indices, and shape. | +| [`_process_tensor_subscript`](#ttl-operators-_process_tensor_subscript) | Process tensor subscript and create tensor slice. | +| [`_set_current_grid`](#ttl-operators-_set_current_grid) | Set the current grid dimensions. Called before compiling threads. | +| [`broadcast`](#ttl-operators-broadcast) | Broadcast over specified dimensions. | +| [`copy`](#ttl-operators-copy) | Initiate an asynchronous data transfer using ttl.copy. | +| [`core`](#ttl-operators-core) | Get the coordinates of the current core. | +| [`grid_size`](#ttl-operators-grid_size) | Get the size of the grid. | +| [`signpost`](#ttl-operators-signpost) | Emit a profiling marker visible in Tracy. | + +### Data + +[`CoreCoordinate`](#ttl-operators-CoreCoordinate) + +[`IndexedTensor`](#ttl-operators-IndexedTensor) + +[`__all__`](#ttl-operators-__all__) + +[`_current_grid`](#ttl-operators-_current_grid) + +### API + + + + + +```python +class ttl.operators.CopyTransferHandler() +``` + + + + + + +Transfer handle for asynchronous copy operations. + +CopyTransferHandler objects are returned by copy() calls and must be +explicitly waited on to ensure transfer completion. + + + + + + +```python +ttl.operators.CopyTransferHandler.wait( + ast_self: ttl.operators.CopyTransferHandler +) +``` + + + + + + +Block until the copy operation completes. + + + + + + + + + +```python +class ttl.operators.TensorBlock( + shape, + dtype +) +``` + + + + + + +Represents a block of tensor data in the TTL dialect. + +TensorBlock supports arithmetic operations through operator +overloading. Operations generate TTL high-level ops that get lowered +to ttl.compute blocks. + + + + + + +```python +ttl.operators.TensorBlock.__add__( + ast_self: ttl.operators.TensorBlock, + rhs: ttl.operators.TensorBlock +) -> ttl.operators.TensorBlock +``` + + + + + + +Element-wise addition using ttl.add. + +**Parameters:** + + +Right operand tensor. Must have the same shape as self. + + +**Returns:** `TensorBlock` + +Result tensor with the same shape as inputs. + + + + + + + +```python +ttl.operators.TensorBlock.__matmul__( + ast_self: ttl.operators.TensorBlock, + rhs: ttl.operators.TensorBlock +) -> ttl.operators.TensorBlock +``` + + + + + + +Matrix multiplication is not yet supported in TTL mode. + + + + + + + +```python +ttl.operators.TensorBlock.__mul__( + ast_self: ttl.operators.TensorBlock, + rhs: ttl.operators.TensorBlock +) -> ttl.operators.TensorBlock +``` + + + + + + +Element-wise multiplication using ttl.mul. + + + + + + + +```python +ttl.operators.TensorBlock.__sub__( + ast_self: ttl.operators.TensorBlock, + rhs: ttl.operators.TensorBlock +) -> ttl.operators.TensorBlock +``` + + + + + + +Element-wise subtraction using ttl.sub. + + + + + + + +```python +ttl.operators.TensorBlock.store( + ast_self: ttl.operators.TensorBlock, + rhs: ttl.operators.TensorBlock +) -> None +``` + + + + + + +Store result tensor to CB by propagating CB association from output view. + + + + + + + + + +```python +ttl.operators._get_cb_from_block( + block +) +``` + + + + + + +Extract the CB from a block (result of ttl.attach_cb). + +The attach_cb op has signature: (tensor, cb) -> tensor +So the CB is operand[1]. + + + + + + + + +```python +ttl.operators._get_cb_shape( + cb_val +) +``` + + + + + + +Extract the block shape from a CB value. + + + + + + + + +```python +ttl.operators._get_constant_int( + val +) +``` + + + + + + +Extract Python int from MLIR arith.ConstantOp or return as-is if already int. + + + + + + + + +```python +ttl.operators._get_current_grid() -> typing.Tuple[int, int] +``` + + + + + + +Get the current grid dimensions. + + + + + + + + +```python +ttl.operators._is_block( + value +) -> bool +``` + + + + + + +Check if a value is a block (result of cb.reserve() or cb.wait()). + +A block is a tensor with an attached CB, produced by ttl.attach_cb. + + + + + + + + +```python +ttl.operators._make_tensor_slice( + tensor, + indices, + slice_shape +) +``` + + + + + + +Create a ttl.tensor_slice from a tensor, tile indices, and shape. + +**Parameters:** + + +The source tensor to slice from + + + +(row, col) tile indices for the slice start position + + + +(rows, cols) shape for the slice in tiles + + + + + + + + + +```python +ttl.operators._process_tensor_subscript( + subscript_tuple, + cb_shape +) +``` + + + + + + +Process tensor subscript and create tensor slice. + +**Parameters:** + + +(tensor, indices) where indices are [(value, is_range), ...] + + + +[rows, cols] shape from the CB + + +**Returns:** + +Tensor slice with shape matching cb_shape + + + + + + + + +```python +ttl.operators._set_current_grid( + grid: typing.Tuple[int, int] +) -> None +``` + + + + + + +Set the current grid dimensions. Called before compiling threads. + + + + + + + + +```python +ttl.operators.broadcast( + input: ttl.operators.TensorBlock, + output: ttl.operators.TensorBlock, + dims: typing.List[int] +) -> ttl.operators.TensorBlock +``` + + + + + + +Broadcast over specified dimensions. + +**Parameters:** + + +Input tensor (CB-attached) + + + +Output tensor (CB-attached, used for output CB tracking) + + + +Dimensions to broadcast over + + +**Returns:** `TensorBlock` + +Result tensor with broadcast values + + + + + + + + +```python +ttl.operators.copy( + src, + dst +) -> ttl.operators.CopyTransferHandler +``` + + + + + + +Initiate an asynchronous data transfer using ttl.copy. + +For multi-tile CBs (shape > 1x1), use range syntax: tensor[0:2, 0:2] +For single-tile CBs (shape 1x1), use index syntax: tensor[0, 0] + +**Parameters:** + + +Source tensor/slice (for reads) or block (for writes) + + + +Destination block (for reads) or tensor/slice (for writes) + + +**Returns:** `CopyTransferHandler` + +CopyTransferHandler handle that must be waited on for completion + + + + + + + + +```python +ttl.operators.core( + dims +) +``` + + + + + + +Get the coordinates of the current core. + +Currently only dims=2 is supported (temporary restriction). + +**Parameters:** + + +Number of dimensions to return (must be 2) + + +**Returns:** + +For dims=2: Tuple (x, y) where x is column coordinate and y is row coordinate + +**Raises:** + +- `ValueError`: If dims is not 2 + + + + + + + + +```python +ttl.operators.grid_size( + dims +) +``` + + + + + + +Get the size of the grid. + +Currently only dims=2 is supported (temporary restriction). + +**Parameters:** + + +Number of dimensions to return (must be 2) + + +**Returns:** + +For dims=2: Tuple (x_size, y_size) where x_size is columns and y_size is rows + +**Raises:** + +- `ValueError`: If dims is not 2 + + + + + + + + +```python +ttl.operators.signpost( + name: str +) +``` + + + + + + +Emit a profiling marker visible in Tracy. + +The marker creates a DeviceZoneScopedN in the generated C++ code, +which will appear in Tracy profiler traces when TT_METAL_DEVICE_PROFILER=1. + +**Parameters:** + + +Name for the profiling region (must be a string literal) + + + + + + + + + +```python +ttl.operators.CoreCoordinate = Tuple[int, int] +``` + + + + + + + + + +```python +ttl.operators.IndexedTensor = Union['TensorBlock', Tuple['TensorBlock', Tuple[int, ...]]] +``` + + + + + + + + + +```python +ttl.operators.__all__ = ['TensorBlock', 'CopyTransferHandler', 'copy', 'core', 'grid_size', 'signpost', ... +``` + + + + + + + + + +```python +ttl.operators._current_grid: Tuple[int, int] = (-1, -1) +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/ttl.mdx b/fern/static/ttl-docs/ttl/ttl/ttl.mdx new file mode 100644 index 0000000..80fe5c9 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/ttl.mdx @@ -0,0 +1,27 @@ +--- +layout: overview +slug: ttl/ttl/ttl +title: ttl.ttl +--- + +TTL DSL module providing the unified ttl.* API namespace. + +## Module Contents + +### Data + +[`__all__`](#ttl-ttl-__all__) + +### API + + + + + +```python +ttl.ttl.__all__ = ['kernel', 'compute', 'datamovement', 'Program', 'make_circular_buffer_like', 'c... +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/ttl_api.mdx b/fern/static/ttl-docs/ttl/ttl/ttl_api.mdx new file mode 100644 index 0000000..e890735 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/ttl_api.mdx @@ -0,0 +1,907 @@ +--- +layout: overview +slug: ttl/ttl/ttl_api +title: ttl.ttl_api +--- + +Main API for the TTL dialect Python DSL. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CompiledTTNNKernel`](#ttl-ttl_api-CompiledTTNNKernel) | A compiled tt-lang kernel ready for execution via ttnn.generic_op. | +| [`Program`](#ttl-ttl_api-Program) | Immutable container for kernel threads and their arguments. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_clear_thread_registry`](#ttl-ttl_api-_clear_thread_registry) | Clear the thread registry before kernel execution. | +| [`_collect_captures`](#ttl-ttl_api-_collect_captures) | Collect and convert captured variables from function closure. | +| [`_collect_cb_configs`](#ttl-ttl_api-_collect_cb_configs) | Extract CircularBuffer objects from thread closures, indexed by cb_index. | +| [`_compile`](#ttl-ttl_api-_compile) | Internal decorator for compiling kernel threads. | +| [`_compile_kernel`](#ttl-ttl_api-_compile_kernel) | Compile kernel function to MLIR and return CompiledTTNNKernel. | +| [`_compile_ttnn_kernel`](#ttl-ttl_api-_compile_ttnn_kernel) | Compile kernel to CompiledTTNNKernel for execution via ttnn.generic_op. | +| [`_detect_memory_space_from_tensor`](#ttl-ttl_api-_detect_memory_space_from_tensor) | Detect memory space (L1/DRAM) from a ttnn tensor's buffer type. | +| [`_get_registered_threads`](#ttl-ttl_api-_get_registered_threads) | Get all registered threads and clear the registry. | +| [`_get_source_line_offset`](#ttl-ttl_api-_get_source_line_offset) | Get the line offset to convert parsed AST line numbers to actual file lines. | +| [`_get_tensor_cache_info`](#ttl-ttl_api-_get_tensor_cache_info) | Extract cache-relevant info from a tensor: (shape, dtype, memory_space, layout). | +| [`_has_float32_args`](#ttl-ttl_api-_has_float32_args) | Check if any input tensor uses float32 dtype. | +| [`_is_interleaved_tensor`](#ttl-ttl_api-_is_interleaved_tensor) | Check if a ttnn tensor has interleaved memory layout. | +| [`_make_cache_key`](#ttl-ttl_api-_make_cache_key) | Create cache key from tensor properties and runtime compute config parameters. | +| [`_register_thread`](#ttl-ttl_api-_register_thread) | Register a thread function during decoration. | +| [`_resolve_grid`](#ttl-ttl_api-_resolve_grid) | Resolve grid, evaluating callable or 'auto' if needed. | +| [`_run_profiling_pipeline`](#ttl-ttl_api-_run_profiling_pipeline) | Read device profiler data and display profile report. | +| [`_should_execute`](#ttl-ttl_api-_should_execute) | Check if kernel execution should proceed (not compile-only mode). | +| [`_track_tensor_sources`](#ttl-ttl_api-_track_tensor_sources) | Track source locations for tensor arguments. | +| [`_write_kernel_to_tmp`](#ttl-ttl_api-_write_kernel_to_tmp) | Write kernel source to /tmp and return the file path. | +| [`compute`](#ttl-ttl_api-compute) | Decorator for compute thread functions. | +| [`datamovement`](#ttl-ttl_api-datamovement) | Decorator for data movement thread functions. | +| [`pykernel_gen`](#ttl-ttl_api-pykernel_gen) | Decorator for generating TTL kernels from Python functions. | + +### Data + +[`__all__`](#ttl-ttl_api-__all__) + +[`_thread_registry`](#ttl-ttl_api-_thread_registry) + +[`kernel`](#ttl-ttl_api-kernel) + +### API + + + + + +```python +class ttl.ttl_api.CompiledTTNNKernel( + kernel_paths, + kernel_configs, + kernel_arg_specs, + num_tensors, + core_ranges, + kernel_tensor_indices, + cb_configs = None, + program_hash = None, + source_lines = None, + all_source_lines = None, + thread_to_kernel = None, + kernel_line_offsets = None +) +``` + + + + + + +A compiled tt-lang kernel ready for execution via ttnn.generic_op. + +Caches compilation artifacts (kernel paths, CB descriptors) so the kernel +can be executed multiple times with different tensors without recompiling. + + + + + + + + + + + + + + + + + +```python +ttl.ttl_api.CompiledTTNNKernel.__call__( + args = () +) +``` + + + + + + +Execute the kernel with the given tensors. + + + + + + + + + +```python +class ttl.ttl_api.Program( + threads = (), + args = (), + kwargs = None +) +``` + + + + + + +Immutable container for kernel threads and their arguments. + +A Program encapsulates compute and data movement threads along with +the arguments to be passed during execution. After construction, all +fields should be treated as read-only. + + + + + + + + + + + + + + + + + +```python +ttl.ttl_api.Program.__call__( + args = (), + kwargs = {} +) +``` + + + + + + + + + + + + + + +```python +ttl.ttl_api._clear_thread_registry() -> None +``` + + + + + + +Clear the thread registry before kernel execution. + + + + + + + + +```python +ttl.ttl_api._collect_captures( + f: typing.Callable +) -> typing.Dict[str, typing.Union[int, ttl.circular_buffer.CircularBuffer]] +``` + + + + + + +Collect and convert captured variables from function closure. + +**Parameters:** + + +Function with closure to inspect + + +**Returns:** `Dict[str, Union[int, CircularBuffer]]` + +Dictionary mapping variable names to converted values + +**Raises:** + +- `TypeError`: If closure contains unsupported variable types + + + + + + + + +```python +ttl.ttl_api._collect_cb_configs( + threads +) +``` + + + + + + +Extract CircularBuffer objects from thread closures, indexed by cb_index. + +Returns a list of CircularBuffer objects indexed by cb_index. Each CB has +shape, buffer_factor, tensor (for dtype), and _cb_index attributes. + + + + + + + + +```python +ttl.ttl_api._compile( + kernel_type: typing.Optional[str] = None, + verbose: bool = False +) -> typing.Callable +``` + + + + + + +Internal decorator for compiling kernel threads. + +**Parameters:** + + +Type of kernel ("compute" or "datamovement") + + + +Enable verbose compilation output + + +**Returns:** `Callable` + +Decorator function for kernel compilation + + + + + + + + +```python +ttl.ttl_api._compile_kernel( + f: typing.Callable, + args: tuple, + kwargs: dict, + grid: typing.Union[tuple, typing.List[int]], + indexing_maps: typing.List[typing.Callable], + iterator_types: typing.List[str], + num_outs: int, + memory_space: str, + tiled: bool, + program_hash: int, + fp32_dest_acc_en: typing.Optional[bool] = None, + dst_full_sync_en: typing.Optional[bool] = None +) -> typing.Optional[ttl.ttl_api.CompiledTTNNKernel] +``` + + + + + + +Compile kernel function to MLIR and return CompiledTTNNKernel. + +**Parameters:** + + +User kernel function + + + +Positional arguments for the kernel + + + +Keyword arguments for the kernel + + + +Grid dimensions + + + +List of lambda functions for indexing + + + +List of iterator type strings + + + +Number of output arguments + + + +"L1" or "DRAM" + + + +Whether to use tiled layout + + + +Hash for tt-metal program cache + + + +Optional override for fp32_dest_acc_en + + + +Optional override for dst_full_sync_en + + +**Returns:** `Optional[CompiledTTNNKernel]` + +CompiledTTNNKernel ready for execution + + + + + + + + +```python +ttl.ttl_api._compile_ttnn_kernel( + module, + args, + grid, + num_outs, + thread_tensor_indices, + cb_configs = None, + program_hash = None, + fp32_dest_acc_en: typing.Optional[bool] = None, + dst_full_sync_en: typing.Optional[bool] = None, + verbose = True, + source_lines = None, + all_source_lines = None, + kernel_line_offsets = None +) +``` + + + + + + +Compile kernel to CompiledTTNNKernel for execution via ttnn.generic_op. + +Builds kernel paths, configs, and CB descriptors from compiled MLIR module. + +**Parameters:** + + +MLIR module after D2M pipeline (with EmitC kernels) + + + +Input/output tensors (used for shape/dtype info) + + + +Grid dimensions tuple + + + +Number of output tensors + + + +Hash for tt-metal program cache + + + +Print compilation info + + + +Source code lines for auto-profiling reports + + +**Returns:** + +CompiledTTNNKernel ready for execution + + + + + + + + +```python +ttl.ttl_api._detect_memory_space_from_tensor( + tensor, + default: str +) -> str +``` + + + + + + +Detect memory space (L1/DRAM) from a ttnn tensor's buffer type. + + + + + + + + +```python +ttl.ttl_api._get_registered_threads() -> typing.List[typing.Callable] +``` + + + + + + +Get all registered threads and clear the registry. + + + + + + + + +```python +ttl.ttl_api._get_source_line_offset( + f +) -> int +``` + + + + + + +Get the line offset to convert parsed AST line numbers to actual file lines. + + + + + + + + +```python +ttl.ttl_api._get_tensor_cache_info( + tensor +) -> tuple +``` + + + + + + +Extract cache-relevant info from a tensor: (shape, dtype, memory_space, layout). + + + + + + + + +```python +ttl.ttl_api._has_float32_args( + args +) -> bool +``` + + + + + + +Check if any input tensor uses float32 dtype. + +Inspects the tensor arguments to detect float32. This is used to +automatically enable fp32_dest_acc_en configuration for compute kernels. + +**Parameters:** + + +List of tensor arguments (torch or ttnn) + + +**Returns:** `bool` + +True if any tensor uses float32 dtype, False otherwise + + + + + + + + +```python +ttl.ttl_api._is_interleaved_tensor( + tensor +) -> bool +``` + + + + + + +Check if a ttnn tensor has interleaved memory layout. + + + + + + + + +```python +ttl.ttl_api._make_cache_key( + args: tuple, + fp32_dest_acc_en: typing.Optional[bool], + dst_full_sync_en: typing.Optional[bool] +) -> tuple +``` + + + + + + +Create cache key from tensor properties and runtime compute config parameters. + + + + + + + + +```python +ttl.ttl_api._register_thread( + thread_fn: typing.Callable +) -> None +``` + + + + + + +Register a thread function during decoration. + + + + + + + + +```python +ttl.ttl_api._resolve_grid( + grid, + args, + kwargs +) +``` + + + + + + +Resolve grid, evaluating callable or 'auto' if needed. + + + + + + + + +```python +ttl.ttl_api._run_profiling_pipeline( + tensors: tuple, + all_source_lines: typing.Dict[str, typing.List[str]], + thread_to_kernel: typing.Dict[str, str], + kernel_line_offsets: typing.Optional[typing.Dict[str, int]] = None +) +``` + + + + + + +Read device profiler data and display profile report. + +Called after kernel execution when auto-profiling is enabled. + +**Parameters:** + + +Tuple of tensor arguments passed to the kernel + + + +Dict mapping kernel name to source lines + + + +Dict mapping RISC thread name to kernel name + + + + + + + + + +```python +ttl.ttl_api._should_execute() -> bool +``` + + + + + + +Check if kernel execution should proceed (not compile-only mode). + + + + + + + + +```python +ttl.ttl_api._track_tensor_sources( + f_params, + args, + source_file: str +) -> None +``` + + + + + + +Track source locations for tensor arguments. + +Searches backwards from the kernel call site to find where each +tensor variable was assigned, then registers that location. + + + + + + + + +```python +ttl.ttl_api._write_kernel_to_tmp( + name: str, + source: str +) -> str +``` + + + + + + +Write kernel source to /tmp and return the file path. + + + + + + + + +```python +ttl.ttl_api.compute( + verbose: bool = False +) -> typing.Callable +``` + + + + + + +Decorator for compute thread functions. + +Compute threads execute on Tensix cores and perform mathematical operations. + +**Parameters:** + + +Enable verbose compilation output + + +**Returns:** `Callable` + +Decorator for compute kernel compilation + + + + + + + + +```python +ttl.ttl_api.datamovement( + verbose: bool = False +) -> typing.Callable +``` + + + + + + +Decorator for data movement thread functions. + +Data movement threads handle DMA operations between memory hierarchies. + +**Parameters:** + + +Enable verbose compilation output + + +**Returns:** `Callable` + +Decorator for data movement kernel compilation + + + + + + + + +```python +ttl.ttl_api.pykernel_gen( + grid: typing.Optional[typing.Union[tuple, typing.Callable]] = None, + indexing_maps: typing.Optional[typing.List[typing.Callable]] = None, + iterator_types: typing.Optional[typing.List[str]] = None, + num_outs: int = 1, + memory_space: str = 'L1', + tiled: bool = True, + fp32_dest_acc_en: typing.Optional[bool] = None, + dst_full_sync_en: typing.Optional[bool] = None +) -> typing.Callable +``` + + + + + + +Decorator for generating TTL kernels from Python functions. + +This decorator compiles Python functions into TTL dialect operations, +handling thread compilation, stream creation, and pipeline execution. +Kernels are compiled to C++ for execution via ttnn.generic_op. + +**Parameters:** + + +Grid dimensions as tuple (e.g., (2, 2)) or callable + + + +List of lambda functions for indexing (optional) + + + +List of iterator types ("parallel", "reduction") + + + +Number of output arguments + + + +"L1" or "DRAM" + + + +Whether to use tiled layout + + + +Optional override for fp32_dest_acc_en + + + +Optional override for dst_full_sync_en + + +**Returns:** `Callable` + +Decorated function that compiles and executes the kernel + +**Raises:** + +- `AssertionError`: If required parameters are missing or invalid + + + + + + + + +```python +ttl.ttl_api.__all__ = ['pykernel_gen', 'kernel', 'Program', 'compute', 'datamovement', 'TensorBlock', ... +``` + + + + + + + + + +```python +ttl.ttl_api._thread_registry: List[Callable] = [] +``` + + + + + + + + + +```python +ttl.ttl_api.kernel = pykernel_gen +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/ttl_math.mdx b/fern/static/ttl-docs/ttl/ttl/ttl_math.mdx new file mode 100644 index 0000000..9472960 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/ttl_math.mdx @@ -0,0 +1,29 @@ +--- +layout: overview +slug: ttl/ttl/ttl_math +title: ttl.ttl_math +--- + +TTL math operations namespace (ttl.math). + +Re-exports elementwise operations from the generated module. + +## Module Contents + +### Data + +[`__all__`](#ttl-ttl_math-__all__) + +### API + + + + + +```python +ttl.ttl_math.__all__ = ['broadcast', *_generated_all] +``` + + + + diff --git a/fern/static/ttl-docs/ttl/ttl/ttl_utils.mdx b/fern/static/ttl-docs/ttl/ttl/ttl_utils.mdx new file mode 100644 index 0000000..813bd24 --- /dev/null +++ b/fern/static/ttl-docs/ttl/ttl/ttl_utils.mdx @@ -0,0 +1,70 @@ +--- +layout: overview +slug: ttl/ttl/ttl_utils +title: ttl.ttl_utils +--- + +Utility functions for tt-lang. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_thread_type_string`](#ttl-ttl_utils-get_thread_type_string) | Map kernel type to thread type string. | + +### Data + +[`_KERNEL_TYPE_TO_THREAD_TYPE`](#ttl-ttl_utils-_KERNEL_TYPE_TO_THREAD_TYPE) + +### API + + + + + +```python +ttl.ttl_utils.get_thread_type_string( + input: typing.Union[str, object] +) -> str +``` + + + + + + +Map kernel type to thread type string. + +Handles both string kernel types and MLIR ThreadTypeAttr. + +**Parameters:** + + +Either a string kernel type ("compute", "datamovement", "ethernet") + or a ttkernel.ThreadTypeAttr from MLIR IR + + +**Returns:** `str` + +Thread type string: "compute", "noc", "ethernet" + +**Raises:** + +- `ValueError`: If input is a string that's not a valid kernel type + + + + + + + + +```python +ttl.ttl_utils._KERNEL_TYPE_TO_THREAD_TYPE = {'compute': 'compute', 'datamovement': 'noc', 'ethernet': 'ethernet'} +``` + + + + From 4e26ae30c67e612093084d461b81f8f6f343215b Mon Sep 17 00:00:00 2001 From: Paarth Gupta Date: Fri, 13 Feb 2026 16:00:36 -0500 Subject: [PATCH 4/6] more testing --- fern/docs.yml | 17 +- fern/fern.config.json | 2 +- fern/static/nemo-rl-docs/_navigation.yml | 1017 --------- fern/static/nemo-rl-docs/nemo-rl/nemo_rl.mdx | 149 -- .../nemo-rl/nemo_rl/algorithms.mdx | 19 - .../algorithms/advantage_estimator.mdx | 196 -- .../nemo_rl/algorithms/async_utils.mdx | 572 ----- .../nemo_rl/algorithms/distillation.mdx | 326 --- .../nemo-rl/nemo_rl/algorithms/dpo.mdx | 378 ---- .../nemo-rl/nemo_rl/algorithms/grpo.mdx | 864 -------- .../nemo-rl/nemo_rl/algorithms/interfaces.mdx | 123 -- .../nemo_rl/algorithms/loss_functions.mdx | 875 -------- .../nemo_rl/algorithms/reward_functions.mdx | 102 - .../nemo-rl/nemo_rl/algorithms/rm.mdx | 320 --- .../nemo-rl/nemo_rl/algorithms/sft.mdx | 258 --- .../nemo-rl/nemo_rl/algorithms/utils.mdx | 379 ---- .../nemo-rl-docs/nemo-rl/nemo_rl/data.mdx | 466 ----- .../nemo-rl/nemo_rl/data/chat_templates.mdx | 35 - .../nemo-rl/nemo_rl/data/collate_fn.mdx | 166 -- .../nemo-rl/nemo_rl/data/datasets.mdx | 37 - .../nemo_rl/data/datasets/eval_datasets.mdx | 60 - .../data/datasets/eval_datasets/aime.mdx | 64 - .../data/datasets/eval_datasets/gpqa.mdx | 64 - .../eval_datasets/local_math_dataset.mdx | 65 - .../data/datasets/eval_datasets/math.mdx | 61 - .../data/datasets/eval_datasets/mmlu.mdx | 61 - .../data/datasets/eval_datasets/mmlu_pro.mdx | 60 - .../data/datasets/preference_datasets.mdx | 72 - .../binary_preference_dataset.mdx | 102 - .../preference_datasets/helpsteer3.mdx | 66 - .../preference_dataset.mdx | 77 - .../datasets/preference_datasets/tulu3.mdx | 59 - .../data/datasets/processed_dataset.mdx | 135 -- .../nemo_rl/data/datasets/raw_dataset.mdx | 94 - .../data/datasets/response_datasets.mdx | 82 - .../datasets/response_datasets/aime24.mdx | 66 - .../data/datasets/response_datasets/clevr.mdx | 97 - .../datasets/response_datasets/dapo_math.mdx | 84 - .../datasets/response_datasets/deepscaler.mdx | 59 - .../datasets/response_datasets/geometry3k.mdx | 76 - .../datasets/response_datasets/helpsteer3.mdx | 66 - .../response_datasets/nemogym_dataset.mdx | 54 - .../response_datasets/oai_format_dataset.mdx | 214 -- .../data/datasets/response_datasets/oasst.mdx | 127 -- .../response_datasets/openmathinstruct2.mdx | 84 - .../datasets/response_datasets/refcoco.mdx | 160 -- .../response_datasets/response_dataset.mdx | 104 - .../data/datasets/response_datasets/squad.mdx | 66 - .../data/datasets/response_datasets/tulu3.mdx | 76 - .../nemo-rl/nemo_rl/data/datasets/utils.mdx | 191 -- .../nemo-rl/nemo_rl/data/interfaces.mdx | 284 --- .../nemo_rl/data/llm_message_utils.mdx | 548 ----- .../nemo-rl/nemo_rl/data/multimodal_utils.mdx | 298 --- .../nemo-rl/nemo_rl/data/packing.mdx | 30 - .../nemo_rl/data/packing/algorithms.mdx | 791 ------- .../nemo-rl/nemo_rl/data/packing/metrics.mdx | 177 -- .../nemo-rl/nemo_rl/data/processors.mdx | 353 ---- .../nemo-rl/nemo_rl/data/utils.mdx | 104 - .../nemo-rl/nemo_rl/distributed.mdx | 17 - .../nemo_rl/distributed/batched_data_dict.mdx | 671 ------ .../nemo_rl/distributed/collectives.mdx | 108 - .../nemo_rl/distributed/model_utils.mdx | 851 -------- .../nemo_rl/distributed/named_sharding.mdx | 236 --- .../ray_actor_environment_registry.mdx | 105 - .../distributed/stateless_process_group.mdx | 73 - .../nemo_rl/distributed/virtual_cluster.mdx | 514 ----- .../distributed/worker_group_utils.mdx | 81 - .../nemo_rl/distributed/worker_groups.mdx | 603 ------ .../nemo-rl/nemo_rl/environments.mdx | 19 - .../nemo_rl/environments/code_environment.mdx | 290 --- .../environments/code_jaccard_environment.mdx | 268 --- .../environments/dapo_math_verifier.mdx | 316 --- .../nemo_rl/environments/interfaces.mdx | 151 -- .../nemo_rl/environments/math_environment.mdx | 356 ---- .../nemo-rl/nemo_rl/environments/metrics.mdx | 42 - .../nemo-rl/nemo_rl/environments/nemo_gym.mdx | 210 -- .../environments/reward_model_environment.mdx | 276 --- .../nemo-rl/nemo_rl/environments/rewards.mdx | 180 -- .../nemo-rl/nemo_rl/environments/utils.mdx | 152 -- .../nemo_rl/environments/vlm_environment.mdx | 243 --- .../nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx | 10 - .../nemo-rl/nemo_rl/evals/answer_parsing.mdx | 86 - .../nemo-rl/nemo_rl/evals/eval.mdx | 399 ---- .../nemo-rl/nemo_rl/experience.mdx | 9 - .../nemo-rl/nemo_rl/experience/rollouts.mdx | 469 ----- .../nemo-rl-docs/nemo-rl/nemo_rl/models.mdx | 14 - .../nemo-rl/nemo_rl/models/automodel.mdx | 12 - .../nemo_rl/models/automodel/config.mdx | 125 -- .../nemo-rl/nemo_rl/models/automodel/data.mdx | 374 ---- .../nemo_rl/models/automodel/setup.mdx | 229 -- .../nemo_rl/models/automodel/train.mdx | 841 -------- .../nemo-rl/nemo_rl/models/dtensor.mdx | 9 - .../nemo_rl/models/dtensor/parallelize.mdx | 454 ---- .../nemo-rl/nemo_rl/models/generation.mdx | 62 - .../nemo_rl/models/generation/interfaces.mdx | 569 ----- .../nemo_rl/models/generation/sglang.mdx | 33 - .../models/generation/sglang/config.mdx | 299 --- .../generation/sglang/sglang_copied_utils.mdx | 307 --- .../generation/sglang/sglang_generation.mdx | 369 ---- .../generation/sglang/sglang_worker.mdx | 529 ----- .../models/generation/sglang/utils.mdx | 109 - .../nemo_rl/models/generation/vllm.mdx | 34 - .../nemo_rl/models/generation/vllm/config.mdx | 111 - .../nemo_rl/models/generation/vllm/utils.mdx | 113 - .../models/generation/vllm/vllm_backend.mdx | 236 --- .../generation/vllm/vllm_generation.mdx | 656 ------ .../models/generation/vllm/vllm_worker.mdx | 545 ----- .../generation/vllm/vllm_worker_async.mdx | 485 ----- .../nemo-rl/nemo_rl/models/huggingface.mdx | 9 - .../nemo_rl/models/huggingface/common.mdx | 303 --- .../nemo-rl/nemo_rl/models/megatron.mdx | 13 - .../nemo_rl/models/megatron/common.mdx | 212 -- .../models/megatron/community_import.mdx | 76 - .../nemo_rl/models/megatron/config.mdx | 146 -- .../nemo-rl/nemo_rl/models/megatron/data.mdx | 471 ----- .../nemo-rl/nemo_rl/models/megatron/setup.mdx | 535 ----- .../nemo-rl/nemo_rl/models/policy.mdx | 948 --------- .../nemo_rl/models/policy/interfaces.mdx | 574 ----- .../nemo_rl/models/policy/lm_policy.mdx | 609 ------ .../nemo-rl/nemo_rl/models/policy/utils.mdx | 624 ------ .../nemo-rl/nemo_rl/models/policy/workers.mdx | 13 - .../policy/workers/base_policy_worker.mdx | 309 --- .../policy/workers/dtensor_policy_worker.mdx | 693 ------ .../workers/dtensor_policy_worker_v2.mdx | 714 ------- .../policy/workers/megatron_policy_worker.mdx | 682 ------ .../nemo_rl/models/policy/workers/patches.mdx | 85 - .../nemo-rl/nemo_rl/package_info.mdx | 235 --- .../nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx | 22 - .../nemo_rl/utils/automodel_checkpoint.mdx | 436 ---- .../nemo-rl/nemo_rl/utils/checkpoint.mdx | 411 ---- .../nemo-rl/nemo_rl/utils/config.mdx | 266 --- .../nemo-rl/nemo_rl/utils/flops_formulas.mdx | 501 ----- .../nemo-rl/nemo_rl/utils/flops_tracker.mdx | 215 -- .../nemo-rl/nemo_rl/utils/logger.mdx | 1856 ----------------- .../nemo-rl/nemo_rl/utils/memory_tracker.mdx | 122 -- .../nemo_rl/utils/native_checkpoint.mdx | 351 ---- .../nemo-rl/nemo_rl/utils/nsys.mdx | 138 -- .../nemo-rl/nemo_rl/utils/nvml.mdx | 100 - .../nemo-rl/nemo_rl/utils/packed_tensor.mdx | 140 -- .../nemo-rl/nemo_rl/utils/prefetch_venvs.mdx | 108 - .../nemo-rl/nemo_rl/utils/timer.mdx | 441 ---- .../nemo-rl/nemo_rl/utils/venvs.mdx | 177 -- fern/static/ttl-docs/_navigation.yml | 149 -- fern/static/ttl-docs/ttl/ttl.mdx | 60 - fern/static/ttl-docs/ttl/ttl/_mlir_libs.mdx | 9 - .../ttl/ttl/_mlir_libs/_site_initialize_1.mdx | 35 - fern/static/ttl-docs/ttl/ttl/_src.mdx | 11 - .../ttl-docs/ttl/ttl/_src/auto_profile.mdx | 479 ----- .../ttl-docs/ttl/ttl/_src/tensor_registry.mdx | 169 -- fern/static/ttl-docs/ttl/ttl/_src/ttl_ast.mdx | 731 ------- .../ttl-docs/ttl/ttl/circular_buffer.mdx | 283 --- fern/static/ttl-docs/ttl/ttl/constants.mdx | 41 - fern/static/ttl-docs/ttl/ttl/diagnostics.mdx | 466 ----- fern/static/ttl-docs/ttl/ttl/dialects.mdx | 12 - .../ttl-docs/ttl/ttl/dialects/_ods_common.mdx | 39 - fern/static/ttl-docs/ttl/ttl/dialects/ttl.mdx | 81 - fern/static/ttl-docs/ttl/ttl/dtype_utils.mdx | 212 -- .../static/ttl-docs/ttl/ttl/kernel_runner.mdx | 274 --- fern/static/ttl-docs/ttl/ttl/layouts.mdx | 126 -- fern/static/ttl-docs/ttl/ttl/operators.mdx | 642 ------ fern/static/ttl-docs/ttl/ttl/ttl.mdx | 27 - fern/static/ttl-docs/ttl/ttl/ttl_api.mdx | 907 -------- fern/static/ttl-docs/ttl/ttl/ttl_math.mdx | 29 - fern/static/ttl-docs/ttl/ttl/ttl_utils.mdx | 70 - 164 files changed, 16 insertions(+), 42153 deletions(-) delete mode 100644 fern/static/nemo-rl-docs/_navigation.yml delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx delete mode 100644 fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx delete mode 100644 fern/static/ttl-docs/_navigation.yml delete mode 100644 fern/static/ttl-docs/ttl/ttl.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/_mlir_libs.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/_src.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/_src/auto_profile.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/_src/tensor_registry.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/_src/ttl_ast.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/circular_buffer.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/constants.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/diagnostics.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/dialects.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/dialects/_ods_common.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/dialects/ttl.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/dtype_utils.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/kernel_runner.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/layouts.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/operators.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/ttl.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/ttl_api.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/ttl_math.mdx delete mode 100644 fern/static/ttl-docs/ttl/ttl/ttl_utils.mdx diff --git a/fern/docs.yml b/fern/docs.yml index 1003620..28406aa 100644 --- a/fern/docs.yml +++ b/fern/docs.yml @@ -28,20 +28,30 @@ tabs: TTL Reference: display-name: TTL Reference icon: book + LangChain Core Reference: + display-name: LangChain Core + icon: book libraries: nemo-rl: input: git: https://github.com/NVIDIA-NeMo/RL output: - path: ./static/nemo-rl-docs + path: ./library-docs/nemo-rl-docs lang: python ttl: input: git: https://github.com/tenstorrent/tt-lang subpath: python/ttl output: - path: ./static/ttl-docs + path: ./library-docs/ttl-docs + lang: python + langchain-core: + input: + git: https://github.com/langchain-ai/langchain + subpath: libs/core/langchain_core + output: + path: ./library-docs/langchain-core-docs lang: python navigation: @@ -93,6 +103,9 @@ navigation: - tab: TTL Reference layout: - library: ttl + - tab: LangChain Core Reference + layout: + - library: langchain-core navbar-links: diff --git a/fern/fern.config.json b/fern/fern.config.json index 208d5a3..e188917 100644 --- a/fern/fern.config.json +++ b/fern/fern.config.json @@ -1,4 +1,4 @@ { "organization": "fern", - "version": "3.63.0" + "version": "3.78.0" } diff --git a/fern/static/nemo-rl-docs/_navigation.yml b/fern/static/nemo-rl-docs/_navigation.yml deleted file mode 100644 index 2d1a93b..0000000 --- a/fern/static/nemo-rl-docs/_navigation.yml +++ /dev/null @@ -1,1017 +0,0 @@ -# AUTO-GENERATED by `fern docs md generate` — DO NOT EDIT -- type: section - title: algorithms - slug: nemo-rl/nemo_rl/algorithms - children: - - type: section - title: advantage_estimator - slug: nemo-rl/nemo_rl/algorithms/advantage_estimator - children: - - type: page - title: advantage_estimator - slug: nemo-rl/nemo_rl/algorithms/advantage_estimator - pageId: nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx - - type: section - title: async_utils - slug: nemo-rl/nemo_rl/algorithms/async_utils - children: - - type: page - title: async_utils - slug: nemo-rl/nemo_rl/algorithms/async_utils - pageId: nemo-rl/nemo_rl/algorithms/async_utils.mdx - - type: section - title: distillation - slug: nemo-rl/nemo_rl/algorithms/distillation - children: - - type: page - title: distillation - slug: nemo-rl/nemo_rl/algorithms/distillation - pageId: nemo-rl/nemo_rl/algorithms/distillation.mdx - - type: section - title: dpo - slug: nemo-rl/nemo_rl/algorithms/dpo - children: - - type: page - title: dpo - slug: nemo-rl/nemo_rl/algorithms/dpo - pageId: nemo-rl/nemo_rl/algorithms/dpo.mdx - - type: section - title: grpo - slug: nemo-rl/nemo_rl/algorithms/grpo - children: - - type: page - title: grpo - slug: nemo-rl/nemo_rl/algorithms/grpo - pageId: nemo-rl/nemo_rl/algorithms/grpo.mdx - - type: section - title: interfaces - slug: nemo-rl/nemo_rl/algorithms/interfaces - children: - - type: page - title: interfaces - slug: nemo-rl/nemo_rl/algorithms/interfaces - pageId: nemo-rl/nemo_rl/algorithms/interfaces.mdx - - type: section - title: loss_functions - slug: nemo-rl/nemo_rl/algorithms/loss_functions - children: - - type: page - title: loss_functions - slug: nemo-rl/nemo_rl/algorithms/loss_functions - pageId: nemo-rl/nemo_rl/algorithms/loss_functions.mdx - - type: section - title: reward_functions - slug: nemo-rl/nemo_rl/algorithms/reward_functions - children: - - type: page - title: reward_functions - slug: nemo-rl/nemo_rl/algorithms/reward_functions - pageId: nemo-rl/nemo_rl/algorithms/reward_functions.mdx - - type: section - title: rm - slug: nemo-rl/nemo_rl/algorithms/rm - children: - - type: page - title: rm - slug: nemo-rl/nemo_rl/algorithms/rm - pageId: nemo-rl/nemo_rl/algorithms/rm.mdx - - type: section - title: sft - slug: nemo-rl/nemo_rl/algorithms/sft - children: - - type: page - title: sft - slug: nemo-rl/nemo_rl/algorithms/sft - pageId: nemo-rl/nemo_rl/algorithms/sft.mdx - - type: section - title: utils - slug: nemo-rl/nemo_rl/algorithms/utils - children: - - type: page - title: utils - slug: nemo-rl/nemo_rl/algorithms/utils - pageId: nemo-rl/nemo_rl/algorithms/utils.mdx -- type: section - title: data - slug: nemo-rl/nemo_rl/data - children: - - type: section - title: chat_templates - slug: nemo-rl/nemo_rl/data/chat_templates - children: - - type: page - title: chat_templates - slug: nemo-rl/nemo_rl/data/chat_templates - pageId: nemo-rl/nemo_rl/data/chat_templates.mdx - - type: section - title: collate_fn - slug: nemo-rl/nemo_rl/data/collate_fn - children: - - type: page - title: collate_fn - slug: nemo-rl/nemo_rl/data/collate_fn - pageId: nemo-rl/nemo_rl/data/collate_fn.mdx - - type: section - title: datasets - slug: nemo-rl/nemo_rl/data/datasets - children: - - type: section - title: eval_datasets - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets - children: - - type: section - title: aime - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime - children: - - type: page - title: aime - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime - pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx - - type: section - title: gpqa - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa - children: - - type: page - title: gpqa - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa - pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx - - type: section - title: local_math_dataset - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset - children: - - type: page - title: local_math_dataset - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset - pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx - - type: section - title: math - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/math - children: - - type: page - title: math - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/math - pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx - - type: section - title: mmlu - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu - children: - - type: page - title: mmlu - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu - pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx - - type: section - title: mmlu_pro - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro - children: - - type: page - title: mmlu_pro - slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro - pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx - - type: section - title: preference_datasets - slug: nemo-rl/nemo_rl/data/datasets/preference_datasets - children: - - type: section - title: binary_preference_dataset - slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset - children: - - type: page - title: binary_preference_dataset - slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset - pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx - - type: section - title: helpsteer3 - slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3 - children: - - type: page - title: helpsteer3 - slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3 - pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx - - type: section - title: preference_dataset - slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset - children: - - type: page - title: preference_dataset - slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset - pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx - - type: section - title: tulu3 - slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3 - children: - - type: page - title: tulu3 - slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3 - pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx - - type: section - title: processed_dataset - slug: nemo-rl/nemo_rl/data/datasets/processed_dataset - children: - - type: page - title: processed_dataset - slug: nemo-rl/nemo_rl/data/datasets/processed_dataset - pageId: nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx - - type: section - title: raw_dataset - slug: nemo-rl/nemo_rl/data/datasets/raw_dataset - children: - - type: page - title: raw_dataset - slug: nemo-rl/nemo_rl/data/datasets/raw_dataset - pageId: nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx - - type: section - title: response_datasets - slug: nemo-rl/nemo_rl/data/datasets/response_datasets - children: - - type: section - title: aime24 - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24 - children: - - type: page - title: aime24 - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24 - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx - - type: section - title: clevr - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr - children: - - type: page - title: clevr - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx - - type: section - title: dapo_math - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math - children: - - type: page - title: dapo_math - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx - - type: section - title: deepscaler - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler - children: - - type: page - title: deepscaler - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx - - type: section - title: geometry3k - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k - children: - - type: page - title: geometry3k - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx - - type: section - title: helpsteer3 - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3 - children: - - type: page - title: helpsteer3 - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3 - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx - - type: section - title: nemogym_dataset - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset - children: - - type: page - title: nemogym_dataset - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx - - type: section - title: oai_format_dataset - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset - children: - - type: page - title: oai_format_dataset - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx - - type: section - title: oasst - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst - children: - - type: page - title: oasst - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx - - type: section - title: openmathinstruct2 - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2 - children: - - type: page - title: openmathinstruct2 - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2 - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx - - type: section - title: refcoco - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco - children: - - type: page - title: refcoco - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx - - type: section - title: response_dataset - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset - children: - - type: page - title: response_dataset - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx - - type: section - title: squad - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/squad - children: - - type: page - title: squad - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/squad - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx - - type: section - title: tulu3 - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3 - children: - - type: page - title: tulu3 - slug: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3 - pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx - - type: section - title: utils - slug: nemo-rl/nemo_rl/data/datasets/utils - children: - - type: page - title: utils - slug: nemo-rl/nemo_rl/data/datasets/utils - pageId: nemo-rl/nemo_rl/data/datasets/utils.mdx - - type: section - title: interfaces - slug: nemo-rl/nemo_rl/data/interfaces - children: - - type: page - title: interfaces - slug: nemo-rl/nemo_rl/data/interfaces - pageId: nemo-rl/nemo_rl/data/interfaces.mdx - - type: section - title: llm_message_utils - slug: nemo-rl/nemo_rl/data/llm_message_utils - children: - - type: page - title: llm_message_utils - slug: nemo-rl/nemo_rl/data/llm_message_utils - pageId: nemo-rl/nemo_rl/data/llm_message_utils.mdx - - type: section - title: multimodal_utils - slug: nemo-rl/nemo_rl/data/multimodal_utils - children: - - type: page - title: multimodal_utils - slug: nemo-rl/nemo_rl/data/multimodal_utils - pageId: nemo-rl/nemo_rl/data/multimodal_utils.mdx - - type: section - title: packing - slug: nemo-rl/nemo_rl/data/packing - children: - - type: section - title: algorithms - slug: nemo-rl/nemo_rl/data/packing/algorithms - children: - - type: page - title: algorithms - slug: nemo-rl/nemo_rl/data/packing/algorithms - pageId: nemo-rl/nemo_rl/data/packing/algorithms.mdx - - type: section - title: metrics - slug: nemo-rl/nemo_rl/data/packing/metrics - children: - - type: page - title: metrics - slug: nemo-rl/nemo_rl/data/packing/metrics - pageId: nemo-rl/nemo_rl/data/packing/metrics.mdx - - type: section - title: processors - slug: nemo-rl/nemo_rl/data/processors - children: - - type: page - title: processors - slug: nemo-rl/nemo_rl/data/processors - pageId: nemo-rl/nemo_rl/data/processors.mdx - - type: section - title: utils - slug: nemo-rl/nemo_rl/data/utils - children: - - type: page - title: utils - slug: nemo-rl/nemo_rl/data/utils - pageId: nemo-rl/nemo_rl/data/utils.mdx -- type: section - title: distributed - slug: nemo-rl/nemo_rl/distributed - children: - - type: section - title: batched_data_dict - slug: nemo-rl/nemo_rl/distributed/batched_data_dict - children: - - type: page - title: batched_data_dict - slug: nemo-rl/nemo_rl/distributed/batched_data_dict - pageId: nemo-rl/nemo_rl/distributed/batched_data_dict.mdx - - type: section - title: collectives - slug: nemo-rl/nemo_rl/distributed/collectives - children: - - type: page - title: collectives - slug: nemo-rl/nemo_rl/distributed/collectives - pageId: nemo-rl/nemo_rl/distributed/collectives.mdx - - type: section - title: model_utils - slug: nemo-rl/nemo_rl/distributed/model_utils - children: - - type: page - title: model_utils - slug: nemo-rl/nemo_rl/distributed/model_utils - pageId: nemo-rl/nemo_rl/distributed/model_utils.mdx - - type: section - title: named_sharding - slug: nemo-rl/nemo_rl/distributed/named_sharding - children: - - type: page - title: named_sharding - slug: nemo-rl/nemo_rl/distributed/named_sharding - pageId: nemo-rl/nemo_rl/distributed/named_sharding.mdx - - type: section - title: ray_actor_environment_registry - slug: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry - children: - - type: page - title: ray_actor_environment_registry - slug: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry - pageId: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx - - type: section - title: stateless_process_group - slug: nemo-rl/nemo_rl/distributed/stateless_process_group - children: - - type: page - title: stateless_process_group - slug: nemo-rl/nemo_rl/distributed/stateless_process_group - pageId: nemo-rl/nemo_rl/distributed/stateless_process_group.mdx - - type: section - title: virtual_cluster - slug: nemo-rl/nemo_rl/distributed/virtual_cluster - children: - - type: page - title: virtual_cluster - slug: nemo-rl/nemo_rl/distributed/virtual_cluster - pageId: nemo-rl/nemo_rl/distributed/virtual_cluster.mdx - - type: section - title: worker_group_utils - slug: nemo-rl/nemo_rl/distributed/worker_group_utils - children: - - type: page - title: worker_group_utils - slug: nemo-rl/nemo_rl/distributed/worker_group_utils - pageId: nemo-rl/nemo_rl/distributed/worker_group_utils.mdx - - type: section - title: worker_groups - slug: nemo-rl/nemo_rl/distributed/worker_groups - children: - - type: page - title: worker_groups - slug: nemo-rl/nemo_rl/distributed/worker_groups - pageId: nemo-rl/nemo_rl/distributed/worker_groups.mdx -- type: section - title: environments - slug: nemo-rl/nemo_rl/environments - children: - - type: section - title: code_environment - slug: nemo-rl/nemo_rl/environments/code_environment - children: - - type: page - title: code_environment - slug: nemo-rl/nemo_rl/environments/code_environment - pageId: nemo-rl/nemo_rl/environments/code_environment.mdx - - type: section - title: code_jaccard_environment - slug: nemo-rl/nemo_rl/environments/code_jaccard_environment - children: - - type: page - title: code_jaccard_environment - slug: nemo-rl/nemo_rl/environments/code_jaccard_environment - pageId: nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx - - type: section - title: dapo_math_verifier - slug: nemo-rl/nemo_rl/environments/dapo_math_verifier - children: - - type: page - title: dapo_math_verifier - slug: nemo-rl/nemo_rl/environments/dapo_math_verifier - pageId: nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx - - type: section - title: interfaces - slug: nemo-rl/nemo_rl/environments/interfaces - children: - - type: page - title: interfaces - slug: nemo-rl/nemo_rl/environments/interfaces - pageId: nemo-rl/nemo_rl/environments/interfaces.mdx - - type: section - title: math_environment - slug: nemo-rl/nemo_rl/environments/math_environment - children: - - type: page - title: math_environment - slug: nemo-rl/nemo_rl/environments/math_environment - pageId: nemo-rl/nemo_rl/environments/math_environment.mdx - - type: section - title: metrics - slug: nemo-rl/nemo_rl/environments/metrics - children: - - type: page - title: metrics - slug: nemo-rl/nemo_rl/environments/metrics - pageId: nemo-rl/nemo_rl/environments/metrics.mdx - - type: section - title: nemo_gym - slug: nemo-rl/nemo_rl/environments/nemo_gym - children: - - type: page - title: nemo_gym - slug: nemo-rl/nemo_rl/environments/nemo_gym - pageId: nemo-rl/nemo_rl/environments/nemo_gym.mdx - - type: section - title: reward_model_environment - slug: nemo-rl/nemo_rl/environments/reward_model_environment - children: - - type: page - title: reward_model_environment - slug: nemo-rl/nemo_rl/environments/reward_model_environment - pageId: nemo-rl/nemo_rl/environments/reward_model_environment.mdx - - type: section - title: rewards - slug: nemo-rl/nemo_rl/environments/rewards - children: - - type: page - title: rewards - slug: nemo-rl/nemo_rl/environments/rewards - pageId: nemo-rl/nemo_rl/environments/rewards.mdx - - type: section - title: utils - slug: nemo-rl/nemo_rl/environments/utils - children: - - type: page - title: utils - slug: nemo-rl/nemo_rl/environments/utils - pageId: nemo-rl/nemo_rl/environments/utils.mdx - - type: section - title: vlm_environment - slug: nemo-rl/nemo_rl/environments/vlm_environment - children: - - type: page - title: vlm_environment - slug: nemo-rl/nemo_rl/environments/vlm_environment - pageId: nemo-rl/nemo_rl/environments/vlm_environment.mdx -- type: section - title: evals - slug: nemo-rl/nemo_rl/evals - children: - - type: section - title: answer_parsing - slug: nemo-rl/nemo_rl/evals/answer_parsing - children: - - type: page - title: answer_parsing - slug: nemo-rl/nemo_rl/evals/answer_parsing - pageId: nemo-rl/nemo_rl/evals/answer_parsing.mdx - - type: section - title: eval - slug: nemo-rl/nemo_rl/evals/eval - children: - - type: page - title: eval - slug: nemo-rl/nemo_rl/evals/eval - pageId: nemo-rl/nemo_rl/evals/eval.mdx -- type: section - title: experience - slug: nemo-rl/nemo_rl/experience - children: - - type: section - title: rollouts - slug: nemo-rl/nemo_rl/experience/rollouts - children: - - type: page - title: rollouts - slug: nemo-rl/nemo_rl/experience/rollouts - pageId: nemo-rl/nemo_rl/experience/rollouts.mdx -- type: section - title: models - slug: nemo-rl/nemo_rl/models - children: - - type: section - title: automodel - slug: nemo-rl/nemo_rl/models/automodel - children: - - type: section - title: config - slug: nemo-rl/nemo_rl/models/automodel/config - children: - - type: page - title: config - slug: nemo-rl/nemo_rl/models/automodel/config - pageId: nemo-rl/nemo_rl/models/automodel/config.mdx - - type: section - title: data - slug: nemo-rl/nemo_rl/models/automodel/data - children: - - type: page - title: data - slug: nemo-rl/nemo_rl/models/automodel/data - pageId: nemo-rl/nemo_rl/models/automodel/data.mdx - - type: section - title: setup - slug: nemo-rl/nemo_rl/models/automodel/setup - children: - - type: page - title: setup - slug: nemo-rl/nemo_rl/models/automodel/setup - pageId: nemo-rl/nemo_rl/models/automodel/setup.mdx - - type: section - title: train - slug: nemo-rl/nemo_rl/models/automodel/train - children: - - type: page - title: train - slug: nemo-rl/nemo_rl/models/automodel/train - pageId: nemo-rl/nemo_rl/models/automodel/train.mdx - - type: section - title: dtensor - slug: nemo-rl/nemo_rl/models/dtensor - children: - - type: section - title: parallelize - slug: nemo-rl/nemo_rl/models/dtensor/parallelize - children: - - type: page - title: parallelize - slug: nemo-rl/nemo_rl/models/dtensor/parallelize - pageId: nemo-rl/nemo_rl/models/dtensor/parallelize.mdx - - type: section - title: generation - slug: nemo-rl/nemo_rl/models/generation - children: - - type: section - title: interfaces - slug: nemo-rl/nemo_rl/models/generation/interfaces - children: - - type: page - title: interfaces - slug: nemo-rl/nemo_rl/models/generation/interfaces - pageId: nemo-rl/nemo_rl/models/generation/interfaces.mdx - - type: section - title: sglang - slug: nemo-rl/nemo_rl/models/generation/sglang - children: - - type: section - title: config - slug: nemo-rl/nemo_rl/models/generation/sglang/config - children: - - type: page - title: config - slug: nemo-rl/nemo_rl/models/generation/sglang/config - pageId: nemo-rl/nemo_rl/models/generation/sglang/config.mdx - - type: section - title: sglang_copied_utils - slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils - children: - - type: page - title: sglang_copied_utils - slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils - pageId: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx - - type: section - title: sglang_generation - slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation - children: - - type: page - title: sglang_generation - slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation - pageId: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx - - type: section - title: sglang_worker - slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker - children: - - type: page - title: sglang_worker - slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker - pageId: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx - - type: section - title: utils - slug: nemo-rl/nemo_rl/models/generation/sglang/utils - children: - - type: page - title: utils - slug: nemo-rl/nemo_rl/models/generation/sglang/utils - pageId: nemo-rl/nemo_rl/models/generation/sglang/utils.mdx - - type: section - title: vllm - slug: nemo-rl/nemo_rl/models/generation/vllm - children: - - type: section - title: config - slug: nemo-rl/nemo_rl/models/generation/vllm/config - children: - - type: page - title: config - slug: nemo-rl/nemo_rl/models/generation/vllm/config - pageId: nemo-rl/nemo_rl/models/generation/vllm/config.mdx - - type: section - title: utils - slug: nemo-rl/nemo_rl/models/generation/vllm/utils - children: - - type: page - title: utils - slug: nemo-rl/nemo_rl/models/generation/vllm/utils - pageId: nemo-rl/nemo_rl/models/generation/vllm/utils.mdx - - type: section - title: vllm_backend - slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend - children: - - type: page - title: vllm_backend - slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend - pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx - - type: section - title: vllm_generation - slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation - children: - - type: page - title: vllm_generation - slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation - pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx - - type: section - title: vllm_worker - slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker - children: - - type: page - title: vllm_worker - slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker - pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx - - type: section - title: vllm_worker_async - slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async - children: - - type: page - title: vllm_worker_async - slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async - pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx - - type: section - title: huggingface - slug: nemo-rl/nemo_rl/models/huggingface - children: - - type: section - title: common - slug: nemo-rl/nemo_rl/models/huggingface/common - children: - - type: page - title: common - slug: nemo-rl/nemo_rl/models/huggingface/common - pageId: nemo-rl/nemo_rl/models/huggingface/common.mdx - - type: section - title: megatron - slug: nemo-rl/nemo_rl/models/megatron - children: - - type: section - title: common - slug: nemo-rl/nemo_rl/models/megatron/common - children: - - type: page - title: common - slug: nemo-rl/nemo_rl/models/megatron/common - pageId: nemo-rl/nemo_rl/models/megatron/common.mdx - - type: section - title: community_import - slug: nemo-rl/nemo_rl/models/megatron/community_import - children: - - type: page - title: community_import - slug: nemo-rl/nemo_rl/models/megatron/community_import - pageId: nemo-rl/nemo_rl/models/megatron/community_import.mdx - - type: section - title: config - slug: nemo-rl/nemo_rl/models/megatron/config - children: - - type: page - title: config - slug: nemo-rl/nemo_rl/models/megatron/config - pageId: nemo-rl/nemo_rl/models/megatron/config.mdx - - type: section - title: data - slug: nemo-rl/nemo_rl/models/megatron/data - children: - - type: page - title: data - slug: nemo-rl/nemo_rl/models/megatron/data - pageId: nemo-rl/nemo_rl/models/megatron/data.mdx - - type: section - title: setup - slug: nemo-rl/nemo_rl/models/megatron/setup - children: - - type: page - title: setup - slug: nemo-rl/nemo_rl/models/megatron/setup - pageId: nemo-rl/nemo_rl/models/megatron/setup.mdx - - type: section - title: policy - slug: nemo-rl/nemo_rl/models/policy - children: - - type: section - title: interfaces - slug: nemo-rl/nemo_rl/models/policy/interfaces - children: - - type: page - title: interfaces - slug: nemo-rl/nemo_rl/models/policy/interfaces - pageId: nemo-rl/nemo_rl/models/policy/interfaces.mdx - - type: section - title: lm_policy - slug: nemo-rl/nemo_rl/models/policy/lm_policy - children: - - type: page - title: lm_policy - slug: nemo-rl/nemo_rl/models/policy/lm_policy - pageId: nemo-rl/nemo_rl/models/policy/lm_policy.mdx - - type: section - title: utils - slug: nemo-rl/nemo_rl/models/policy/utils - children: - - type: page - title: utils - slug: nemo-rl/nemo_rl/models/policy/utils - pageId: nemo-rl/nemo_rl/models/policy/utils.mdx - - type: section - title: workers - slug: nemo-rl/nemo_rl/models/policy/workers - children: - - type: section - title: base_policy_worker - slug: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker - children: - - type: page - title: base_policy_worker - slug: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker - pageId: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx - - type: section - title: dtensor_policy_worker - slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker - children: - - type: page - title: dtensor_policy_worker - slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker - pageId: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx - - type: section - title: dtensor_policy_worker_v2 - slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2 - children: - - type: page - title: dtensor_policy_worker_v2 - slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2 - pageId: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx - - type: section - title: megatron_policy_worker - slug: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker - children: - - type: page - title: megatron_policy_worker - slug: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker - pageId: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx - - type: section - title: patches - slug: nemo-rl/nemo_rl/models/policy/workers/patches - children: - - type: page - title: patches - slug: nemo-rl/nemo_rl/models/policy/workers/patches - pageId: nemo-rl/nemo_rl/models/policy/workers/patches.mdx -- type: section - title: package_info - slug: nemo-rl/nemo_rl/package_info - children: - - type: page - title: package_info - slug: nemo-rl/nemo_rl/package_info - pageId: nemo-rl/nemo_rl/package_info.mdx -- type: section - title: utils - slug: nemo-rl/nemo_rl/utils - children: - - type: section - title: automodel_checkpoint - slug: nemo-rl/nemo_rl/utils/automodel_checkpoint - children: - - type: page - title: automodel_checkpoint - slug: nemo-rl/nemo_rl/utils/automodel_checkpoint - pageId: nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx - - type: section - title: checkpoint - slug: nemo-rl/nemo_rl/utils/checkpoint - children: - - type: page - title: checkpoint - slug: nemo-rl/nemo_rl/utils/checkpoint - pageId: nemo-rl/nemo_rl/utils/checkpoint.mdx - - type: section - title: config - slug: nemo-rl/nemo_rl/utils/config - children: - - type: page - title: config - slug: nemo-rl/nemo_rl/utils/config - pageId: nemo-rl/nemo_rl/utils/config.mdx - - type: section - title: flops_formulas - slug: nemo-rl/nemo_rl/utils/flops_formulas - children: - - type: page - title: flops_formulas - slug: nemo-rl/nemo_rl/utils/flops_formulas - pageId: nemo-rl/nemo_rl/utils/flops_formulas.mdx - - type: section - title: flops_tracker - slug: nemo-rl/nemo_rl/utils/flops_tracker - children: - - type: page - title: flops_tracker - slug: nemo-rl/nemo_rl/utils/flops_tracker - pageId: nemo-rl/nemo_rl/utils/flops_tracker.mdx - - type: section - title: logger - slug: nemo-rl/nemo_rl/utils/logger - children: - - type: page - title: logger - slug: nemo-rl/nemo_rl/utils/logger - pageId: nemo-rl/nemo_rl/utils/logger.mdx - - type: section - title: memory_tracker - slug: nemo-rl/nemo_rl/utils/memory_tracker - children: - - type: page - title: memory_tracker - slug: nemo-rl/nemo_rl/utils/memory_tracker - pageId: nemo-rl/nemo_rl/utils/memory_tracker.mdx - - type: section - title: native_checkpoint - slug: nemo-rl/nemo_rl/utils/native_checkpoint - children: - - type: page - title: native_checkpoint - slug: nemo-rl/nemo_rl/utils/native_checkpoint - pageId: nemo-rl/nemo_rl/utils/native_checkpoint.mdx - - type: section - title: nsys - slug: nemo-rl/nemo_rl/utils/nsys - children: - - type: page - title: nsys - slug: nemo-rl/nemo_rl/utils/nsys - pageId: nemo-rl/nemo_rl/utils/nsys.mdx - - type: section - title: nvml - slug: nemo-rl/nemo_rl/utils/nvml - children: - - type: page - title: nvml - slug: nemo-rl/nemo_rl/utils/nvml - pageId: nemo-rl/nemo_rl/utils/nvml.mdx - - type: section - title: packed_tensor - slug: nemo-rl/nemo_rl/utils/packed_tensor - children: - - type: page - title: packed_tensor - slug: nemo-rl/nemo_rl/utils/packed_tensor - pageId: nemo-rl/nemo_rl/utils/packed_tensor.mdx - - type: section - title: prefetch_venvs - slug: nemo-rl/nemo_rl/utils/prefetch_venvs - children: - - type: page - title: prefetch_venvs - slug: nemo-rl/nemo_rl/utils/prefetch_venvs - pageId: nemo-rl/nemo_rl/utils/prefetch_venvs.mdx - - type: section - title: timer - slug: nemo-rl/nemo_rl/utils/timer - children: - - type: page - title: timer - slug: nemo-rl/nemo_rl/utils/timer - pageId: nemo-rl/nemo_rl/utils/timer.mdx - - type: section - title: venvs - slug: nemo-rl/nemo_rl/utils/venvs - children: - - type: page - title: venvs - slug: nemo-rl/nemo_rl/utils/venvs - pageId: nemo-rl/nemo_rl/utils/venvs.mdx diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl.mdx deleted file mode 100644 index 002c19d..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl.mdx +++ /dev/null @@ -1,149 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl -title: nemo_rl ---- - -## Subpackages - -- **[`nemo_rl.algorithms`](/nemo-rl/nemo_rl/algorithms)** -- **[`nemo_rl.data`](/nemo-rl/nemo_rl/data)** -- **[`nemo_rl.distributed`](/nemo-rl/nemo_rl/distributed)** -- **[`nemo_rl.environments`](/nemo-rl/nemo_rl/environments)** -- **[`nemo_rl.evals`](/nemo-rl/nemo_rl/evals)** -- **[`nemo_rl.experience`](/nemo-rl/nemo_rl/experience)** -- **[`nemo_rl.models`](/nemo-rl/nemo_rl/models)** -- **[`nemo_rl.utils`](/nemo-rl/nemo_rl/utils)** - -## Submodules - -- **[`nemo_rl.package_info`](/nemo-rl/nemo_rl/package_info)** - -## Package Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`_check_container_fingerprint`](#nemo_rl-_check_container_fingerprint) | Check if container dependencies match the current code (container-only). | -| [`_is_build_isolation`](#nemo_rl-_is_build_isolation) | Detect if we're running in a uv build isolation environment. | -| [`_patch_nsight_file`](#nemo_rl-_patch_nsight_file) | Patch the nsight.py file to fix the context.py_executable assignment. | -| [`patch_transformers_module_dir`](#nemo_rl-patch_transformers_module_dir) | - | - -### Data - -[`megatron_path`](#nemo_rl-megatron_path) - -### API - - - - - -```python -nemo_rl._check_container_fingerprint() -``` - - - - - - -Check if container dependencies match the current code (container-only). - -This check only runs when NRL_CONTAINER=1 is set (inside containers). -It compares the container's fingerprint (computed at build time) with -the current code's fingerprint to detect dependency drift. - -This check is also skipped entirely if NRL_FORCE_REBUILD_VENVS=true is set, -since environment rebuilding will ensure dependencies are consistent regardless -of a mismatch. - -If there's a mismatch, raises RuntimeError unless NRL_IGNORE_VERSION_MISMATCH is set. - - - - - - - - -```python -nemo_rl._is_build_isolation() -``` - - - - - - -Detect if we're running in a uv build isolation environment. - -When running uv lock/sync, uv creates a temporary isolated environment -in ~/.cache/uv/builds-v*/ to build packages and introspect metadata. -We skip the fingerprint check in this context since the user is updating dependencies. - -Returns True if in build isolation, False otherwise. - - - - - - - - -```python -nemo_rl._patch_nsight_file() -``` - - - - - - -Patch the nsight.py file to fix the context.py_executable assignment. - -Until this fix is upstreamed, we will maintain this patch here. This patching -logic is only applied if the user intends to use nsys profiling which they enable with -NRL_NSYS_WORKER_PATTERNS. - -If enabled, will effectively apply the following patch in an idempotent manner: - -https://github.com/ray-project/ray/compare/master...terrykong:ray:tk/nsight-py-exeutable-fix?expand=1 - -This hack works b/c the nsight plugin is not called from the main driver process, so -as soon as nemo_rl is imported, the patch is applied and the source of the nsight.py module -is up to date before the nsight.py is actually needed. - - - - - - - - -```python -nemo_rl.patch_transformers_module_dir( - env_vars: dict[str, str] -) -``` - - - - - - - - - - - - - -```python -nemo_rl.megatron_path = Path(__file__).parent.parent / '3rdparty' / 'Megatron-LM-workspace' / 'Megatron-... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx deleted file mode 100644 index 7f03746..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx +++ /dev/null @@ -1,19 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms -title: nemo_rl.algorithms ---- - -## Submodules - -- **[`nemo_rl.algorithms.advantage_estimator`](/nemo-rl/nemo_rl/algorithms/advantage_estimator)** -- **[`nemo_rl.algorithms.async_utils`](/nemo-rl/nemo_rl/algorithms/async_utils)** -- **[`nemo_rl.algorithms.distillation`](/nemo-rl/nemo_rl/algorithms/distillation)** -- **[`nemo_rl.algorithms.dpo`](/nemo-rl/nemo_rl/algorithms/dpo)** -- **[`nemo_rl.algorithms.grpo`](/nemo-rl/nemo_rl/algorithms/grpo)** -- **[`nemo_rl.algorithms.interfaces`](/nemo-rl/nemo_rl/algorithms/interfaces)** -- **[`nemo_rl.algorithms.loss_functions`](/nemo-rl/nemo_rl/algorithms/loss_functions)** -- **[`nemo_rl.algorithms.reward_functions`](/nemo-rl/nemo_rl/algorithms/reward_functions)** -- **[`nemo_rl.algorithms.rm`](/nemo-rl/nemo_rl/algorithms/rm)** -- **[`nemo_rl.algorithms.sft`](/nemo-rl/nemo_rl/algorithms/sft)** -- **[`nemo_rl.algorithms.utils`](/nemo-rl/nemo_rl/algorithms/utils)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx deleted file mode 100644 index 0841909..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx +++ /dev/null @@ -1,196 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms/advantage_estimator -title: nemo_rl.algorithms.advantage_estimator ---- - -Advantage Estimators for RL algorithms. - -This module provides different advantage estimation strategies: -- GRPOAdvantageEstimator: Standard GRPO advantage with leave-one-out baseline -- ReinforcePlusPlusAdvantageEstimator: Reinforce++ with optional baseline subtraction (minus_baseline) and KL penalty in reward -Reference papers: -- ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/ -- Reinforce++: https://arxiv.org/abs/2501.03262 - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`GRPOAdvantageEstimator`](#nemo_rl-algorithms-advantage_estimator-GRPOAdvantageEstimator) | GRPO-style advantage estimator with leave-one-out baseline. | -| [`ReinforcePlusPlusAdvantageEstimator`](#nemo_rl-algorithms-advantage_estimator-ReinforcePlusPlusAdvantageEstimator) | Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward. | - -### API - - - - - -```python -class nemo_rl.algorithms.advantage_estimator.GRPOAdvantageEstimator( - estimator_config: dict, - loss_config: dict -) -``` - - - - - - -GRPO-style advantage estimator with leave-one-out baseline. - -Note: GRPO computes advantages over all responses for each prompt. - - - - - - - - - - - -```python -nemo_rl.algorithms.advantage_estimator.GRPOAdvantageEstimator.compute_advantage( - prompt_ids, - rewards, - mask, - kwargs = {} -) -``` - - - - - - -Compute GRPO advantages. - -**Parameters:** - - -Tensor of shape [batch_size] identifying which prompt each sample belongs to. - - - -Tensor of shape [batch_size] containing reward for each sample. - - - -Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. - Used only for expanding advantages to token-level shape. - - - -Additional arguments (unused). - - -**Returns:** - -Advantages tensor of shape [batch_size, seq_len]. - - - - - - - - - -```python -class nemo_rl.algorithms.advantage_estimator.ReinforcePlusPlusAdvantageEstimator( - estimator_config: dict, - loss_config: dict -) -``` - - - - - - -Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward. - -**Parameters:** - - -If True, subtract per-prompt mean baseline from rewards. - - - -If True, add KL penalty to reward instead of loss. - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.advantage_estimator.ReinforcePlusPlusAdvantageEstimator.compute_advantage( - prompt_ids, - rewards, - mask, - logprobs_policy = None, - logprobs_reference = None, - kwargs = {} -) -``` - - - - - - -Compute Reinforce++ advantages with optional KL penalty. - -**Parameters:** - - -Tensor of shape [batch_size] identifying which prompt each sample belongs to. - - - -Tensor of shape [batch_size] containing reward for each sample. - - - -Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. - Used for: (1) expanding advantages to token-level shape, (2) global normalization - that only considers valid tokens. - - - -Policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward. - - - -Reference policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward. - - - -Additional arguments (unused). - - -**Returns:** - -Advantages tensor of shape [batch_size, seq_len], globally normalized across valid tokens. - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx deleted file mode 100644 index f9ad506..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx +++ /dev/null @@ -1,572 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms/async_utils -title: nemo_rl.algorithms.async_utils ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AsyncTrajectoryCollector`](#nemo_rl-algorithms-async_utils-AsyncTrajectoryCollector) | Collects trajectories asynchronously and adds them to replay buffer. | -| [`ReplayBuffer`](#nemo_rl-algorithms-async_utils-ReplayBuffer) | Replay buffer storing per-prompt groups. | - -### Data - -[`TokenizerType`](#nemo_rl-algorithms-async_utils-TokenizerType) - -### API - - - - - -```python -class nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector( - policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, - tokenizer: nemo_rl.algorithms.async_utils.TokenizerType, - task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], - master_config: nemo_rl.algorithms.grpo.MasterConfig, - replay_buffer: typing.Any, - start_step: int = 0 -) -``` - - - - - - -Collects trajectories asynchronously and adds them to replay buffer. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._calculate_target_weights( - generation_weight_version: int -) -> list[int] -``` - - - - - - -Calculate target weight versions for given generation weight version. - -The list of versions returned enumerate the possible version a generation -server can target. These versions are looped over to see what training -step they can target. If all target versions are exhausted, this generation -server will remain idle until the next weight update. - -Example: -generation_weight_version = 10 -max_trajectory_age_steps = 4 - -**Returns:** `list[int]` - -[11, 12, 13, 14] # Meaning this generation server can create trajectories for training step 11, 12, 13, 14 - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._cleanup_finished_threads() -> None -``` - - - - - - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._collection_loop() -``` - - - - - - -Run the collection loop in background thread. - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._get_next_target_for_generation( - generation_weight_version: int -) -> typing.Optional[int] -``` - - - - - - -Get the next target weight that needs generation (if any). - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._process_batch( - batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] -) -> None -``` - - - - - - -Process a single batch and generate for one target weight. - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._run_prompt_group_worker( - repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], - generation_weight_version: int, - target_weight_version: int, - prompt_idx: int -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._should_pause_for_generation_limits() -> bool -``` - - - - - - -Check if collection should be paused due to generation limits. - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.get_dataloader_state() -> dict -``` - - - - - - -Get the current dataloader state for checkpointing. - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.get_weight_version() -> int -``` - - - - - - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.pause() -> None -``` - - - - - - -Pause trajectory collection. - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.prepare_for_refit() -> None -``` - - - - - - -Pause new generation starts and optionally wait for pending generations. - -For vLLM V1 async engine, leverages in-flight weight updates via collective_rpc, -allowing ongoing generations to continue with their current KV caches while -weights are updated. This significantly improves async performance. - -For non-async engines, waits for all pending generations to complete before refit. - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.resume() -> None -``` - - - - - - -Resume trajectory collection. - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.resume_after_refit() -> None -``` - - - - - - -Resume new generation starts after refit is complete. - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.set_weight_version( - version: int -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.start_collection( - dataloader: torchdata.stateful_dataloader.StatefulDataLoader -) -> None -``` - - - - - - -Start collecting trajectories from dataloader. - - - - - - - -```python -nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.wait_for_pending_generations() -> None -``` - - - - - - -Wait for all in-flight generation threads to complete. - - - - - - - - - -```python -class nemo_rl.algorithms.async_utils.ReplayBuffer( - max_size: int -) -``` - - - - - - -Replay buffer storing per-prompt groups. - -A single entry corresponds to 1 prompt repeated by -grpo.num_generations_per_prompt (required to compute per-prompt advantages). - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.async_utils.ReplayBuffer.clear() -> None -``` - - - - - - -Clear the buffer. - - - - - - - -```python -nemo_rl.algorithms.async_utils.ReplayBuffer.get_debug_info() -> dict -``` - - - - - - -Get debug information about buffer state. - - - - - - - -```python -nemo_rl.algorithms.async_utils.ReplayBuffer.get_existing_target_weights() -> set[int] -``` - - - - - - -Get set of target weight versions that already have trajectories. - - - - - - - -```python -nemo_rl.algorithms.async_utils.ReplayBuffer.get_last_target_weight_already_generated() -> int -``` - - - - - - - - - - - - -```python -nemo_rl.algorithms.async_utils.ReplayBuffer.push_with_wait_signal( - trajectory: dict[str, typing.Any], - weight_version: int, - target_weight_version: int -) -> str -``` - - - - - - -Add a per-prompt trajectory group with metadata. - -**Parameters:** - - -data dict - - - -version of the model weights used for generation - - - -version of the model weights this trajectory is intended for training - - - - - - - - -```python -nemo_rl.algorithms.async_utils.ReplayBuffer.sample( - num_prompt_groups: int, - current_weight_version: int, - max_age_steps: int -) -> typing.Optional[dict[str, typing.Any]] -``` - - - - - - -Sample per-prompt trajectory groups intended for the current training step. - -Only returns trajectories with target_weight_version == current_weight_version. -If insufficient trajectories are available, returns None to stall training -until the remaining trajectories are generated. This ensures no trajectory -loses its last chance to be used for its intended training step. - -**Returns:** `Optional[dict[str, Any]]` - -Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None if insufficient data - - - - - - - -```python -nemo_rl.algorithms.async_utils.ReplayBuffer.size() -> int -``` - - - - - - -Return current buffer size. - - - - - - - - - -```python -nemo_rl.algorithms.async_utils.TokenizerType = PreTrainedTokenizerBase -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx deleted file mode 100644 index 2dede47..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx +++ /dev/null @@ -1,326 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms/distillation -title: nemo_rl.algorithms.distillation ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`DistillationConfig`](#nemo_rl-algorithms-distillation-DistillationConfig) | - | -| [`DistillationSaveState`](#nemo_rl-algorithms-distillation-DistillationSaveState) | - | -| [`MasterConfig`](#nemo_rl-algorithms-distillation-MasterConfig) | Main configuration structure. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_default_distillation_save_state`](#nemo_rl-algorithms-distillation-_default_distillation_save_state) | - | -| [`check_vocab_equality`](#nemo_rl-algorithms-distillation-check_vocab_equality) | Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal. | -| [`distillation_train`](#nemo_rl-algorithms-distillation-distillation_train) | Run Distillation training algorithm. | -| [`setup`](#nemo_rl-algorithms-distillation-setup) | Main entry point for distillation algorithm. | -| [`validate`](#nemo_rl-algorithms-distillation-validate) | Run validation on the validation dataset. | - -### Data - -[`TokenizerType`](#nemo_rl-algorithms-distillation-TokenizerType) - -### API - - - - - -```python -class nemo_rl.algorithms.distillation.DistillationConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.distillation.DistillationSaveState -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.distillation.MasterConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Main configuration structure. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.distillation._default_distillation_save_state() -> nemo_rl.algorithms.distillation.DistillationSaveState -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.distillation.check_vocab_equality( - tokenizer: nemo_rl.algorithms.distillation.TokenizerType, - student_model_name: str, - teacher_model_name: str -) -> None -``` - - - - - - -Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal. - - - - - - - - -```python -nemo_rl.algorithms.distillation.distillation_train( - student_policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, - teacher_policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, - student_generation: typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], - dataloader: torchdata.stateful_dataloader.StatefulDataLoader, - val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], - tokenizer: nemo_rl.algorithms.distillation.TokenizerType, - loss_fn: nemo_rl.algorithms.loss_functions.DistillationLossFn, - task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], - val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], - logger: nemo_rl.utils.logger.Logger, - checkpointer: nemo_rl.utils.checkpoint.CheckpointManager, - distillation_save_state: nemo_rl.algorithms.distillation.DistillationSaveState, - master_config: nemo_rl.algorithms.distillation.MasterConfig -) -> None -``` - - - - - - -Run Distillation training algorithm. - - - - - - - - -```python -nemo_rl.algorithms.distillation.setup( - master_config: nemo_rl.algorithms.distillation.MasterConfig, - tokenizer: nemo_rl.algorithms.distillation.TokenizerType, - train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, - val_dataset: typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset] -) -> tuple[nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], torchdata.stateful_dataloader.StatefulDataLoader, typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.DistillationLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.distillation.DistillationSaveState, nemo_rl.algorithms.distillation.MasterConfig] -``` - - - - - - -Main entry point for distillation algorithm. - -**Returns:** `ColocatablePolicyInterface` - -tuple of student_policy, teacher_policy, student_generation, - - - - - - - - -```python -nemo_rl.algorithms.distillation.validate( - policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, - val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], - tokenizer, - val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], - step: int, - master_config: nemo_rl.algorithms.distillation.MasterConfig -) -> tuple[dict[str, typing.Any], dict[str, typing.Any]] -``` - - - - - - -Run validation on the validation dataset. - - - - - - - - -```python -nemo_rl.algorithms.distillation.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx deleted file mode 100644 index 3d57261..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx +++ /dev/null @@ -1,378 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms/dpo -title: nemo_rl.algorithms.dpo ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`DPOConfig`](#nemo_rl-algorithms-dpo-DPOConfig) | - | -| [`DPOSaveState`](#nemo_rl-algorithms-dpo-DPOSaveState) | - | -| [`DPOValMetrics`](#nemo_rl-algorithms-dpo-DPOValMetrics) | - | -| [`MasterConfig`](#nemo_rl-algorithms-dpo-MasterConfig) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`_default_dpo_save_state`](#nemo_rl-algorithms-dpo-_default_dpo_save_state) | - | -| [`add_ref_logprobs_to_data`](#nemo_rl-algorithms-dpo-add_ref_logprobs_to_data) | - | -| [`dpo_train`](#nemo_rl-algorithms-dpo-dpo_train) | - | -| [`setup`](#nemo_rl-algorithms-dpo-setup) | Main entry point for running DPO algorithm. | -| [`validate`](#nemo_rl-algorithms-dpo-validate) | - | -| [`validate_one_dataset`](#nemo_rl-algorithms-dpo-validate_one_dataset) | Run validation on one validation dataset. | - -### API - - - - - -```python -class nemo_rl.algorithms.dpo.DPOConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.dpo.DPOSaveState -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.dpo.DPOValMetrics -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.dpo.MasterConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.dpo._default_dpo_save_state() -> nemo_rl.algorithms.dpo.DPOSaveState -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.dpo.add_ref_logprobs_to_data( - dataloader, - policy, - master_config, - is_val = False -) -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.dpo.dpo_train( - policy, - train_dataloader, - val_dataloader, - tokenizer, - loss_fn, - master_config, - logger, - checkpointer, - dpo_save_state: nemo_rl.algorithms.dpo.DPOSaveState -) -> None -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.dpo.setup( - master_config: nemo_rl.algorithms.dpo.MasterConfig, - tokenizer: transformers.AutoTokenizer, - train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, - val_dataset: dict[str, nemo_rl.data.datasets.AllTaskProcessedDataset] -) -> tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, dict[str, torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.DPOLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.dpo.DPOSaveState, nemo_rl.algorithms.dpo.MasterConfig] -``` - - - - - - -Main entry point for running DPO algorithm. - -**Returns:** `tuple[Policy, RayVirtualCluster, StatefulDataLoader, dict[str, StatefulDataLoader], DPOLossFn, Logger, CheckpointManager, DPOSaveState, MasterConfig]` - -Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger - - - - - - - - -```python -nemo_rl.algorithms.dpo.validate( - policy: nemo_rl.models.policy.interfaces.PolicyInterface, - val_dataloader: dict[str, torchdata.stateful_dataloader.StatefulDataLoader], - tokenizer, - loss_fn, - step: int, - master_config: nemo_rl.algorithms.dpo.MasterConfig, - val_batches: int, - val_batch_size: int, - val_mbs: int, - logger: nemo_rl.utils.logger.Logger -) -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.dpo.validate_one_dataset( - policy: nemo_rl.models.policy.interfaces.PolicyInterface, - val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader, - loss_fn, - step: int, - master_config: nemo_rl.algorithms.dpo.MasterConfig, - val_batches: int, - val_batch_size: int, - val_mbs: int, - dataset_name: str -) -``` - - - - - - -Run validation on one validation dataset. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx deleted file mode 100644 index b8db0fe..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx +++ /dev/null @@ -1,864 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms/grpo -title: nemo_rl.algorithms.grpo ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AdvEstimatorConfig`](#nemo_rl-algorithms-grpo-AdvEstimatorConfig) | Configuration for advantage estimator (GRPO or Reinforce++). | -| [`AsyncGRPOConfig`](#nemo_rl-algorithms-grpo-AsyncGRPOConfig) | - | -| [`GRPOConfig`](#nemo_rl-algorithms-grpo-GRPOConfig) | - | -| [`GRPOLoggerConfig`](#nemo_rl-algorithms-grpo-GRPOLoggerConfig) | - | -| [`GRPOSaveState`](#nemo_rl-algorithms-grpo-GRPOSaveState) | - | -| [`MasterConfig`](#nemo_rl-algorithms-grpo-MasterConfig) | - | -| [`RewardScalingConfig`](#nemo_rl-algorithms-grpo-RewardScalingConfig) | Configure linear reward scaling with clamping. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_create_advantage_estimator`](#nemo_rl-algorithms-grpo-_create_advantage_estimator) | Create and return an advantage estimator based on configuration. | -| [`_default_grpo_save_state`](#nemo_rl-algorithms-grpo-_default_grpo_save_state) | - | -| [`_extract_prompt_only_messages`](#nemo_rl-algorithms-grpo-_extract_prompt_only_messages) | Extract only prompt messages (user/system) from message logs. | -| [`_log_mixed_rewards_and_advantages_information`](#nemo_rl-algorithms-grpo-_log_mixed_rewards_and_advantages_information) | - | -| [`_should_log_nemo_gym_responses`](#nemo_rl-algorithms-grpo-_should_log_nemo_gym_responses) | - | -| [`_should_use_async_rollouts`](#nemo_rl-algorithms-grpo-_should_use_async_rollouts) | Determine if async rollouts should be used based on the configuration. | -| [`_should_use_nemo_gym`](#nemo_rl-algorithms-grpo-_should_use_nemo_gym) | Determine if NeMo-Gym should be used for rollouts and validation based on the configuration. | -| [`async_grpo_train`](#nemo_rl-algorithms-grpo-async_grpo_train) | Run asynchronous GRPO training with replay buffer. | -| [`dynamic_sampling`](#nemo_rl-algorithms-grpo-dynamic_sampling) | Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation. | -| [`grpo_train`](#nemo_rl-algorithms-grpo-grpo_train) | Run GRPO training algorithm. | -| [`refit_policy_generation`](#nemo_rl-algorithms-grpo-refit_policy_generation) | Refit the policy generation interface with the latest policy weights. | -| [`scale_rewards`](#nemo_rl-algorithms-grpo-scale_rewards) | Linearly scales rewards from a source range to a target range. | -| [`setup`](#nemo_rl-algorithms-grpo-setup) | Main entry point for running GRPO algorithm. | -| [`validate`](#nemo_rl-algorithms-grpo-validate) | Run validation on the validation dataset. | - -### Data - -[`TokenizerType`](#nemo_rl-algorithms-grpo-TokenizerType) - -### API - - - - - -```python -class nemo_rl.algorithms.grpo.AdvEstimatorConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Configuration for advantage estimator (GRPO or Reinforce++). - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.grpo.AsyncGRPOConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.grpo.GRPOConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.grpo.GRPOLoggerConfig() -``` - - - - - - -**Bases:** [LoggerConfig](/nemo-rl/nemo_rl/utils/logger#nemo_rl-utils-logger-LoggerConfig) - - - - - - - - - -```python -class nemo_rl.algorithms.grpo.GRPOSaveState -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.grpo.MasterConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.grpo.RewardScalingConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Configure linear reward scaling with clamping. - -When `enabled` is True, each reward is clamped to the source interval -[source_min, source_max] and linearly mapped to the target interval -[target_min, target_max]. Refer to the scale_rewards function for the implementation. - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.grpo._create_advantage_estimator( - master_config: nemo_rl.algorithms.grpo.MasterConfig -) -``` - - - - - - -Create and return an advantage estimator based on configuration. - -**Parameters:** - - -The master configuration dictionary. - - -**Returns:** - -An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator). - -**Raises:** - -- `ValueError`: If the advantage estimator name is not recognized. - - - - - - - - -```python -nemo_rl.algorithms.grpo._default_grpo_save_state() -> nemo_rl.algorithms.grpo.GRPOSaveState -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.grpo._extract_prompt_only_messages( - message_logs: list -) -> list -``` - - - - - - -Extract only prompt messages (user/system) from message logs. - -This is used to get prompt IDs for advantage estimation, excluding -any assistant responses. - -**Parameters:** - - -List of message logs, where each log is a list of messages. - - -**Returns:** `list` - -List of message logs containing only user and system messages. - - - - - - - - -```python -nemo_rl.algorithms.grpo._log_mixed_rewards_and_advantages_information( - logger: nemo_rl.utils.logger.Logger, - total_steps: int, - metrics: dict[str, typing.Any], - baseline: torch.Tensor, - advantages: torch.Tensor -) -> None -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.grpo._should_log_nemo_gym_responses( - master_config: nemo_rl.algorithms.grpo.MasterConfig -) -> bool -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.grpo._should_use_async_rollouts( - master_config: nemo_rl.algorithms.grpo.MasterConfig -) -> bool -``` - - - - - - -Determine if async rollouts should be used based on the configuration. - -Returns True if vLLM backend is used with async_engine enabled. - - - - - - - - -```python -nemo_rl.algorithms.grpo._should_use_nemo_gym( - master_config: nemo_rl.algorithms.grpo.MasterConfig -) -> bool -``` - - - - - - -Determine if NeMo-Gym should be used for rollouts and validation based on the configuration. - - - - - - - - -```python -nemo_rl.algorithms.grpo.async_grpo_train( - policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, - policy_generation: typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], - dataloader: torchdata.stateful_dataloader.StatefulDataLoader, - val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], - tokenizer: nemo_rl.algorithms.grpo.TokenizerType, - loss_fn: nemo_rl.algorithms.interfaces.LossFunction, - task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], - val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], - logger: nemo_rl.utils.logger.Logger, - checkpointer: nemo_rl.utils.checkpoint.CheckpointManager, - grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState, - master_config: nemo_rl.algorithms.grpo.MasterConfig, - max_trajectory_age_steps: int = 1 -) -> None -``` - - - - - - -Run asynchronous GRPO training with replay buffer. - -**Parameters:** - - -Training policy - - - -Generation interface - - - -Training data loader - - - -Validation data loader - - - -Tokenizer - - - -Loss function - - - -Training environments - - - -Validation environments - - - -Logger - - - -Checkpoint manager - - - -Training state - - - -Master configuration - - - -Maximum age (in training steps) for trajectories to be used in training - - - - - - - - - -```python -nemo_rl.algorithms.grpo.dynamic_sampling( - repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], - std: torch.Tensor, - baseline: torch.Tensor, - dynamic_sampling_num_gen_batches: int, - master_config: nemo_rl.algorithms.grpo.MasterConfig, - timer: nemo_rl.utils.timer.Timer, - batch_cache: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] -``` - - - - - - -Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation. - -This function filters the current batch to retain only those prompts that have a non-zero standard deviation. -If the current batch has fewer number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, -we store it in the batch_cache to be used in later iterations. -If the current batch has more number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, -the batch is sliced to ensure batch size is num_prompts_per_step * num_generations_per_prompt. -is_batch_complete is set to False to indicate that the current batch is not enough to meet the required batch size. This is used as a signal in the GRPO training loop -to continue sampling or proceed to training. -This approach is based on the dynamic sampling algorithm from the DAPO paper: -https://arxiv.org/pdf/2503.14476. - -**Parameters:** - - -The current batch of data containing prompts, responses, rewards, baselines, and std. - - - -Tensor representing the standard deviation for each prompt group. - - - -Baseline values for each prompt group. - - - -Number of generation batches processed at the current step. - - - -Configuration containing GRPO and policy settings. - - - -Cache storing previously selected prompts with non-zero std. - - -**Returns:** `BatchedDataDict[DatumSpec]` - -A tuple containing: -- repeated_batch (BatchedDataDict[DatumSpec]): Updated batch with selected prompts. -- is_batch_complete (bool): Indicates if the batch has enough samples with non-zero std for training. -- batch_cache (BatchedDataDict[DatumSpec]): Updated cache for future iterations. - - - - - - - - -```python -nemo_rl.algorithms.grpo.grpo_train( - policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, - policy_generation: typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], - dataloader: torchdata.stateful_dataloader.StatefulDataLoader, - val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], - tokenizer: nemo_rl.algorithms.grpo.TokenizerType, - loss_fn: nemo_rl.algorithms.interfaces.LossFunction, - task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], - val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], - logger: nemo_rl.utils.logger.Logger, - checkpointer: nemo_rl.utils.checkpoint.CheckpointManager, - grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState, - master_config: nemo_rl.algorithms.grpo.MasterConfig -) -> None -``` - - - - - - -Run GRPO training algorithm. - - - - - - - - -```python -nemo_rl.algorithms.grpo.refit_policy_generation( - policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, - policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, - colocated_inference: bool, - _refit_buffer_size_gb: typing.Optional[int] = None, - timer: typing.Optional[nemo_rl.utils.timer.Timer] = None, - kv_scales: typing.Optional[dict[str, float]] = None -) -> None -``` - - - - - - -Refit the policy generation interface with the latest policy weights. - -**Parameters:** - - -The policy to provide weights to the inference engine. - - - -The inference engine to refit. - - - -The size of the buffer to use for refitting. -If it is None, the buffer size will be computed by the remaining memory. -This parameter is primarily used for testing. - - - -Optional Timer used to time the prepare/transfer/update phase - - - -Optional dictionary of KV cache scales for FP8 quantization. - - - - - - - - - -```python -nemo_rl.algorithms.grpo.scale_rewards( - repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], - reward_scaling_cfg: nemo_rl.algorithms.grpo.RewardScalingConfig -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] -``` - - - - - - -Linearly scales rewards from a source range to a target range. - -If `reward_scaling.enabled` is True, each reward in `repeated_batch["total_reward"]` -is clamped to the configured source interval [source_min, source_max] and then -rescaled to the target interval [target_min, target_max]. - - - - - - - - -```python -nemo_rl.algorithms.grpo.setup( - master_config: nemo_rl.algorithms.grpo.MasterConfig, - tokenizer: nemo_rl.algorithms.grpo.TokenizerType, - dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, - val_dataset: typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset], - processor: typing.Optional[transformers.AutoProcessor] = None -) -> tuple[nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], tuple[nemo_rl.distributed.virtual_cluster.RayVirtualCluster, nemo_rl.distributed.virtual_cluster.RayVirtualCluster], torchdata.stateful_dataloader.StatefulDataLoader, typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.ClippedPGLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.grpo.GRPOSaveState, nemo_rl.algorithms.grpo.MasterConfig] -``` - - - - - - -Main entry point for running GRPO algorithm. - -**Returns:** `tuple[ColocatablePolicyInterface, Optional[GenerationInterface], tuple[RayVirtualCluster, RayVirtualCluster], StatefulDataLoader, Optional[StatefulDataLoader], ClippedPGLossFn, Logger, CheckpointManager, GRPOSaveState, MasterConfig]` - -tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader - - - - - - - - -```python -nemo_rl.algorithms.grpo.validate( - policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, - val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], - tokenizer, - val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], - step: int, - master_config: nemo_rl.algorithms.grpo.MasterConfig, - logger: typing.Optional[nemo_rl.utils.logger.Logger] = None -) -> tuple[dict[str, typing.Any], dict[str, typing.Any]] -``` - - - - - - -Run validation on the validation dataset. - - - - - - - - -```python -nemo_rl.algorithms.grpo.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx deleted file mode 100644 index 7976052..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx +++ /dev/null @@ -1,123 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms/interfaces -title: nemo_rl.algorithms.interfaces ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`LossFunction`](#nemo_rl-algorithms-interfaces-LossFunction) | Signature for loss functions used in reinforcement learning algorithms. | -| [`LossType`](#nemo_rl-algorithms-interfaces-LossType) | - | - -### API - - - - - -```python -class nemo_rl.algorithms.interfaces.LossFunction() -``` - - - - - - -Protocol - -Signature for loss functions used in reinforcement learning algorithms. - -Loss functions compute a scalar loss value and associated metrics from -model logprobs and other data contained in a BatchedDataDict. - - - - - - - - -```python -nemo_rl.algorithms.interfaces.LossFunction.__call__( - next_token_logits: torch.Tensor, - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict, - global_valid_seqs: torch.Tensor, - global_valid_toks: torch.Tensor -) -> tuple[torch.Tensor, dict[str, typing.Any]] -``` - - - - - - -Compute loss and metrics from logprobs and other data. - -**Parameters:** - - -Logits from the model, typically with shape [batch_size, seq_len, vocab_size]. - For each position (b, i), contains the logit distribution over the entire vocabulary - for predicting the next token (at position i+1). For example, if processing "The cat sat on", - then next_token_logits[b, 3] would contain the logits for predicting the word - that follows "on". - - - -Dictionary containing all relevant data for loss computation - such as rewards, values, actions, advantages, masks, and other - algorithm-specific information needed for the particular loss calculation. - - - -torch.Tensor -this tensor should contain the number of valid sequences in the microbatch. -It's used for global normalization for losses/metrics that are computed at the sequence level -and needs to be aggregated across all microbatches. - - - -torch.Tensor -This tensor should contain the number of valid tokens in the microbatch. -It's used for global normalization for losses/metrics that are computed at the token level -and needs to be aggregated across all microbatches. - - -**Returns:** `tuple[torch.Tensor, dict[str, Any]]` - -(loss, metrics) -- loss: A scalar tensor representing the loss value to be minimized during training -- metrics: A dictionary of metrics related to the loss computation, which may include - component losses, statistics about gradients/rewards, and other diagnostic information - - - - - - - - - -```python -class nemo_rl.algorithms.interfaces.LossType -``` - - - - - - -**Bases:** `enum.Enum` - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx deleted file mode 100644 index f8307d1..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx +++ /dev/null @@ -1,875 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms/loss_functions -title: nemo_rl.algorithms.loss_functions ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`ClippedPGLossConfig`](#nemo_rl-algorithms-loss_functions-ClippedPGLossConfig) | - | -| [`ClippedPGLossDataDict`](#nemo_rl-algorithms-loss_functions-ClippedPGLossDataDict) | Required keys for the Clipped Policy Gradient loss function. | -| [`ClippedPGLossFn`](#nemo_rl-algorithms-loss_functions-ClippedPGLossFn) | Generalized Clipped Policy Gradient loss function w/ KL regularization. | -| [`DPOLossConfig`](#nemo_rl-algorithms-loss_functions-DPOLossConfig) | - | -| [`DPOLossDataDict`](#nemo_rl-algorithms-loss_functions-DPOLossDataDict) | Required keys for the DPO loss function. | -| [`DPOLossFn`](#nemo_rl-algorithms-loss_functions-DPOLossFn) | Direct Preference Optimization (DPO) loss function. | -| [`DistillationLossConfig`](#nemo_rl-algorithms-loss_functions-DistillationLossConfig) | - | -| [`DistillationLossDataDict`](#nemo_rl-algorithms-loss_functions-DistillationLossDataDict) | - | -| [`DistillationLossFn`](#nemo_rl-algorithms-loss_functions-DistillationLossFn) | Distillation loss function. | -| [`NLLLoss`](#nemo_rl-algorithms-loss_functions-NLLLoss) | Negative Log Likelihood Loss function. | -| [`PreferenceLoss`](#nemo_rl-algorithms-loss_functions-PreferenceLoss) | Preference Loss function. | -| [`PreferenceLossDataDict`](#nemo_rl-algorithms-loss_functions-PreferenceLossDataDict) | Required keys for the preference loss function. | -| [`SequencePackingLossWrapper`](#nemo_rl-algorithms-loss_functions-SequencePackingLossWrapper) | - | - -### Data - -[`Tensor`](#nemo_rl-algorithms-loss_functions-Tensor) - -### API - - - - - -```python -class nemo_rl.algorithms.loss_functions.ClippedPGLossConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict -``` - - - - - - -**Bases:** `typing.TypedDict` - -Required keys for the Clipped Policy Gradient loss function. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.ClippedPGLossFn( - cfg: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig -) -``` - - - - - - -**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) - -Generalized Clipped Policy Gradient loss function w/ KL regularization. - -This implements: - -- PPO (Clipped) - https://arxiv.org/abs/1707.06347 -- GRPO - https://arxiv.org/abs/2402.03300 -- REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740 -- GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071 -- Truly on-policy (set force_on_policy_ratio = True to force ratio = 1.0, requires one update per rollout) - -Formula: -L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref) - -where: -- r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio -- A_t is the advantage estimate -- ε is the clip parameter (ratio_clip_min/ratio_clip_max) - - As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476), - we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.) - - ratio_clip_min: minimum value for the clip parameter - - ratio_clip_max: maximum value for the clip parameter -- β is the KL penalty coefficient (reference_policy_kl_penalty) -- KL(π_θ || π_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.) - -For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: -L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref) - -Also supports "Dual-Clipping" from https://arxiv.org/pdf/1912.09729, which -imposes an additional upper bound on the probability ratio when advantages are negative. -This prevents excessive policy updates. $rA << 0$ -> $cA$(clipped) -The loss function is modified to the following when A_t < 0: -L(θ) = E_t [ max(min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t), c * A_t) ] - β * KL(π_θ || π_ref) - -where: -- c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is - usually set as 3 empirically. - -Due to potential numerical instability, we cast the logits to float32 before computing the loss. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.loss_functions.ClippedPGLossFn.__call__( - next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict], - global_valid_seqs: torch.Tensor, - global_valid_toks: torch.Tensor, - vocab_parallel_rank: typing.Optional[int] = None, - vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None -) -> tuple[torch.Tensor, dict] -``` - - - - - - -Clipped Policy Gradient RL loss function. - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.DPOLossConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.DPOLossDataDict -``` - - - - - - -**Bases:** `typing.TypedDict` - -Required keys for the DPO loss function. - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.DPOLossFn( - cfg: nemo_rl.algorithms.loss_functions.DPOLossConfig -) -``` - - - - - - -**Bases:** [PreferenceLoss](#nemo_rl-algorithms-loss_functions-PreferenceLoss) - -Direct Preference Optimization (DPO) loss function. - -This loss function implements the DPO algorithm as described in: -"Direct Preference Optimization: Your Language Model is Secretly a Reward Model" -(https://arxiv.org/abs/2305.18290) - -The loss combines two main components: -1. Preference Loss: Optimizes the model to prefer chosen responses over rejected ones -2. SFT Loss (optional): Auxiliary supervised fine-tuning loss on chosen responses - -The total loss is computed as: -L(θ) = w_p * L_pref(θ) + w_s * L_sft(θ) - -where: -- w_p is the preference_loss_weight -- w_s is the sft_loss_weight -- L_pref(θ) is the preference loss term -- L_sft(θ) is the supervised fine-tuning loss term - -The preference loss term is computed as: -L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))] - -where: -- σ is the sigmoid function -- β is the reference_policy_kl_penalty -- r_chosen and r_rejected are the rewards for chosen and rejected responses -- The rewards are computed as the sum of log probability differences between - the current policy and reference policy - -If preference_average_log_probs is True, the rewards are averaged over tokens: -r = (1/n) * Σ_t (log π_θ(a_t|s_t) - log π_ref(a_t|s_t)) - -Otherwise, the rewards are summed over tokens. - -The SFT loss term is a standard negative log likelihood loss on the chosen responses. -If sft_average_log_probs is True, the loss is averaged over tokens. - -**Parameters:** - - -Configuration dictionary containing: -- reference_policy_kl_penalty (float): Strength of the KL penalty term (β) -- preference_loss_weight (float): Weight for the preference loss term (w_p) -- sft_loss_weight (float): Weight for the SFT loss term (w_s) -- preference_average_log_probs (bool): Whether to average log probs across tokens in preference loss -- sft_average_log_probs (bool): Whether to average log probs across tokens in SFT loss - - -**Returns:** - -tuple[torch.Tensor, dict]: A tuple containing: -- The total loss value -- A dictionary with metrics including: - - loss: Total loss value - - sft_loss: SFT loss component - - preference_loss: Preference loss component - - accuracy: Fraction of examples where chosen response has higher reward - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.loss_functions.DPOLossFn.__call__( - next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict], - global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, - global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None, - vocab_parallel_rank: typing.Optional[int] = None, - vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None -) -> tuple[torch.Tensor, dict[str, typing.Any]] -``` - - - - - - - - - - - - -```python -nemo_rl.algorithms.loss_functions.DPOLossFn._dpo_loss( - next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict], - global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, - vocab_parallel_rank: typing.Optional[int] = None, - vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None -) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor] -``` - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.DistillationLossConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.DistillationLossDataDict -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.DistillationLossFn( - cfg: nemo_rl.algorithms.loss_functions.DistillationLossConfig -) -``` - - - - - - -**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) - -Distillation loss function. - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.loss_functions.DistillationLossFn.__call__( - next_token_logits: torch.Tensor, - data: nemo_rl.algorithms.loss_functions.DistillationLossDataDict, - global_valid_seqs: torch.Tensor, - global_valid_toks: torch.Tensor, - vocab_parallel_rank: typing.Optional[int] = None, - vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None -) -> tuple[torch.Tensor, dict[str, typing.Any]] -``` - - - - - - -Compute distillation loss between teacher and student logits. - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.NLLLoss() -``` - - - - - - -**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) - -Negative Log Likelihood Loss function. - - - - - - - - -```python -nemo_rl.algorithms.loss_functions.NLLLoss.__call__( - next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor | None, - global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor, - vocab_parallel_rank: typing.Optional[int] = None, - vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, - dpo_loss: bool = False, - dpo_average_log_probs: bool = False -) -> tuple[torch.Tensor, dict[str, typing.Any]] -``` - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.PreferenceLoss() -``` - - - - - - -**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) - -Preference Loss function. - -Optimizes the model to prefer chosen responses over rejected ones - -The preference loss is computed as: -L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))] - -where: -- σ is the sigmoid function -- β is a scaling factor (ex: `reference_policy_kl_penalty` in DPO) -- r_chosen and r_rejected are the rewards for chosen and rejected responses - -**Returns:** - -tuple[torch.Tensor, dict]: A tuple containing: -- The preference loss value -- A dictionary with metrics including: - - loss: Preference loss - - accuracy: Fraction of examples where chosen response has higher reward - - - - - - - - -```python -nemo_rl.algorithms.loss_functions.PreferenceLoss.__call__( - rewards: nemo_rl.algorithms.loss_functions.Tensor, - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.PreferenceLossDataDict], - global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, - global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None -) -> tuple[torch.Tensor, dict[str, typing.Any]] -``` - - - - - - - - - - - - -```python -nemo_rl.algorithms.loss_functions.PreferenceLoss._preference_loss( - rewards: nemo_rl.algorithms.loss_functions.Tensor, - sample_mask: nemo_rl.algorithms.loss_functions.Tensor, - global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, - beta: float = 1.0 -) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor] -``` - - - - - - - - - - - - -```python -nemo_rl.algorithms.loss_functions.PreferenceLoss.split_output_tensor( - tensor: nemo_rl.algorithms.loss_functions.Tensor -) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor] -``` - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.PreferenceLossDataDict -``` - - - - - - -**Bases:** `typing.TypedDict` - -Required keys for the preference loss function. - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.loss_functions.SequencePackingLossWrapper( - loss_fn: nemo_rl.algorithms.interfaces.LossFunction, - cu_seqlens_q: nemo_rl.algorithms.loss_functions.Tensor, - cu_seqlens_q_padded: typing.Optional[nemo_rl.algorithms.loss_functions.Tensor] = None -) -``` - - - - - - - - - - -```python -nemo_rl.algorithms.loss_functions.SequencePackingLossWrapper.__call__( - next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor | None, - global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None, - vocab_parallel_rank: typing.Optional[int] = None, - vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None -) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, dict[str, typing.Any]] -``` - - - - - - -Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding. - - - - - - - - - -```python -nemo_rl.algorithms.loss_functions.Tensor = TypeVar('Tensor', bound=(torch.Tensor)) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx deleted file mode 100644 index ffcae23..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx +++ /dev/null @@ -1,102 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms/reward_functions -title: nemo_rl.algorithms.reward_functions ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`RewardShapingConfig`](#nemo_rl-algorithms-reward_functions-RewardShapingConfig) | Configuration for reward function processing. | - -### Functions - -| Name | Description | -|------|-------------| -| [`apply_reward_shaping`](#nemo_rl-algorithms-reward_functions-apply_reward_shaping) | Process rewards by applying penalties for responses exceeding max_response_length. Currently, this function only supports DAPO reward shaping as illustrated in the DAPO paper : https://arxiv.org/pdf/2503.14476. | - -### Data - -[`Tensor`](#nemo_rl-algorithms-reward_functions-Tensor) - -### API - - - - - -```python -class nemo_rl.algorithms.reward_functions.RewardShapingConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Configuration for reward function processing. - -This configuration enables custom reward shaping, currently supporting DAPO-style -penalties for responses that exceed the maximum response length threshold. - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.reward_functions.apply_reward_shaping( - batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict, - cfg: nemo_rl.algorithms.reward_functions.RewardShapingConfig -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict -``` - - - - - - -Process rewards by applying penalties for responses exceeding max_response_length. Currently, this function only supports DAPO reward shaping as illustrated in the DAPO paper : https://arxiv.org/pdf/2503.14476. - -Nonetheless, it can be potentially extended to support any custom reward logic. - - - - - - - - -```python -nemo_rl.algorithms.reward_functions.Tensor = TypeVar('Tensor', bound=(torch.Tensor)) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx deleted file mode 100644 index ed41f3a..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx +++ /dev/null @@ -1,320 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms/rm -title: nemo_rl.algorithms.rm ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`MasterConfig`](#nemo_rl-algorithms-rm-MasterConfig) | - | -| [`RMConfig`](#nemo_rl-algorithms-rm-RMConfig) | - | -| [`RMSaveState`](#nemo_rl-algorithms-rm-RMSaveState) | - | -| [`RMValMetrics`](#nemo_rl-algorithms-rm-RMValMetrics) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`_default_rm_save_state`](#nemo_rl-algorithms-rm-_default_rm_save_state) | - | -| [`rm_train`](#nemo_rl-algorithms-rm-rm_train) | - | -| [`setup`](#nemo_rl-algorithms-rm-setup) | Main entry point for running RM algorithm. | -| [`validate`](#nemo_rl-algorithms-rm-validate) | - | -| [`validate_one_dataset`](#nemo_rl-algorithms-rm-validate_one_dataset) | Run validation on one validation dataset. | - -### API - - - - - -```python -class nemo_rl.algorithms.rm.MasterConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.rm.RMConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.rm.RMSaveState -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.rm.RMValMetrics -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.rm._default_rm_save_state() -> nemo_rl.algorithms.rm.RMSaveState -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.rm.rm_train( - policy, - train_dataloader, - val_dataloader, - tokenizer, - loss_fn, - master_config, - logger, - checkpointer, - rm_save_state -) -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.rm.setup( - master_config: nemo_rl.algorithms.rm.MasterConfig, - tokenizer: transformers.AutoTokenizer, - train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, - val_dataset: dict[str, nemo_rl.data.datasets.AllTaskProcessedDataset] -) -> tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, dict[str, torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.PreferenceLoss, nemo_rl.algorithms.rm.MasterConfig, nemo_rl.utils.logger.Logger, nemo_rl.data.interfaces.TaskDataSpec, nemo_rl.algorithms.rm.RMSaveState] -``` - - - - - - -Main entry point for running RM algorithm. - -**Returns:** `tuple[Policy, RayVirtualCluster, StatefulDataLoader, dict[str, StatefulDataLoader], PreferenceLoss, MasterConfig, Logger, TaskDataSpec, RMSaveState]` - -Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger - - - - - - - - -```python -nemo_rl.algorithms.rm.validate( - policy: nemo_rl.models.policy.interfaces.PolicyInterface, - val_dataloader: dict[str, torchdata.stateful_dataloader.StatefulDataLoader], - tokenizer, - loss_fn, - step: int, - master_config: nemo_rl.algorithms.rm.MasterConfig, - val_batches: int, - val_batch_size: int, - val_mbs: int, - logger: nemo_rl.utils.logger.Logger -) -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.rm.validate_one_dataset( - policy: nemo_rl.models.policy.interfaces.PolicyInterface, - val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader, - loss_fn, - step: int, - master_config: nemo_rl.algorithms.rm.MasterConfig, - val_batches: int, - val_batch_size: int, - val_mbs: int, - dataset_name: str -) -``` - - - - - - -Run validation on one validation dataset. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx deleted file mode 100644 index d9a3bd6..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx +++ /dev/null @@ -1,258 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms/sft -title: nemo_rl.algorithms.sft ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`MasterConfig`](#nemo_rl-algorithms-sft-MasterConfig) | - | -| [`SFTConfig`](#nemo_rl-algorithms-sft-SFTConfig) | - | -| [`SFTSaveState`](#nemo_rl-algorithms-sft-SFTSaveState) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`_default_sft_save_state`](#nemo_rl-algorithms-sft-_default_sft_save_state) | - | -| [`setup`](#nemo_rl-algorithms-sft-setup) | Main entry point for running SFT algorithm. | -| [`sft_train`](#nemo_rl-algorithms-sft-sft_train) | - | -| [`validate`](#nemo_rl-algorithms-sft-validate) | Run validation on the validation dataset. | - -### API - - - - - -```python -class nemo_rl.algorithms.sft.MasterConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.sft.SFTConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.algorithms.sft.SFTSaveState -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.algorithms.sft._default_sft_save_state() -> nemo_rl.algorithms.sft.SFTSaveState -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.sft.setup( - master_config: nemo_rl.algorithms.sft.MasterConfig, - tokenizer: transformers.AutoTokenizer, - train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, - val_dataset: typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset] -) -> tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.NLLLoss, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.sft.SFTSaveState, nemo_rl.algorithms.sft.MasterConfig] -``` - - - - - - -Main entry point for running SFT algorithm. - -**Returns:** `tuple[Policy, RayVirtualCluster, StatefulDataLoader, Optional[StatefulDataLoader], NLLLoss, Logger, CheckpointManager, SFTSaveState, MasterConfig]` - -Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger - - - - - - - - -```python -nemo_rl.algorithms.sft.sft_train( - policy, - train_dataloader, - val_dataloader, - tokenizer, - loss_fn, - master_config, - logger, - checkpointer, - sft_save_state: nemo_rl.algorithms.sft.SFTSaveState -) -> None -``` - - - - - - - - - - - - - -```python -nemo_rl.algorithms.sft.validate( - policy: nemo_rl.models.policy.interfaces.PolicyInterface, - val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], - tokenizer, - loss_fn, - step: int, - master_config: nemo_rl.algorithms.sft.MasterConfig, - val_batches: int, - val_batch_size: int, - val_mbs: int -) -``` - - - - - - -Run validation on the validation dataset. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx deleted file mode 100644 index 200ecb0..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx +++ /dev/null @@ -1,379 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/algorithms/utils -title: nemo_rl.algorithms.utils ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`calculate_baseline_and_std_per_prompt`](#nemo_rl-algorithms-utils-calculate_baseline_and_std_per_prompt) | Function to compute a baseline for each (prompt, response) pair in the batch. | -| [`calculate_kl`](#nemo_rl-algorithms-utils-calculate_kl) | Calculates a per-token estimate of the KL Divergence between two logprobs. | -| [`get_tokenizer`](#nemo_rl-algorithms-utils-get_tokenizer) | Get the tokenizer and set pad token to eos token if it is not already set. | -| [`log_generation_metrics_to_wandb`](#nemo_rl-algorithms-utils-log_generation_metrics_to_wandb) | Log generation metrics to wandb. | -| [`masked_mean`](#nemo_rl-algorithms-utils-masked_mean) | Computes the mean of a microbatch, using a global statistic as the normalization factor. | -| [`maybe_pad_last_batch`](#nemo_rl-algorithms-utils-maybe_pad_last_batch) | Pads the given batch so that its size is divisible by (mbs * dp_size). | -| [`print_performance_metrics`](#nemo_rl-algorithms-utils-print_performance_metrics) | Print performance metrics for GRPO. | -| [`set_seed`](#nemo_rl-algorithms-utils-set_seed) | Sets the seed for python, numpy, and pytorch. | -| [`surpress_user_warnings`](#nemo_rl-algorithms-utils-surpress_user_warnings) | - | - -### API - - - - - -```python -nemo_rl.algorithms.utils.calculate_baseline_and_std_per_prompt( - prompts: torch.Tensor, - rewards: torch.Tensor, - valid_mask: torch.Tensor, - leave_one_out_baseline: bool = True -) -> tuple[torch.Tensor, torch.Tensor] -``` - - - - - - -Function to compute a baseline for each (prompt, response) pair in the batch. - -The same baseline is calculated for each prompt. Samples set to 0 in 'valid_mask' -are not included in the baseline calculation. - -prompts: tensor (b, s) Tensor of prompts the model used. May be on any device -rewards: tensor (b,) Float-valued rewards. May be on any device -valid_mask: tensor (b,) Vector of 0/1, where 0 is to ignore and 1 is to keep -leave_one_out_baseline: bool Compute an unbiased baseline by leaving out the sample that - the baseline is for (from RLOO https://arxiv.org/abs/2402.14740) - -Returns: -tensor (b,), tensor (b,) of baselines and std on the same device as 'rewards' - - - - - - - - -```python -nemo_rl.algorithms.utils.calculate_kl( - logprobs: torch.Tensor, - logprobs_reference: torch.Tensor, - kl_type: str = 'k3', - input_clamp_value: float | None = 20.0, - output_clamp_value: float | None = 10.0 -) -> torch.Tensor -``` - - - - - - -Calculates a per-token estimate of the KL Divergence between two logprobs. - -From Schulman 2020, http://joschu.net/blog/kl-approx.html. - -**Parameters:** - - -torch.Tensor (b, s) - - - -torch.Tensor (b, s) - - - -Type of KL approximation to use. Valid values: "k1", "k2", "k3". - - - -Optional clamping value for logr to prevent numerical instability. - If None, no clamping is applied. - - - -Optional clamping value for kl to prevent numerical instability. - If None, no clamping is applied. - - -**Returns:** `torch.Tensor` - -torch.Tensor: Per-token KL penalty values (b, s) - - - - - - - - -```python -nemo_rl.algorithms.utils.get_tokenizer( - tokenizer_config: nemo_rl.models.policy.TokenizerConfig, - get_processor: bool = False -) -> transformers.PreTrainedTokenizerBase -``` - - - - - - -Get the tokenizer and set pad token to eos token if it is not already set. - -This function initializes a tokenizer from the Hugging Face transformers library -and configures it with appropriate chat templates and padding tokens. - -**Parameters:** - - -A dictionary containing tokenizer configuration. -Required keys: - - name: The name or path of the pretrained tokenizer -Optional keys: - - chat_template: The chat template to use. Can be: - - None: Uses a passthrough template that just returns message content - - "default": Uses the tokenizer's default template - - A custom jinja2 template string - If not specified, the tokenizer's default template will be used. - - - -Whether to return a processor (via AutoProcessor) instead of a tokenizer. - - -**Returns:** `PreTrainedTokenizerBase` - -The configured tokenizer instance - -**Examples:** - - - -```python ->>> from transformers import AutoTokenizer ->>> from nemo_rl.algorithms.utils import get_tokenizer ->>> # not specifying a chat template uses the tokenizer's default ->>> config = {"name": "meta-llama/Llama-3.2-1B-Instruct"} ->>> tokenizer = get_tokenizer(config) -No chat template provided, using tokenizer's default ->>> messages = [ -... {"role": "system", "content": "You are a helpful AI assistant."}, -... {"role": "user", "content": "Hello!"} -... ] ->>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) ->>> assert formatted == AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct").apply_chat_template(messages, tokenize=False) - ->>> # Using a passthrough template ->>> config = { -... "name": "meta-llama/Llama-3.2-1B-Instruct", -... "chat_template": None -... } ->>> tokenizer = get_tokenizer(config) -Using passthrough chat template ->>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) ->>> assert formatted == "".join(msg["content"] for msg in messages) - ->>> # Using a custom template ->>> config = { -... "name": "meta-llama/Llama-3.2-1B-Instruct", -... "chat_template": "{% for message in messages %}{{ ' START: ' + message['content'] + ' END.' }}{% endfor %}" -... } ->>> tokenizer = get_tokenizer(config) -Using custom chat template ->>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) ->>> assert formatted == " START: You are a helpful AI assistant. END. START: Hello! END." - ->>> # Requesting a processor (for multimodal models like Qwen-VL) ->>> config = {"name": "Qwen/Qwen2.5-VL-3B-Instruct"} ->>> processor = get_tokenizer(config, get_processor=True) -No chat template provided, using tokenizer's default ->>> messages = [ -... {"role": "system", "content": "You are a helpful AI assistant."}, -... {"role": "user", "content": "Hello!"} -... ] ->>> formatted = processor.tokenizer.apply_chat_template(messages, tokenize=False) ->>> assert formatted == AutoTokenizer.from_pretrained( -... "Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True -... ).apply_chat_template(messages, tokenize=False) ->>> assert processor.pad_token_id == processor.tokenizer.pad_token_id ->>> -``` - - - - - - - - - - -```python -nemo_rl.algorithms.utils.log_generation_metrics_to_wandb( - generation_logger_metrics: dict[str, dict[int, list[typing.Any]]], - step: int, - timeline_interval: float, - logger: nemo_rl.utils.logger.Logger -) -> None -``` - - - - - - -Log generation metrics to wandb. - -**Parameters:** - - -Dictionary of generation logger metrics - - - -Global step value - - - -Interval between timeline points (in seconds) - - - -Logger instance - - - - - - - - - -```python -nemo_rl.algorithms.utils.masked_mean( - values: torch.Tensor, - mask: torch.Tensor, - dim: typing.Optional[int] = None, - global_normalization_factor: typing.Optional[torch.Tensor | float] = None -) -``` - - - - - - -Computes the mean of a microbatch, using a global statistic as the normalization factor. - - - - - - - - -```python -nemo_rl.algorithms.utils.maybe_pad_last_batch( - batch: dict, - dp_size: int, - mbs: int -) -> dict -``` - - - - - - -Pads the given batch so that its size is divisible by (mbs * dp_size). - -**Parameters:** - - -The batch to pad. - - - -Data parallel size. - - - -Micro batch size. - - -**Returns:** `dict` - -The padded batch. - - - - - - - - -```python -nemo_rl.algorithms.utils.print_performance_metrics( - train_results: dict[str, float], - metrics: dict[str, typing.Any], - timing_metrics: dict[str, float], - master_config: dict -) -> dict[str, float] -``` - - - - - - -Print performance metrics for GRPO. - - - - - - - - -```python -nemo_rl.algorithms.utils.set_seed( - seed: int -) -> None -``` - - - - - - -Sets the seed for python, numpy, and pytorch. - - - - - - - - -```python -nemo_rl.algorithms.utils.surpress_user_warnings( - f -) -``` - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx deleted file mode 100644 index 3cafa95..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx +++ /dev/null @@ -1,466 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data -title: nemo_rl.data ---- - -## Subpackages - -- **[`nemo_rl.data.datasets`](/nemo-rl/nemo_rl/data/datasets)** -- **[`nemo_rl.data.packing`](/nemo-rl/nemo_rl/data/packing)** - -## Submodules - -- **[`nemo_rl.data.chat_templates`](/nemo-rl/nemo_rl/data/chat_templates)** -- **[`nemo_rl.data.collate_fn`](/nemo-rl/nemo_rl/data/collate_fn)** -- **[`nemo_rl.data.interfaces`](/nemo-rl/nemo_rl/data/interfaces)** -- **[`nemo_rl.data.llm_message_utils`](/nemo-rl/nemo_rl/data/llm_message_utils)** -- **[`nemo_rl.data.multimodal_utils`](/nemo-rl/nemo_rl/data/multimodal_utils)** -- **[`nemo_rl.data.processors`](/nemo-rl/nemo_rl/data/processors)** -- **[`nemo_rl.data.utils`](/nemo-rl/nemo_rl/data/utils)** - -## Package Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AIMEEvalDataConfig`](#nemo_rl-data-AIMEEvalDataConfig) | Config for AIME datasets. | -| [`DataConfig`](#nemo_rl-data-DataConfig) | - | -| [`GPQAEvalDataConfig`](#nemo_rl-data-GPQAEvalDataConfig) | Config for GPQA datasets. | -| [`LocalMathEvalDataConfig`](#nemo_rl-data-LocalMathEvalDataConfig) | Config for local math datasets loaded from files. | -| [`MMLUEvalDataConfig`](#nemo_rl-data-MMLUEvalDataConfig) | Config for MMLU and multilingual MMLU datasets. | -| [`MMLUProEvalDataConfig`](#nemo_rl-data-MMLUProEvalDataConfig) | Config for MMLU Pro dataset. | -| [`MathEvalDataConfig`](#nemo_rl-data-MathEvalDataConfig) | Config for Math datasets. | -| [`PreferenceDatasetConfig`](#nemo_rl-data-PreferenceDatasetConfig) | - | -| [`ResponseDatasetConfig`](#nemo_rl-data-ResponseDatasetConfig) | - | - -### Data - -[`EvalDataConfigType`](#nemo_rl-data-EvalDataConfigType) - -### API - - - - - -```python -class nemo_rl.data.AIMEEvalDataConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Config for AIME datasets. - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.data.DataConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.data.GPQAEvalDataConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Config for GPQA datasets. - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.data.LocalMathEvalDataConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Config for local math datasets loaded from files. - -dataset_name can be a URL or local file path. -Requires additional fields: problem_key, solution_key, file_format, split. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.data.MMLUEvalDataConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Config for MMLU and multilingual MMLU datasets. - -Supports dataset_name: "mmlu" or "mmlu_{language}" where language is one of: -AR-XY, BN-BD, DE-DE, EN-US, ES-LA, FR-FR, HI-IN, ID-ID, IT-IT, JA-JP, -KO-KR, PT-BR, ZH-CN, SW-KE, YO-NG - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.data.MMLUProEvalDataConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Config for MMLU Pro dataset. - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.data.MathEvalDataConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Config for Math datasets. - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.data.PreferenceDatasetConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.data.ResponseDatasetConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.data.EvalDataConfigType = MMLUEvalDataConfig | MMLUProEvalDataConfig | AIMEEvalDataConfig | GPQAEvalDataCo... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx deleted file mode 100644 index 11e5f15..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx +++ /dev/null @@ -1,35 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/chat_templates -title: nemo_rl.data.chat_templates ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`COMMON_CHAT_TEMPLATES`](#nemo_rl-data-chat_templates-COMMON_CHAT_TEMPLATES) | - | - -### API - - - - - -```python -class nemo_rl.data.chat_templates.COMMON_CHAT_TEMPLATES() -``` - - - - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx deleted file mode 100644 index 56b6bb7..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx +++ /dev/null @@ -1,166 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/collate_fn -title: nemo_rl.data.collate_fn ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`eval_collate_fn`](#nemo_rl-data-collate_fn-eval_collate_fn) | Collate function for evaluation. | -| [`preference_collate_fn`](#nemo_rl-data-collate_fn-preference_collate_fn) | Collate function for preference data training. | -| [`rl_collate_fn`](#nemo_rl-data-collate_fn-rl_collate_fn) | Collate function for RL training. | - -### Data - -[`TokenizerType`](#nemo_rl-data-collate_fn-TokenizerType) - -### API - - - - - -```python -nemo_rl.data.collate_fn.eval_collate_fn( - data_batch: list[nemo_rl.data.interfaces.DatumSpec] -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] -``` - - - - - - -Collate function for evaluation. - -Takes a list of data samples and combines them into a single batched dictionary -for model evaluation. - -Examples: - - -```python ->>> import torch ->>> from nemo_rl.data.collate_fn import eval_collate_fn ->>> from nemo_rl.data.interfaces import DatumSpec ->>> data_batch = [ -... DatumSpec( -... message_log=[{"role": "user", "content": "Hello", "token_ids": torch.tensor([1, 2, 3])}], -... extra_env_info={'ground_truth': '1'}, -... idx=0, -... ), -... DatumSpec( -... message_log=[{"role": "assistant", "content": "Hi there", "token_ids": torch.tensor([4, 5, 6, 7])}], -... extra_env_info={'ground_truth': '2'}, -... idx=1, -... ), -... ] ->>> output = eval_collate_fn(data_batch) ->>> output['message_log'][0] -[{'role': 'user', 'content': 'Hello', 'token_ids': tensor([1, 2, 3])}] ->>> output['message_log'][1] -[{'role': 'assistant', 'content': 'Hi there', 'token_ids': tensor([4, 5, 6, 7])}] ->>> output['extra_env_info'] -[{'ground_truth': '1'}, {'ground_truth': '2'}] ->>> output['idx'] -[0, 1] -``` - - - -**Parameters:** - - -List of data samples with message_log, extra_env_info, and idx fields. - - -**Returns:** `BatchedDataDict[Any]` - -BatchedDataDict with message_log, extra_env_info, and idx fields. - - - - - - - - -```python -nemo_rl.data.collate_fn.preference_collate_fn( - data_batch: list[nemo_rl.data.interfaces.PreferenceDatumSpec], - tokenizer: nemo_rl.data.collate_fn.TokenizerType, - make_sequence_length_divisible_by: int, - add_loss_mask: bool -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] -``` - - - - - - -Collate function for preference data training. - -This function separates the chosen and rejected responses to create -two examples per prompt. The chosen and rejected examples are interleaved -along the batch dimension, resulting in a batch size of 2 * len(data_batch). - -Returns: - BatchedDataDict with input_ids, input_lengths, token_mask (optional), and sample_mask fields. - -**Parameters:** - - -List of data samples with message_log_chosen, message_log_rejected, length_chosen, length_rejected, loss_multiplier, idx, and task_name fields. - - - -Tokenizer for text processing - - - -Make the sequence length divisible by this value - - - -Whether to add a token_mask to the returned data - - - - - - - - - -```python -nemo_rl.data.collate_fn.rl_collate_fn( - data_batch: list[nemo_rl.data.interfaces.DatumSpec] -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] -``` - - - - - - -Collate function for RL training. - - - - - - - - -```python -nemo_rl.data.collate_fn.TokenizerType = Union[PreTrainedTokenizerBase, AutoProcessor] -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx deleted file mode 100644 index 88450e5..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx +++ /dev/null @@ -1,37 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets -title: nemo_rl.data.datasets ---- - -## Subpackages - -- **[`nemo_rl.data.datasets.eval_datasets`](/nemo-rl/nemo_rl/data/datasets/eval_datasets)** -- **[`nemo_rl.data.datasets.preference_datasets`](/nemo-rl/nemo_rl/data/datasets/preference_datasets)** -- **[`nemo_rl.data.datasets.response_datasets`](/nemo-rl/nemo_rl/data/datasets/response_datasets)** - -## Submodules - -- **[`nemo_rl.data.datasets.processed_dataset`](/nemo-rl/nemo_rl/data/datasets/processed_dataset)** -- **[`nemo_rl.data.datasets.raw_dataset`](/nemo-rl/nemo_rl/data/datasets/raw_dataset)** -- **[`nemo_rl.data.datasets.utils`](/nemo-rl/nemo_rl/data/datasets/utils)** - -## Package Contents - -### Data - -[`__all__`](#nemo_rl-data-datasets-__all__) - -### API - - - - - -```python -nemo_rl.data.datasets.__all__ = ['AllTaskProcessedDataset', 'load_eval_dataset', 'load_preference_dataset', 'loa... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx deleted file mode 100644 index 433590f..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx +++ /dev/null @@ -1,60 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/eval_datasets -title: nemo_rl.data.datasets.eval_datasets ---- - -## Submodules - -- **[`nemo_rl.data.datasets.eval_datasets.aime`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime)** -- **[`nemo_rl.data.datasets.eval_datasets.gpqa`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa)** -- **[`nemo_rl.data.datasets.eval_datasets.local_math_dataset`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset)** -- **[`nemo_rl.data.datasets.eval_datasets.math`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/math)** -- **[`nemo_rl.data.datasets.eval_datasets.mmlu`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu)** -- **[`nemo_rl.data.datasets.eval_datasets.mmlu_pro`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro)** - -## Package Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`load_eval_dataset`](#nemo_rl-data-datasets-eval_datasets-load_eval_dataset) | Loads evaluation dataset. | - -### Data - -[`__all__`](#nemo_rl-data-datasets-eval_datasets-__all__) - -### API - - - - - -```python -nemo_rl.data.datasets.eval_datasets.load_eval_dataset( - data_config -) -``` - - - - - - -Loads evaluation dataset. - - - - - - - - -```python -nemo_rl.data.datasets.eval_datasets.__all__ = ['AIMEDataset', 'GPQADataset', 'LocalMathDataset', 'MathDataset', 'MMLUDataset',... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx deleted file mode 100644 index 155c936..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx +++ /dev/null @@ -1,64 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime -title: nemo_rl.data.datasets.eval_datasets.aime ---- - -AIME dataset. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AIMEDataset`](#nemo_rl-data-datasets-eval_datasets-aime-AIMEDataset) | - | - -### API - - - - - -```python -class nemo_rl.data.datasets.eval_datasets.aime.AIMEDataset( - variant: typing.Literal['2024', '2025'] = '2025', - prompt_file: typing.Optional[str] = None, - system_prompt_file: typing.Optional[str] = None -) -``` - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.eval_datasets.aime.AIMEDataset._rekey( - data: dict[str, typing.Any] -) -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx deleted file mode 100644 index d1ca3a9..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx +++ /dev/null @@ -1,64 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa -title: nemo_rl.data.datasets.eval_datasets.gpqa ---- - -GPQA dataset and its variants. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`GPQADataset`](#nemo_rl-data-datasets-eval_datasets-gpqa-GPQADataset) | - | - -### API - - - - - -```python -class nemo_rl.data.datasets.eval_datasets.gpqa.GPQADataset( - variant: typing.Literal['diamond', 'main'] = 'diamond', - prompt_file: typing.Optional[str] = None, - system_prompt_file: typing.Optional[str] = None -) -``` - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.eval_datasets.gpqa.GPQADataset._rekey( - data: dict[str, typing.Any] -) -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx deleted file mode 100644 index e6d6754..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx +++ /dev/null @@ -1,65 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset -title: nemo_rl.data.datasets.eval_datasets.local_math_dataset ---- - -Local math dataset. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`LocalMathDataset`](#nemo_rl-data-datasets-eval_datasets-local_math_dataset-LocalMathDataset) | - | - -### API - - - - - -```python -class nemo_rl.data.datasets.eval_datasets.local_math_dataset.LocalMathDataset( - data_path: str, - problem_key: str, - solution_key: str, - split: typing.Optional[str] = None, - file_format: typing.Literal['csv', 'json'] = 'csv', - prompt_file: typing.Optional[str] = None, - system_prompt_file: typing.Optional[str] = None -) -``` - - - - - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.eval_datasets.local_math_dataset.LocalMathDataset._rekey( - data: dict[str, typing.Any] -) -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx deleted file mode 100644 index c00f375..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx +++ /dev/null @@ -1,61 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/math -title: nemo_rl.data.datasets.eval_datasets.math ---- - -Math dataset and its variants. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`MathDataset`](#nemo_rl-data-datasets-eval_datasets-math-MathDataset) | - | - -### API - - - - - -```python -class nemo_rl.data.datasets.eval_datasets.math.MathDataset( - variant: typing.Literal['math_test', 'math_500_test'] = 'math_test', - prompt_file: typing.Optional[str] = None, - system_prompt_file: typing.Optional[str] = None -) -``` - - - - - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.eval_datasets.math.MathDataset._rekey( - data: dict[str, typing.Any] -) -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx deleted file mode 100644 index 1114133..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx +++ /dev/null @@ -1,61 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu -title: nemo_rl.data.datasets.eval_datasets.mmlu ---- - -MMLU dataset and its variants. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`MMLUDataset`](#nemo_rl-data-datasets-eval_datasets-mmlu-MMLUDataset) | - | - -### API - - - - - -```python -class nemo_rl.data.datasets.eval_datasets.mmlu.MMLUDataset( - language: typing.Literal['AR-XY', 'BN-BD', 'DE-DE', 'EN-US', 'ES-LA', 'FR-FR', 'HI-IN', 'ID-ID', 'IT-IT', 'JA-JP', 'KO-KR', 'PT-BR', 'ZH-CN', 'SW-KE', 'YO-NG'] = 'EN-US', - prompt_file: typing.Optional[str] = None, - system_prompt_file: typing.Optional[str] = None -) -``` - - - - - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.eval_datasets.mmlu.MMLUDataset._rekey( - data: dict[str, typing.Any] -) -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx deleted file mode 100644 index 998a593..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx +++ /dev/null @@ -1,60 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro -title: nemo_rl.data.datasets.eval_datasets.mmlu_pro ---- - -MMLU-Pro dataset. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`MMLUProDataset`](#nemo_rl-data-datasets-eval_datasets-mmlu_pro-MMLUProDataset) | - | - -### API - - - - - -```python -class nemo_rl.data.datasets.eval_datasets.mmlu_pro.MMLUProDataset( - prompt_file: str, - system_prompt_file: typing.Optional[str] = None -) -``` - - - - - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.eval_datasets.mmlu_pro.MMLUProDataset._rekey( - data: dict[str, typing.Any] -) -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx deleted file mode 100644 index 1b101aa..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx +++ /dev/null @@ -1,72 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/preference_datasets -title: nemo_rl.data.datasets.preference_datasets ---- - -## Submodules - -- **[`nemo_rl.data.datasets.preference_datasets.binary_preference_dataset`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset)** -- **[`nemo_rl.data.datasets.preference_datasets.helpsteer3`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3)** -- **[`nemo_rl.data.datasets.preference_datasets.preference_dataset`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset)** -- **[`nemo_rl.data.datasets.preference_datasets.tulu3`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3)** - -## Package Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`load_preference_dataset`](#nemo_rl-data-datasets-preference_datasets-load_preference_dataset) | Loads preference dataset. | - -### Data - -[`DATASET_REGISTRY`](#nemo_rl-data-datasets-preference_datasets-DATASET_REGISTRY) - -[`__all__`](#nemo_rl-data-datasets-preference_datasets-__all__) - -### API - - - - - -```python -nemo_rl.data.datasets.preference_datasets.load_preference_dataset( - data_config: nemo_rl.data.PreferenceDatasetConfig -) -``` - - - - - - -Loads preference dataset. - - - - - - - - -```python -nemo_rl.data.datasets.preference_datasets.DATASET_REGISTRY = {'HelpSteer3': HelpSteer3Dataset, 'Tulu3Preference': Tulu3PreferenceDataset, 'Bi... -``` - - - - - - - - - -```python -nemo_rl.data.datasets.preference_datasets.__all__ = ['BinaryPreferenceDataset', 'HelpSteer3Dataset', 'PreferenceDataset', 'Tulu3Pref... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx deleted file mode 100644 index 762ddd7..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx +++ /dev/null @@ -1,102 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset -title: nemo_rl.data.datasets.preference_datasets.binary_preference_dataset ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`BinaryPreferenceDataset`](#nemo_rl-data-datasets-preference_datasets-binary_preference_dataset-BinaryPreferenceDataset) | Dataset class for binary preference data which can be loaded from a JSON file. | - -### API - - - - - -```python -class nemo_rl.data.datasets.preference_datasets.binary_preference_dataset.BinaryPreferenceDataset( - data_path: str, - prompt_key: str = 'prompt', - chosen_key: str = 'chosen', - rejected_key: str = 'rejected', - subset: typing.Optional[str] = None, - split: typing.Optional[str] = None, - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Dataset class for binary preference data which can be loaded from a JSON file. - -This class handles loading of preference data for DPO and RM training. -It will be converted to the format of PreferenceDataset through the `to_preference_data_format` function. - -The input JSONL files should contain valid JSON objects formatted like this: -{ - prompt_key: str, # The input prompt/context - chosen_key: str, # The preferred/winning response - rejected_key: str, # The non-preferred/losing response -} -Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/dpo.md#datasets for more details. - -**Parameters:** - - -Path to the dataset JSON file - - - -Key for the input prompt/context, default is "prompt" - - - -Key for the preferred/winning response, default is "chosen" - - - -Key for the non-preferred/losing response, default is "rejected" - - - -Optional subset name for the dataset, used for HuggingFace datasets - - - -Optional split name for the dataset, used for HuggingFace datasets - - - - - - - - - - - - -```python -nemo_rl.data.datasets.preference_datasets.binary_preference_dataset.BinaryPreferenceDataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx deleted file mode 100644 index a88c29e..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx +++ /dev/null @@ -1,66 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3 -title: nemo_rl.data.datasets.preference_datasets.helpsteer3 ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`HelpSteer3Dataset`](#nemo_rl-data-datasets-preference_datasets-helpsteer3-HelpSteer3Dataset) | HelpSteer3 preference dataset for DPO training. | - -### API - - - - - -```python -class nemo_rl.data.datasets.preference_datasets.helpsteer3.HelpSteer3Dataset( - split: str = 'train', - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -HelpSteer3 preference dataset for DPO training. - -**Parameters:** - - -Split name for the dataset, default is "train" - - - - - - - - - - - - -```python -nemo_rl.data.datasets.preference_datasets.helpsteer3.HelpSteer3Dataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx deleted file mode 100644 index 0264eec..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx +++ /dev/null @@ -1,77 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset -title: nemo_rl.data.datasets.preference_datasets.preference_dataset ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`PreferenceDataset`](#nemo_rl-data-datasets-preference_datasets-preference_dataset-PreferenceDataset) | Dataset class for preference data which can be loaded from a JSON file. | - -### API - - - - - -```python -class nemo_rl.data.datasets.preference_datasets.preference_dataset.PreferenceDataset( - data_path: str, - subset: typing.Optional[str] = None, - split: typing.Optional[str] = None, - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Dataset class for preference data which can be loaded from a JSON file. - -This class handles loading of preference data for DPO and RM training. -The input JSONL files should contain valid JSON objects formatted like this: -{ - "context": list[dict], # The prompt message (including previous turns, if any) - "completions": [ # The list of completions - { - "rank": 0, # The rank of the completion (lower rank is preferred) - "completion": list[dict], # The completion message(s) - }, - { - "rank": 1, # The rank of the completion (lower rank is preferred) - "completion": list[dict], # The completion message(s) - }, - ... # More completions can be added if needed - ] -} -Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/dpo.md#datasets for more details. - -**Parameters:** - - -Path to the dataset JSON file - - - -Optional subset name for the dataset, used for HuggingFace datasets - - - -Optional split name for the dataset, used for HuggingFace datasets - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx deleted file mode 100644 index 0a7c89c..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx +++ /dev/null @@ -1,59 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3 -title: nemo_rl.data.datasets.preference_datasets.tulu3 ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`Tulu3PreferenceDataset`](#nemo_rl-data-datasets-preference_datasets-tulu3-Tulu3PreferenceDataset) | Tulu3 preference dataset for DPO training. | - -### API - - - - - -```python -class nemo_rl.data.datasets.preference_datasets.tulu3.Tulu3PreferenceDataset( - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Tulu3 preference dataset for DPO training. - - - - - - - - - - - -```python -nemo_rl.data.datasets.preference_datasets.tulu3.Tulu3PreferenceDataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx deleted file mode 100644 index 130991c..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx +++ /dev/null @@ -1,135 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/processed_dataset -title: nemo_rl.data.datasets.processed_dataset ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AllTaskProcessedDataset`](#nemo_rl-data-datasets-processed_dataset-AllTaskProcessedDataset) | Dataset for processing single or multi-task data with task-specific tokenization and processing. | - -### Data - -[`TokenizerType`](#nemo_rl-data-datasets-processed_dataset-TokenizerType) - -### API - - - - - -```python -class nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset( - dataset: datasets.Dataset | typing.Any, - tokenizer: nemo_rl.data.datasets.processed_dataset.TokenizerType, - default_task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, - task_data_processors: dict[str, tuple[nemo_rl.data.interfaces.TaskDataSpec, nemo_rl.data.interfaces.TaskDataProcessFnCallable]] | nemo_rl.data.interfaces.TaskDataProcessFnCallable, - max_seq_length: typing.Optional[int] = None -) -``` - - - - - - -Dataset for processing single or multi-task data with task-specific tokenization and processing. - -**Parameters:** - - -Input dataset containing raw data - - - -Tokenizer for text processing - - - -Default task processing specifications. -In the case of single-task, this is the spec used for processing all entries. -In the case of multi-task, any values not specified in the task-specific specs will be taken from the default spec. - - - -Either a single TaskDataProcessFnCallable for single-task, -or a dict mapping task names to (TaskDataSpec, TaskDataProcessFnCallable) for multi-task - - - -Maximum sequence length for tokenized outputs - - - - - - - -```python -nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset.__getitem__( - idx: int -) -> nemo_rl.data.interfaces.DatumSpec -``` - - - - - - -Return a single prompt. - - - - - - - -```python -nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset.__len__() -> int -``` - - - - - - - - - - - - -```python -nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset.encode_single( - text: typing.Union[str, list[str]] -) -> tuple[list[int] | torch.Tensor, int] -``` - - - - - - -Takes either a single string or a list of strings that represent multiple turns for the same conversation. - -Returns a single (concatenated) list of tokenized ids and the length of the tokenized ids. - - - - - - - - - -```python -nemo_rl.data.datasets.processed_dataset.TokenizerType = Union[PreTrainedTokenizerBase, AutoProcessor] -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx deleted file mode 100644 index af7a37b..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx +++ /dev/null @@ -1,94 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/raw_dataset -title: nemo_rl.data.datasets.raw_dataset ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`RawDataset`](#nemo_rl-data-datasets-raw_dataset-RawDataset) | - | - -### API - - - - - -```python -class nemo_rl.data.datasets.raw_dataset.RawDataset() -``` - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.raw_dataset.RawDataset.set_processor() -``` - - - - - - - - - - - - -```python -nemo_rl.data.datasets.raw_dataset.RawDataset.set_task_spec( - data_config: nemo_rl.data.ResponseDatasetConfig | nemo_rl.data.PreferenceDatasetConfig -) -``` - - - - - - - - - - - - -```python -nemo_rl.data.datasets.raw_dataset.RawDataset.split_train_validation( - test_size: float, - seed: int -) -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx deleted file mode 100644 index 3892488..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx +++ /dev/null @@ -1,82 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets -title: nemo_rl.data.datasets.response_datasets ---- - -## Submodules - -- **[`nemo_rl.data.datasets.response_datasets.aime24`](/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24)** -- **[`nemo_rl.data.datasets.response_datasets.clevr`](/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr)** -- **[`nemo_rl.data.datasets.response_datasets.dapo_math`](/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math)** -- **[`nemo_rl.data.datasets.response_datasets.deepscaler`](/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler)** -- **[`nemo_rl.data.datasets.response_datasets.geometry3k`](/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k)** -- **[`nemo_rl.data.datasets.response_datasets.helpsteer3`](/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3)** -- **[`nemo_rl.data.datasets.response_datasets.nemogym_dataset`](/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset)** -- **[`nemo_rl.data.datasets.response_datasets.oai_format_dataset`](/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset)** -- **[`nemo_rl.data.datasets.response_datasets.oasst`](/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst)** -- **[`nemo_rl.data.datasets.response_datasets.openmathinstruct2`](/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2)** -- **[`nemo_rl.data.datasets.response_datasets.refcoco`](/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco)** -- **[`nemo_rl.data.datasets.response_datasets.response_dataset`](/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset)** -- **[`nemo_rl.data.datasets.response_datasets.squad`](/nemo-rl/nemo_rl/data/datasets/response_datasets/squad)** -- **[`nemo_rl.data.datasets.response_datasets.tulu3`](/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3)** - -## Package Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`load_response_dataset`](#nemo_rl-data-datasets-response_datasets-load_response_dataset) | Loads response dataset. | - -### Data - -[`DATASET_REGISTRY`](#nemo_rl-data-datasets-response_datasets-DATASET_REGISTRY) - -[`__all__`](#nemo_rl-data-datasets-response_datasets-__all__) - -### API - - - - - -```python -nemo_rl.data.datasets.response_datasets.load_response_dataset( - data_config: nemo_rl.data.ResponseDatasetConfig -) -``` - - - - - - -Loads response dataset. - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.DATASET_REGISTRY = {'AIME2024': AIME2024Dataset, 'clevr-cogent': CLEVRCoGenTDataset, 'DAPOMath17K':... -``` - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.__all__ = ['AIME2024Dataset', 'CLEVRCoGenTDataset', 'DAPOMath17KDataset', 'DAPOMathAIME202... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx deleted file mode 100644 index 334fee4..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx +++ /dev/null @@ -1,66 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24 -title: nemo_rl.data.datasets.response_datasets.aime24 ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AIME2024Dataset`](#nemo_rl-data-datasets-response_datasets-aime24-AIME2024Dataset) | Simple wrapper around the AIME2024 dataset with train split. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.aime24.AIME2024Dataset( - repeat: int = 16, - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the AIME2024 dataset with train split. - -**Parameters:** - - -Number of times to repeat the dataset, default is 16 - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.aime24.AIME2024Dataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx deleted file mode 100644 index 2bf7236..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx +++ /dev/null @@ -1,97 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr -title: nemo_rl.data.datasets.response_datasets.clevr ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`CLEVRCoGenTDataset`](#nemo_rl-data-datasets-response_datasets-clevr-CLEVRCoGenTDataset) | Simple wrapper around the CLEVR-CoGenT dataset. | - -### Functions - -| Name | Description | -|------|-------------| -| [`format_answer_fromtags`](#nemo_rl-data-datasets-response_datasets-clevr-format_answer_fromtags) | Extract content between <answer> tags and strip whitespace. | -| [`format_clevr_cogent_dataset`](#nemo_rl-data-datasets-response_datasets-clevr-format_clevr_cogent_dataset) | Format the CLEVR-CoGenT dataset into an OpenAI-API-like message log. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.clevr.CLEVRCoGenTDataset( - split: str = 'train', - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the CLEVR-CoGenT dataset. - -**Parameters:** - - -Split name for the dataset, default is "train" - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.clevr.format_answer_fromtags( - answer: str -) -> str -``` - - - - - - -Extract content between <answer> tags and strip whitespace. - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.clevr.format_clevr_cogent_dataset( - example: dict[str, typing.Any], - return_pil: bool = False -) -> dict[str, typing.Any] -``` - - - - - - -Format the CLEVR-CoGenT dataset into an OpenAI-API-like message log. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx deleted file mode 100644 index 6866067..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx +++ /dev/null @@ -1,84 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math -title: nemo_rl.data.datasets.response_datasets.dapo_math ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`DAPOMath17KDataset`](#nemo_rl-data-datasets-response_datasets-dapo_math-DAPOMath17KDataset) | Simple wrapper around the DAPO Math 17K dataset with train split. | -| [`DAPOMathAIME2024Dataset`](#nemo_rl-data-datasets-response_datasets-dapo_math-DAPOMathAIME2024Dataset) | - | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.dapo_math.DAPOMath17KDataset( - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the DAPO Math 17K dataset with train split. - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.dapo_math.DAPOMath17KDataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - - - - - - -```python -class nemo_rl.data.datasets.response_datasets.dapo_math.DAPOMathAIME2024Dataset( - kwargs = {} -) -``` - - - - - - -**Bases:** [DAPOMath17KDataset](#nemo_rl-data-datasets-response_datasets-dapo_math-DAPOMath17KDataset) - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx deleted file mode 100644 index e1a6e7d..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx +++ /dev/null @@ -1,59 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler -title: nemo_rl.data.datasets.response_datasets.deepscaler ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`DeepScalerDataset`](#nemo_rl-data-datasets-response_datasets-deepscaler-DeepScalerDataset) | Simple wrapper around the DeepScaler dataset with train split. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.deepscaler.DeepScalerDataset( - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the DeepScaler dataset with train split. - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.deepscaler.DeepScalerDataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx deleted file mode 100644 index a4be5a2..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx +++ /dev/null @@ -1,76 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k -title: nemo_rl.data.datasets.response_datasets.geometry3k ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`Geometry3KDataset`](#nemo_rl-data-datasets-response_datasets-geometry3k-Geometry3KDataset) | Simple wrapper around the Geometry3K dataset. | - -### Functions - -| Name | Description | -|------|-------------| -| [`format_geometry3k_dataset`](#nemo_rl-data-datasets-response_datasets-geometry3k-format_geometry3k_dataset) | Format the Geometry3K dataset into an OpenAI-API-like message log. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.geometry3k.Geometry3KDataset( - split: str = 'train', - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the Geometry3K dataset. - -**Parameters:** - - -Split name for the dataset, default is "train" - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.geometry3k.format_geometry3k_dataset( - example: dict[str, typing.Any], - return_pil: bool = False -) -> dict[str, typing.Any] -``` - - - - - - -Format the Geometry3K dataset into an OpenAI-API-like message log. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx deleted file mode 100644 index 7176bf4..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx +++ /dev/null @@ -1,66 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3 -title: nemo_rl.data.datasets.response_datasets.helpsteer3 ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`HelpSteer3Dataset`](#nemo_rl-data-datasets-response_datasets-helpsteer3-HelpSteer3Dataset) | Simple wrapper around the HelpSteer3 dataset with preference subset. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.helpsteer3.HelpSteer3Dataset( - split: str = 'train', - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the HelpSteer3 dataset with preference subset. - -**Parameters:** - - -Split name for the dataset, default is "train" - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.helpsteer3.HelpSteer3Dataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx deleted file mode 100644 index 54915b0..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx +++ /dev/null @@ -1,54 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset -title: nemo_rl.data.datasets.response_datasets.nemogym_dataset ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`NemoGymDataset`](#nemo_rl-data-datasets-response_datasets-nemogym_dataset-NemoGymDataset) | Simple wrapper around the Nemo Gym dataset. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.nemogym_dataset.NemoGymDataset( - data_path: str, - repeat: int = 1, - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the Nemo Gym dataset. - -**Parameters:** - - -Path to the dataset JSONL file - - - -Number of times to repeat the dataset, default is 1 - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx deleted file mode 100644 index f1c75e2..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx +++ /dev/null @@ -1,214 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset -title: nemo_rl.data.datasets.response_datasets.oai_format_dataset ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`OpenAIFormatDataset`](#nemo_rl-data-datasets-response_datasets-oai_format_dataset-OpenAIFormatDataset) | This class is used to load an SFT dataset in the OpenAI format. | -| [`PreservingDataset`](#nemo_rl-data-datasets-response_datasets-oai_format_dataset-PreservingDataset) | A dataset wrapper that preserves original dict structure without None-filling. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.oai_format_dataset.OpenAIFormatDataset( - data_path: str, - chat_key: str = 'messages', - system_key: str | None = None, - system_prompt: str | None = None, - tool_key: str | None = 'tools', - use_preserving_dataset: bool = False, - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -This class is used to load an SFT dataset in the OpenAI format. - -The dataset should be in the following format: -{ - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - {"role": "assistant", "content": "The capital of France is Paris."} - ] -} -Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#openai-format-datasets-with-tool-calling-support for more details. - -**Parameters:** - - -Path to the dataset JSON file - - - -Key for the messages list in the dataset (default: "messages") - - - -Optional key for system prompt in the dataset - - - -Optional system prompt to add if not in the dataset - - - -Key for tools in the dataset (default: "tools") - - - -If True, uses PreservingDataset to maintain -heterogeneous schemas (e.g., for tool calls with varying argument -structures). If False, uses standard HuggingFace dataset loading. -Default is False for backward compatibility. - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.oai_format_dataset.OpenAIFormatDataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - - - - - - -```python -class nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset( - data: list[dict[str, typing.Any]] -) -``` - - - - - - -A dataset wrapper that preserves original dict structure without None-filling. - -Unlike HuggingFace's Dataset class which enforces schema uniformity across all samples -(filling missing keys with None), this class maintains the exact structure of each sample. -This is critical for heterogeneous data like tool calls where different samples may have -different argument structures. - - - - - - -```python -nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.__getitem__( - idx: typing.Union[int, slice, list] -) -> typing.Union[dict[str, typing.Any], list[dict[str, typing.Any]]] -``` - - - - - - -Support integer indexing, slicing, and list indexing. - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.__iter__() -``` - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.__len__() -> int -``` - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.map( - function: typing.Callable, - args = (), - kwargs = {} -) -> nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset -``` - - - - - - -Apply a function to each sample in the dataset. - -**Parameters:** - - -Function to apply to each sample - - - -If True, pass index as second argument to function - - -**Returns:** `PreservingDataset` - -New PreservingDataset with transformed samples - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx deleted file mode 100644 index 8a92408..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx +++ /dev/null @@ -1,127 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst -title: nemo_rl.data.datasets.response_datasets.oasst ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`OasstDataset`](#nemo_rl-data-datasets-response_datasets-oasst-OasstDataset) | Simple wrapper around the OASST dataset. | - -### Functions - -| Name | Description | -|------|-------------| -| [`get_data_records`](#nemo_rl-data-datasets-response_datasets-oasst-get_data_records) | - | -| [`parse_conversations`](#nemo_rl-data-datasets-response_datasets-oasst-parse_conversations) | Recusive function that returns all the sub converstaions in a list starting from node tree_obj. | - -### Data - -[`SYSTEM_PROMPT`](#nemo_rl-data-datasets-response_datasets-oasst-SYSTEM_PROMPT) - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.oasst.OasstDataset( - split_validation_size: float = 0.05, - seed: int = 42, - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the OASST dataset. - -**Parameters:** - - -Size of the validation data, default is 0.05 - - - -Seed for train/validation split when split_validation_size > 0, default is 42 - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.oasst.get_data_records( - objs, - task_name: str = 'oasst' -) -``` - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.oasst.parse_conversations( - tree_obj, - first: bool = False -) -``` - - - - - - -Recusive function that returns all the sub converstaions in a list starting from node tree_obj. - -**Parameters:** - - -current conversation node - - -**Returns:** - -a list of sub conversation threads including the current conversation node - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.oasst.SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The ass... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx deleted file mode 100644 index 6124a92..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx +++ /dev/null @@ -1,84 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2 -title: nemo_rl.data.datasets.response_datasets.openmathinstruct2 ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`OpenMathInstruct2Dataset`](#nemo_rl-data-datasets-response_datasets-openmathinstruct2-OpenMathInstruct2Dataset) | Simple wrapper around the OpenMathInstruct2 dataset. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.openmathinstruct2.OpenMathInstruct2Dataset( - output_key: str = 'expected_answer', - split: str = 'train_1M', - split_validation_size: float = 0.05, - seed: int = 42, - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the OpenMathInstruct2 dataset. - -**Parameters:** - - -Key for the output text, default is "expected_answer" - - - -Split name for the dataset, default is "train_1M" - - - -Size of the validation data, default is 0.05 - - - -Seed for train/validation split when split_validation_size > 0, default is 42 - - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.openmathinstruct2.OpenMathInstruct2Dataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx deleted file mode 100644 index 34e0b8d..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx +++ /dev/null @@ -1,160 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco -title: nemo_rl.data.datasets.response_datasets.refcoco ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`RefCOCODataset`](#nemo_rl-data-datasets-response_datasets-refcoco-RefCOCODataset) | Simple wrapper around the RefCOCO dataset. | - -### Functions - -| Name | Description | -|------|-------------| -| [`download_and_unzip`](#nemo_rl-data-datasets-response_datasets-refcoco-download_and_unzip) | Downloads a zip file from a given URL to a target directory and unzips it into a specified subdirectory within the target directory, showing download progress. | -| [`format_refcoco_dataset`](#nemo_rl-data-datasets-response_datasets-refcoco-format_refcoco_dataset) | Format the RefCOCO dataset from huggingface. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.refcoco.RefCOCODataset( - split: str = 'train', - download_dir: str = './coco_images', - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the RefCOCO dataset. - -**Parameters:** - - -Split name for the dataset, default is "train" - - - -Directory to download the dataset to, default is "./coco_images" - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.refcoco.RefCOCODataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.refcoco.download_and_unzip( - url: str, - target_directory: str, - subdir_name: str = '.' -) -``` - - - - - - -Downloads a zip file from a given URL to a target directory and unzips it into a specified subdirectory within the target directory, showing download progress. - -**Parameters:** - - -The URL of the zip file to download. - - - -The directory where the zip file will be downloaded - and unzipped. - - - -The name of the subdirectory within the target_directory - where the contents of the zip file will be unzipped. - Defaults to "train". - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.refcoco.format_refcoco_dataset( - example: dict[str, typing.Any], - width: int = 256, - height: int = 256, - caption_type: str = 'random' -) -> dict[str, typing.Any] -``` - - - - - - -Format the RefCOCO dataset from huggingface. - -This should be replaced with our own curated RefCOCO/+/g dataset soon - -**Parameters:** - - -The example to format. - - - -The width of the resized image. - - - -The height of the resized image. - - - -The type of caption to use. - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx deleted file mode 100644 index 09bbf07..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx +++ /dev/null @@ -1,104 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset -title: nemo_rl.data.datasets.response_datasets.response_dataset ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`ResponseDataset`](#nemo_rl-data-datasets-response_datasets-response_dataset-ResponseDataset) | Dataset class for response data which can be loaded from a JSON file. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.response_dataset.ResponseDataset( - data_path: str, - input_key: str = 'input', - output_key: str = 'output', - subset: typing.Optional[str] = None, - split: typing.Optional[str] = None, - split_validation_size: float = 0, - seed: int = 42, - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Dataset class for response data which can be loaded from a JSON file. - -This class handles loading of response data for SFT and RL training. -The input JSONL files should contain valid JSON objects formatted like this: -{ - input_key: str, # The input prompt/context - output_key: str, # The output response/answer -} -Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details. - -**Parameters:** - - -Path to the dataset JSON file - - - -Key for the input text, default is "input" - - - -Key for the output text, default is "output" - - - -Optional subset name for the dataset, used for HuggingFace datasets - - - -Optional split name for the dataset, used for HuggingFace datasets - - - -Size of the validation data, default is 0 - - - -Seed for train/validation split when split_validation_size > 0, default is 42 - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.response_dataset.ResponseDataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx deleted file mode 100644 index f201b41..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx +++ /dev/null @@ -1,66 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/squad -title: nemo_rl.data.datasets.response_datasets.squad ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`SquadDataset`](#nemo_rl-data-datasets-response_datasets-squad-SquadDataset) | Simple wrapper around the squad dataset. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.squad.SquadDataset( - split: str = 'train', - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the squad dataset. - -**Parameters:** - - -Split name for the dataset, default is "train" - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.squad.SquadDataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx deleted file mode 100644 index 68dfa11..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx +++ /dev/null @@ -1,76 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3 -title: nemo_rl.data.datasets.response_datasets.tulu3 ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`Tulu3SftMixtureDataset`](#nemo_rl-data-datasets-response_datasets-tulu3-Tulu3SftMixtureDataset) | Simple wrapper around the Tulu3 SFT mixture dataset with train split. | - -### API - - - - - -```python -class nemo_rl.data.datasets.response_datasets.tulu3.Tulu3SftMixtureDataset( - split_validation_size: float = 0.05, - seed: int = 42, - max_samples: int | None = None, - kwargs = {} -) -``` - - - - - - -**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) - -Simple wrapper around the Tulu3 SFT mixture dataset with train split. - -**Parameters:** - - -Size of the validation data, default is 0.05 - - - -Seed for train/validation split when split_validation_size > 0, default is 42 - - - -Optional maximum number of samples to use from the dataset - - - - - - - - - - - - -```python -nemo_rl.data.datasets.response_datasets.tulu3.Tulu3SftMixtureDataset.format_data( - data: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx deleted file mode 100644 index d5a02db..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx +++ /dev/null @@ -1,191 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/datasets/utils -title: nemo_rl.data.datasets.utils ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`assert_no_double_bos`](#nemo_rl-data-datasets-utils-assert_no_double_bos) | Assert that there are no double starting BOS tokens in the message. | -| [`extract_necessary_env_names`](#nemo_rl-data-datasets-utils-extract_necessary_env_names) | Extract the necessary environment names from the data config. | -| [`load_dataset_from_path`](#nemo_rl-data-datasets-utils-load_dataset_from_path) | Load a dataset from a local file, huggingface dataset, or Arrow dataset (saved with save_to_disk). | -| [`pil_to_base64`](#nemo_rl-data-datasets-utils-pil_to_base64) | Converts a PIL Image object to a base64 encoded string. | -| [`update_single_dataset_config`](#nemo_rl-data-datasets-utils-update_single_dataset_config) | Fill the single dataset config with default dataset config. | - -### Data - -[`TokenizerType`](#nemo_rl-data-datasets-utils-TokenizerType) - -### API - - - - - -```python -nemo_rl.data.datasets.utils.assert_no_double_bos( - token_ids: torch.Tensor, - tokenizer: nemo_rl.data.datasets.utils.TokenizerType -) -> None -``` - - - - - - -Assert that there are no double starting BOS tokens in the message. - -**Parameters:** - - -List of token IDs - - - -Tokenizer - - - - - - - - - -```python -nemo_rl.data.datasets.utils.extract_necessary_env_names( - data_config: dict -) -> list[str] -``` - - - - - - -Extract the necessary environment names from the data config. - -Some environments are set in env_configs but not used in the data config. -This function extracts the necessary environment names from the data config. - -**Parameters:** - - -The data config. - - -**Returns:** `list[str]` - -The necessary environment names. - - - - - - - - -```python -nemo_rl.data.datasets.utils.load_dataset_from_path( - data_path: str, - data_subset: typing.Optional[str] = None, - data_split: typing.Optional[str] = 'train' -) -``` - - - - - - -Load a dataset from a local file, huggingface dataset, or Arrow dataset (saved with save_to_disk). - -**Parameters:** - - -The path to the dataset. - - - -The subset to load from the dataset. Only supported for huggingface datasets. - - - -The split to load from the dataset. - - - - - - - - - -```python -nemo_rl.data.datasets.utils.pil_to_base64( - image: PIL.Image.Image, - format: str = 'PNG' -) -> str -``` - - - - - - -Converts a PIL Image object to a base64 encoded string. - -**Parameters:** - - -The PIL Image object to convert. - - - -The image format (e.g., "PNG", "JPEG"). Defaults to "PNG". - - -**Returns:** `str` - -A base64 encoded string representation of the image. - - - - - - - - -```python -nemo_rl.data.datasets.utils.update_single_dataset_config( - data_config: dict, - default_data_config: dict -) -> None -``` - - - - - - -Fill the single dataset config with default dataset config. - - - - - - - - -```python -nemo_rl.data.datasets.utils.TokenizerType = Union[PreTrainedTokenizerBase, AutoProcessor] -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx deleted file mode 100644 index b435173..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx +++ /dev/null @@ -1,284 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/interfaces -title: nemo_rl.data.interfaces ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`DatumSpec`](#nemo_rl-data-interfaces-DatumSpec) | - | -| [`PreferenceDatumSpec`](#nemo_rl-data-interfaces-PreferenceDatumSpec) | - | -| [`TaskDataProcessFnCallable`](#nemo_rl-data-interfaces-TaskDataProcessFnCallable) | A callable that processes a loaded datum dictionary into a DatumSpec. | -| [`TaskDataSpec`](#nemo_rl-data-interfaces-TaskDataSpec) | - | - -### Data - -[`FlatMessagesType`](#nemo_rl-data-interfaces-FlatMessagesType) - -[`LLMMessageLogType`](#nemo_rl-data-interfaces-LLMMessageLogType) - -[`PathLike`](#nemo_rl-data-interfaces-PathLike) - -[`TokenizerType`](#nemo_rl-data-interfaces-TokenizerType) - -[`VLMMessageLogType`](#nemo_rl-data-interfaces-VLMMessageLogType) - -### API - - - - - -```python -class nemo_rl.data.interfaces.DatumSpec -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.data.interfaces.PreferenceDatumSpec -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.data.interfaces.TaskDataProcessFnCallable() -``` - - - - - - -Protocol - -A callable that processes a loaded datum dictionary into a DatumSpec. - - - - - - -```python -nemo_rl.data.interfaces.TaskDataProcessFnCallable.__call__( - datum_dict: dict[str, typing.Any], - task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, - tokenizer: nemo_rl.data.interfaces.TokenizerType, - max_seq_length: int | None, - idx: int -) -> nemo_rl.data.interfaces.DatumSpec -``` - - - - - - - - - - - - - - -```python -class nemo_rl.data.interfaces.TaskDataSpec( - task_name: typing.Optional[str] = None, - prompt_file: typing.Optional[nemo_rl.data.interfaces.PathLike] = None, - system_prompt_file: typing.Optional[nemo_rl.data.interfaces.PathLike] = None -) -``` - - - - - - -Dataclass - - - - - - - - - - - - - -```python -nemo_rl.data.interfaces.TaskDataSpec.__post_init__() -> None -``` - - - - - - - - - - - - -```python -nemo_rl.data.interfaces.TaskDataSpec.copy_defaults( - from_spec: nemo_rl.data.interfaces.TaskDataSpec -) -> None -``` - - - - - - -Apply default values from another Task instance for any None attributes. - - - - - - - - - -```python -nemo_rl.data.interfaces.FlatMessagesType = dict[str, Union[list[str], torch.Tensor]] -``` - - - - - - - - - -```python -nemo_rl.data.interfaces.LLMMessageLogType = list[dict[str, Union[str, torch.Tensor]]] -``` - - - - - - - - - -```python -nemo_rl.data.interfaces.PathLike = Union[str, 'os.PathLike[Any]'] -``` - - - - - - - - - -```python -nemo_rl.data.interfaces.TokenizerType = PreTrainedTokenizerBase -``` - - - - - - - - - -```python -nemo_rl.data.interfaces.VLMMessageLogType = list[dict[str, Union[str, torch.Tensor, PackedTensor]]] -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx deleted file mode 100644 index 49a124a..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx +++ /dev/null @@ -1,548 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/llm_message_utils -title: nemo_rl.data.llm_message_utils ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`_pad_tensor`](#nemo_rl-data-llm_message_utils-_pad_tensor) | Pad a tensor to the specified length. | -| [`_validate_tensor_consistency`](#nemo_rl-data-llm_message_utils-_validate_tensor_consistency) | Validate that all tensors have consistent dtypes and devices. | -| [`add_loss_mask_to_message_log`](#nemo_rl-data-llm_message_utils-add_loss_mask_to_message_log) | Add token-level loss masks to each message in a message log. | -| [`batched_message_log_to_flat_message`](#nemo_rl-data-llm_message_utils-batched_message_log_to_flat_message) | Process and pad a batch of message logs for model input. | -| [`get_first_index_that_differs`](#nemo_rl-data-llm_message_utils-get_first_index_that_differs) | Get the first index that differs between two strings. | -| [`get_formatted_message_log`](#nemo_rl-data-llm_message_utils-get_formatted_message_log) | Format and tokenize chat messages using the specified template. | -| [`get_images_from_message`](#nemo_rl-data-llm_message_utils-get_images_from_message) | Get all images from a message log item. | -| [`get_keys_from_message_log`](#nemo_rl-data-llm_message_utils-get_keys_from_message_log) | Return a new LLMMessageLogType containing only the specified keys from each message. | -| [`message_log_shape`](#nemo_rl-data-llm_message_utils-message_log_shape) | Get the shape of the tensors in the message log. | -| [`message_log_to_flat_messages`](#nemo_rl-data-llm_message_utils-message_log_to_flat_messages) | Converts a message log (sequence of message turns) into a flattened representation. | -| [`remap_dataset_keys`](#nemo_rl-data-llm_message_utils-remap_dataset_keys) | Remap dataset keys as per mapping. | - -### Data - -[`Tensor`](#nemo_rl-data-llm_message_utils-Tensor) - -[`TokenizerType`](#nemo_rl-data-llm_message_utils-TokenizerType) - -### API - - - - - -```python -nemo_rl.data.llm_message_utils._pad_tensor( - tensor: nemo_rl.data.llm_message_utils.Tensor, - max_len: int, - pad_side: str, - pad_value: int = 0 -) -> nemo_rl.data.llm_message_utils.Tensor -``` - - - - - - -Pad a tensor to the specified length. - -**Parameters:** - - -Tensor to pad - - - -Length to pad to - - - -Whether to pad on the 'left' or 'right' - - - -Value to use for padding - - -**Returns:** `Tensor` - -torch.Tensor: Padded tensor - - - - - - - - -```python -nemo_rl.data.llm_message_utils._validate_tensor_consistency( - tensors: list[nemo_rl.data.llm_message_utils.Tensor] -) -> None -``` - - - - - - -Validate that all tensors have consistent dtypes and devices. - -**Parameters:** - - -List of tensors to validate - - -**Raises:** - -- `RuntimeError`: If tensors have different dtypes or devices - - - - - - - - -```python -nemo_rl.data.llm_message_utils.add_loss_mask_to_message_log( - batch_message_log: list[nemo_rl.data.interfaces.LLMMessageLogType], - roles_to_train_on: list[str] = ['assistant'], - only_unmask_final: bool = False -) -> None -``` - - - - - - -Add token-level loss masks to each message in a message log. - -**Parameters:** - - -List of message dictionaries containing token IDs and metadata - - - -List of strings indicating which speakers to unmask. Default: ["assistant"] - - - -If True, only unmask the final message in the log. Default: False - - - - - - - - - -```python -nemo_rl.data.llm_message_utils.batched_message_log_to_flat_message( - message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], - pad_value_dict: typing.Optional[dict[str, int]] = None, - make_sequence_length_divisible_by: int = 1 -) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.FlatMessagesType], nemo_rl.data.llm_message_utils.Tensor] -``` - - - - - - -Process and pad a batch of message logs for model input. - -For each message log in the batch: -1. Converts it to a flat representation using message_log_to_flat_messages -2. Pads all resulting tensors to the same length for batching -3. Returns a BatchedDataDict and sequence lengths tensor - -Padding is always applied to the right side of sequences. - -Examples: - - -```python ->>> import torch ->>> from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message ->>> from nemo_rl.distributed.batched_data_dict import BatchedDataDict ->>> # Create a batch of two message logs with different lengths ->>> message_log_batch = [ -... # First conversation -... [ -... {'role': 'user', 'content': 'What is 2+2?', 'token_ids': torch.tensor([1, 2, 3, 4, 5])}, -... {'role': 'assistant', 'content': '4', 'token_ids': torch.tensor([6, 7])} -... ], -... # Second conversation -... [ -... {'role': 'user', 'content': 'Solve x+10=15', 'token_ids': torch.tensor([1, 8, 9, 10, 11, 12])}, -... {'role': 'assistant', 'content': 'x=5', 'token_ids': torch.tensor([13, 14, 15])} -... ] -... ] ->>> pad_value_dict = {'token_ids': 0} ->>> batched_flat, input_lengths = batched_message_log_to_flat_message(message_log_batch, pad_value_dict) ->>> batched_flat['token_ids'][0].tolist() -[1, 2, 3, 4, 5, 6, 7, 0, 0] ->>> batched_flat['token_ids'][1].tolist() -[1, 8, 9, 10, 11, 12, 13, 14, 15] ->>> batched_flat['content'][0] -['What is 2+2?', '4'] ->>> batched_flat['content'][1] -['Solve x+10=15', 'x=5'] ->>> batched_flat['role'] -[['user', 'assistant'], ['user', 'assistant']] ->>> input_lengths -tensor([7, 9], dtype=torch.int32) ->>> ->>> # Multimodal example: include images on both conversations and verify packing ->>> from nemo_rl.data.multimodal_utils import PackedTensor ->>> mm_batch = [ -... [ -... {'role': 'user', 'content': 'look', 'token_ids': torch.tensor([1, 2, 3]), 'images': PackedTensor(torch.randn(2, 3, 4, 4), dim_to_pack=0)}, -... {'role': 'assistant', 'content': 'ok', 'token_ids': torch.tensor([4])} -... ], -... [ -... {'role': 'user', 'content': 'again', 'token_ids': torch.tensor([5, 6]), 'images': PackedTensor(torch.randn(1, 3, 4, 4), dim_to_pack=0)}, -... {'role': 'assistant', 'content': 'fine', 'token_ids': torch.tensor([7, 8])} -... ] -... ] ->>> mm_flat, mm_lengths = batched_message_log_to_flat_message(mm_batch, pad_value_dict={'token_ids': 0}) ->>> isinstance(mm_flat['images'], PackedTensor) -True ->>> tuple(mm_flat['images'].as_tensor().shape) # 2 + 1 images -(3, 3, 4, 4) ->>> mm_lengths -tensor([4, 4], dtype=torch.int32) ->>> -``` - - - -**Parameters:** - - -List of LLMMessageLogType (each a conversation with multiple turns) - - - -Dictionary mapping keys to padding values (default is 0) - - - -forces the data to be divisible by this value - - -**Returns:** `BatchedDataDict[FlatMessagesType]` - -BatchedDataDict[FlatMessagesType]: Dictionary containing padded stacked tensors - -**Raises:** - -- `RuntimeError`: If tensors have different dtypes or devices - - - - - - - - -```python -nemo_rl.data.llm_message_utils.get_first_index_that_differs( - str1: str, - str2: str -) -> int -``` - - - - - - -Get the first index that differs between two strings. - - - - - - - - -```python -nemo_rl.data.llm_message_utils.get_formatted_message_log( - message_log: nemo_rl.data.interfaces.LLMMessageLogType, - tokenizer: nemo_rl.data.llm_message_utils.TokenizerType, - task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, - add_bos_token: bool = True, - add_eos_token: bool = True, - add_generation_prompt: bool = False, - tools: typing.Optional[list[dict[str, typing.Any]]] = None -) -> nemo_rl.data.interfaces.LLMMessageLogType -``` - - - - - - -Format and tokenize chat messages using the specified template. - -Returns: - The message log with updated 'token_ids' and 'content' fields. - -**Parameters:** - - -List of message dicts with 'role' and 'content' keys - - - -Tokenizer for converting text to token IDs - - - -Task spec for this dataset. - - - -Whether to add bos token to first message if it is not already present. Default: True - - - -Whether to add eos token to last message if it is not already present. Default: True - - - -Whether to include assistant's generation prompt in user messages. Default: False - - - -Optional list of tool/function definitions to pass to the chat template. Default: None - - - - - - - - - -```python -nemo_rl.data.llm_message_utils.get_images_from_message( - message: dict[str, typing.Any] -) -> list[typing.Any] -``` - - - - - - -Get all images from a message log item. - - - - - - - - -```python -nemo_rl.data.llm_message_utils.get_keys_from_message_log( - message_log: nemo_rl.data.interfaces.LLMMessageLogType, - keys: list[str] -) -> nemo_rl.data.interfaces.LLMMessageLogType -``` - - - - - - -Return a new LLMMessageLogType containing only the specified keys from each message. - -**Parameters:** - - -Original message log to extract keys from - - - -List of keys to keep in each message - - -**Returns:** `LLMMessageLogType` - -New list with only specified keys - - - - - - - - -```python -nemo_rl.data.llm_message_utils.message_log_shape( - message_log: nemo_rl.data.interfaces.LLMMessageLogType -) -> list[dict[str, torch.Size]] -``` - - - - - - -Get the shape of the tensors in the message log. - -This utility function examines each message in the message log and reports -the shape of tensor values or recursively processes list values. - -**Parameters:** - - -The message log to analyze - - -**Returns:** `list[dict[str, torch.Size]]` - -List of dictionaries containing tensor shapes for each key in messages - - - - - - - - -```python -nemo_rl.data.llm_message_utils.message_log_to_flat_messages( - message_log: nemo_rl.data.interfaces.LLMMessageLogType -) -> nemo_rl.data.interfaces.FlatMessagesType -``` - - - - - - -Converts a message log (sequence of message turns) into a flattened representation. - -This function takes a message log (list of dict messages with 'role', 'content', 'token_ids', etc.) -and converts it to a flat dictionary where all tensors of the same key are concatenated and -all strings of the same key are put into lists. - -Examples: - - -```python ->>> import torch ->>> from nemo_rl.data.llm_message_utils import message_log_to_flat_messages ->>> # Create a simple message log with two messages ->>> message_log = [ -... {'role': 'user', 'content': 'Hello', 'token_ids': torch.tensor([1, 2, 3])}, -... {'role': 'assistant', 'content': 'Hi there', 'token_ids': torch.tensor([4, 5, 6, 7])} -... ] ->>> flat_msgs = message_log_to_flat_messages(message_log) ->>> flat_msgs['role'] -['user', 'assistant'] ->>> flat_msgs['content'] -['Hello', 'Hi there'] ->>> flat_msgs['token_ids'] -tensor([1, 2, 3, 4, 5, 6, 7]) ->>> ->>> # Multimodal example: ->>> from nemo_rl.data.multimodal_utils import PackedTensor ->>> img1 = torch.randn(2, 3, 4, 4) ->>> img2 = torch.randn(3, 3, 4, 4) ->>> mm_log = [ -... {'role': 'user', 'content': 'see', 'token_ids': torch.tensor([1]), 'images': PackedTensor(img1, dim_to_pack=0)}, -... {'role': 'assistant', 'content': 'ok', 'token_ids': torch.tensor([2, 3]), 'images': PackedTensor(img2, dim_to_pack=0)}, -... ] ->>> flat_mm = message_log_to_flat_messages(mm_log) ->>> tuple(flat_mm['images'].as_tensor().shape) -(5, 3, 4, 4) ->>> -``` - - - -**Parameters:** - - -List of message dictionaries with 'role', 'content', and potentially 'token_ids' - - -**Returns:** `FlatMessagesType` - -Dictionary mapping keys to concatenated tensors and string lists - - - - - - - - -```python -nemo_rl.data.llm_message_utils.remap_dataset_keys( - dataset: datasets.Dataset, - mapping_dict: dict[str, str] -) -> datasets.Dataset -``` - - - - - - -Remap dataset keys as per mapping. - -**Parameters:** - - -The input dataset to remap keys in - - - -A dictionary mapping input keys to output keys - - -**Returns:** `Dataset` - -A new dataset with remapped keys - - - - - - - - -```python -nemo_rl.data.llm_message_utils.Tensor = torch.Tensor -``` - - - - - - - - - -```python -nemo_rl.data.llm_message_utils.TokenizerType = PreTrainedTokenizerBase -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx deleted file mode 100644 index 89f71d0..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx +++ /dev/null @@ -1,298 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/multimodal_utils -title: nemo_rl.data.multimodal_utils ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`PackedTensor`](#nemo_rl-data-multimodal_utils-PackedTensor) | Wrapper around a list of torch tensors and a dimension along which to pack the tensors. | - -### Functions - -| Name | Description | -|------|-------------| -| [`get_dim_to_pack_along`](#nemo_rl-data-multimodal_utils-get_dim_to_pack_along) | Special considerations for packing certain keys from certain processors. | -| [`get_multimodal_keys_from_processor`](#nemo_rl-data-multimodal_utils-get_multimodal_keys_from_processor) | Get keys of the multimodal data that can be used as model inputs. | -| [`resolve_to_image`](#nemo_rl-data-multimodal_utils-resolve_to_image) | Resolve the image path to a PIL.Image object. | - -### API - - - - - -```python -class nemo_rl.data.multimodal_utils.PackedTensor( - tensors: typing.Union[torch.Tensor, list[typing.Optional[torch.Tensor]], list[None]], - dim_to_pack: int -) -``` - - - - - - -Wrapper around a list of torch tensors and a dimension along which to pack the tensors. - -This class is used to wrap a list of tensors along with a `dim_to_pack` parameter. -It can be used for data that can be packed along different dimensions (such as multimodal data). - -`dim_to_pack` is used to specify the dimension along which to pack the tensors. - -The list of tensors can be returned as a single packed tensor by calling `as_tensor` which will concatenate the tensors along the `dim_to_pack` dimension. - - - - - - - - -```python -nemo_rl.data.multimodal_utils.PackedTensor.__len__() -> int -``` - - - - - - - - - - - - -```python -nemo_rl.data.multimodal_utils.PackedTensor.as_tensor( - device: typing.Optional[torch.device] = None -) -> typing.Optional[torch.Tensor] -``` - - - - - - - - - - - - -```python -nemo_rl.data.multimodal_utils.PackedTensor.concat( - from_packed_tensors: list[nemo_rl.data.multimodal_utils.PackedTensor] -) -> nemo_rl.data.multimodal_utils.PackedTensor -``` - - - - - - -classmethod - -Concatenate a list of PackedTensor objects into a single PackedTensor. - -The underlying tensors from the PackedTensors are combined into a single list of tensors and used to create a new PackedTensor. - -Each batch must have the same dim_to_pack. - -Example: - - -```python ->>> import torch ->>> from nemo_rl.data.multimodal_utils import PackedTensor ->>> p1 = PackedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])], dim_to_pack=0) ->>> p2 = PackedTensor([torch.tensor([7, 8, 9])], dim_to_pack=0) ->>> p3 = PackedTensor.concat([p1, p2]) ->>> p3.tensors -[tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])] ->>> p3.as_tensor() -tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]) ->>> -``` - - - - - - - - - -```python -nemo_rl.data.multimodal_utils.PackedTensor.empty_like( - other: nemo_rl.data.multimodal_utils.PackedTensor -) -> nemo_rl.data.multimodal_utils.PackedTensor -``` - - - - - - -classmethod - -Return a new PackedTensor with same length and dim_to_pack as `other`, with all entries None. - - - - - - - -```python -nemo_rl.data.multimodal_utils.PackedTensor.flattened_concat( - from_packed_tensors: list[nemo_rl.data.multimodal_utils.PackedTensor] -) -> nemo_rl.data.multimodal_utils.PackedTensor -``` - - - - - - -classmethod - -Given a list of PackedTensor objects, flattens each PackedTensor and then concatenates them into a single PackedTensor. - -Each PackedTensor is first flattened by packing along the PackedTensor's `dim_to_pack` dimension. Then, the resulting flattened tensors are used to create a new PackedTensor. - -This is different from `PackedTensor.concat` which simply extends the underlying list of tensors. This is important because the `slice` and `__len__` methods operate on the underlying list of tensors. Note, however, that calling `as_tensor` on the resulting PackedTensor will result in the same tensor as `concat`. - -Each batch must have the same dim_to_pack. - -Example: - - -```python ->>> import torch ->>> from nemo_rl.data.multimodal_utils import PackedTensor ->>> p1 = PackedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])], dim_to_pack=0) ->>> p2 = PackedTensor([torch.tensor([7, 8, 9])], dim_to_pack=0) ->>> p3 = PackedTensor.flattened_concat([p1, p2]) ->>> p3.tensors -[tensor([1, 2, 3, 4, 5, 6]), tensor([7, 8, 9])] ->>> p3.as_tensor() -tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]) ->>> -``` - - - - - - - - - -```python -nemo_rl.data.multimodal_utils.PackedTensor.slice( - indices: typing.Union[list[int], torch.Tensor] -) -> nemo_rl.data.multimodal_utils.PackedTensor -``` - - - - - - - - - - - - -```python -nemo_rl.data.multimodal_utils.PackedTensor.to( - device: str | torch.device -) -> nemo_rl.data.multimodal_utils.PackedTensor -``` - - - - - - - - - - - - - - -```python -nemo_rl.data.multimodal_utils.get_dim_to_pack_along( - processor, - key: str -) -> int -``` - - - - - - -Special considerations for packing certain keys from certain processors. - -In most cases, the packed items are along dim 0 - - - - - - - - -```python -nemo_rl.data.multimodal_utils.get_multimodal_keys_from_processor( - processor -) -> list[str] -``` - - - - - - -Get keys of the multimodal data that can be used as model inputs. - -This will be used in the data_processor function to determine which keys to use as model inputs. - - - - - - - - -```python -nemo_rl.data.multimodal_utils.resolve_to_image( - image_path_or_image: str | PIL.Image.Image -) -> PIL.Image.Image -``` - - - - - - -Resolve the image path to a PIL.Image object. - -image_path can be either: -- path to local file -- url to image -- base64 encoded image - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx deleted file mode 100644 index 8161181..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx +++ /dev/null @@ -1,30 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/packing -title: nemo_rl.data.packing ---- - -## Submodules - -- **[`nemo_rl.data.packing.algorithms`](/nemo-rl/nemo_rl/data/packing/algorithms)** -- **[`nemo_rl.data.packing.metrics`](/nemo-rl/nemo_rl/data/packing/metrics)** - -## Package Contents - -### Data - -[`__all__`](#nemo_rl-data-packing-__all__) - -### API - - - - - -```python -nemo_rl.data.packing.__all__ = ['PackingAlgorithm', 'SequencePacker', 'ConcatenativePacker', 'FirstFitDecreasin... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx deleted file mode 100644 index 7337f16..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx +++ /dev/null @@ -1,791 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/packing/algorithms -title: nemo_rl.data.packing.algorithms ---- - -Sequence packing algorithms for efficient batching of variable-length sequences. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`ConcatenativePacker`](#nemo_rl-data-packing-algorithms-ConcatenativePacker) | Concatenative packing algorithm. | -| [`FirstFitDecreasingPacker`](#nemo_rl-data-packing-algorithms-FirstFitDecreasingPacker) | First-Fit Decreasing (FFD) algorithm for sequence packing. | -| [`FirstFitPacker`](#nemo_rl-data-packing-algorithms-FirstFitPacker) | Base class for First-Fit algorithms. | -| [`FirstFitShufflePacker`](#nemo_rl-data-packing-algorithms-FirstFitShufflePacker) | First-Fit Shuffle algorithm for sequence packing. | -| [`ModifiedFirstFitDecreasingPacker`](#nemo_rl-data-packing-algorithms-ModifiedFirstFitDecreasingPacker) | Modified First-Fit Decreasing (MFFD) algorithm for sequence packing. | -| [`PackingAlgorithm`](#nemo_rl-data-packing-algorithms-PackingAlgorithm) | Enum for supported sequence packing algorithms. | -| [`SequencePacker`](#nemo_rl-data-packing-algorithms-SequencePacker) | Abstract base class for sequence packing algorithms. | - -### Functions - -| Name | Description | -|------|-------------| -| [`get_packer`](#nemo_rl-data-packing-algorithms-get_packer) | Factory function to get a sequence packer based on the algorithm. | - -### API - - - - - -```python -class nemo_rl.data.packing.algorithms.ConcatenativePacker() -``` - - - - - - -**Bases:** [SequencePacker](#nemo_rl-data-packing-algorithms-SequencePacker) - -Concatenative packing algorithm. - -This algorithm simply concatenates sequences in order until reaching the bin capacity, -then starts a new bin. It doesn't try to optimize the packing in any way. - -Time complexity: O(n) where n is the number of sequences. - -Example: - - -```python ->>> examples = { -... "sequence_lengths": [4, 1, 3, 2, 1, 3, 4, 5] -... } ->>> # If packed with seq_length=5: -... {"bins": [ [0, 1], [2, 3], [4, 5], [6], [7] ]} ->>> # If packed with seq_length=8: -... {"bins": [ [0, 1, 2], [3, 4, 5], [6], [7] ]} -``` - - - - - - - - - - -```python -nemo_rl.data.packing.algorithms.ConcatenativePacker._pack_implementation( - sequence_lengths: typing.List[int] -) -> typing.List[typing.List[int]] -``` - - - - - - -Pack sequences using the Concatenative algorithm. - -**Parameters:** - - -A list of sequence lengths to pack. - - -**Returns:** `List[List[int]]` - -A list of bins, where each bin is a list of indices into the original - - - - - - - - - -```python -class nemo_rl.data.packing.algorithms.FirstFitDecreasingPacker() -``` - - - - - - -**Bases:** [FirstFitPacker](#nemo_rl-data-packing-algorithms-FirstFitPacker) - -First-Fit Decreasing (FFD) algorithm for sequence packing. - -This algorithm sorts sequences by length in descending order and then -places each sequence into the first bin where it fits. - -Time complexity: O(n log n) for sorting + O(n * m) for packing, -where n is the number of sequences and m is the number of bins. - - - - - - -```python -nemo_rl.data.packing.algorithms.FirstFitDecreasingPacker._prepare_sequences( - sequence_lengths: typing.List[int] -) -> typing.List[typing.Tuple[int, int]] -``` - - - - - - -Prepare sequences for packing by sorting them in descending order. - -**Parameters:** - - -A list of sequence lengths to pack. - - -**Returns:** `List[Tuple[int, int]]` - -A list of (length, index) pairs sorted by length in descending order. - - - - - - - - - -```python -class nemo_rl.data.packing.algorithms.FirstFitPacker() -``` - - - - - - -**Bases:** [SequencePacker](#nemo_rl-data-packing-algorithms-SequencePacker) - -Base class for First-Fit algorithms. - -First-Fit algorithms place each sequence into the first bin where it fits. -If no bin can fit the sequence, a new bin is created. - -This is an abstract base class that provides the common implementation for -First-Fit variants. Subclasses must implement the _prepare_sequences method -to determine the order in which sequences are processed. - - - - - - -```python -nemo_rl.data.packing.algorithms.FirstFitPacker._pack_implementation( - sequence_lengths: typing.List[int] -) -> typing.List[typing.List[int]] -``` - - - - - - -Pack sequences using the First-Fit algorithm. - -**Parameters:** - - -A list of sequence lengths to pack. - - -**Returns:** `List[List[int]]` - -A list of bins, where each bin is a list of indices into the original - - - - - - - -```python -nemo_rl.data.packing.algorithms.FirstFitPacker._prepare_sequences( - sequence_lengths: typing.List[int] -) -> typing.List[typing.Tuple[int, int]] -``` - - - - - - -Prepare sequences for packing. - -This method determines the order in which sequences are processed. -Subclasses must override this method. - -**Parameters:** - - -A list of sequence lengths to pack. - - -**Returns:** `List[Tuple[int, int]]` - -A list of (length, index) pairs. - - - - - - - - - -```python -class nemo_rl.data.packing.algorithms.FirstFitShufflePacker() -``` - - - - - - -**Bases:** [FirstFitPacker](#nemo_rl-data-packing-algorithms-FirstFitPacker) - -First-Fit Shuffle algorithm for sequence packing. - -This algorithm randomly shuffles the sequences and then places each -sequence into the first bin where it fits. - -Time complexity: O(n * m) for packing, where n is the number of sequences -and m is the number of bins. - - - - - - -```python -nemo_rl.data.packing.algorithms.FirstFitShufflePacker._prepare_sequences( - sequence_lengths: typing.List[int] -) -> typing.List[typing.Tuple[int, int]] -``` - - - - - - -Prepare sequences for packing by randomly shuffling them. - -**Parameters:** - - -A list of sequence lengths to pack. - - -**Returns:** `List[Tuple[int, int]]` - -A list of (length, index) pairs in random order. - - - - - - - - - -```python -class nemo_rl.data.packing.algorithms.ModifiedFirstFitDecreasingPacker() -``` - - - - - - -**Bases:** [SequencePacker](#nemo_rl-data-packing-algorithms-SequencePacker) - -Modified First-Fit Decreasing (MFFD) algorithm for sequence packing. - -This algorithm implements the Johnson & Garey (1985) Modified First-Fit-Decreasing -heuristic. It classifies items into four categories (large, medium, small, tiny) -and uses a sophisticated 5-phase packing strategy to achieve better bin utilization -than standard First-Fit Decreasing. - -The algorithm phases: -1. Classify items by size relative to bin capacity -2. Create one bin per large item -3. Add medium items to large bins (forward pass) -4. Add pairs of small items to bins with medium items (backward pass) -5. Greedily fit remaining items -6. Apply FFD to any leftovers - -Time complexity: O(n log n) for sorting + O(n * m) for packing, -where n is the number of sequences and m is the number of bins. - - - - - - -```python -nemo_rl.data.packing.algorithms.ModifiedFirstFitDecreasingPacker._classify_items( - items: typing.List[typing.Tuple[int, int]] -) -> typing.Tuple[typing.List[typing.Tuple[int, int]], typing.List[typing.Tuple[int, int]], typing.List[typing.Tuple[int, int]], typing.List[typing.Tuple[int, int]]] -``` - - - - - - -Split items into large / medium / small / tiny classes. - -Follows the classification used by Johnson & Garey: - large : (C/2, C] - medium : (C/3, C/2] - small : (C/6, C/3] - tiny : (0 , C/6] - -**Parameters:** - - -List of (index, size) tuples - - -**Returns:** `Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]` - -Tuple of four lists (large, medium, small, tiny) without additional sorting. - - - - - - - -```python -nemo_rl.data.packing.algorithms.ModifiedFirstFitDecreasingPacker._pack_implementation( - sequence_lengths: typing.List[int] -) -> typing.List[typing.List[int]] -``` - - - - - - -Pack sequences using the Modified First-Fit Decreasing algorithm. - -**Parameters:** - - -A list of sequence lengths to pack. - - -**Returns:** `List[List[int]]` - -A list of bins, where each bin is a list of indices into the original - - - - - - - - - -```python -class nemo_rl.data.packing.algorithms.PackingAlgorithm -``` - - - - - - -**Bases:** `enum.Enum` - -Enum for supported sequence packing algorithms. - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.data.packing.algorithms.SequencePacker( - bin_capacity: int, - collect_metrics: bool = False, - min_bin_count: typing.Optional[int] = None, - bin_count_multiple: typing.Optional[int] = None -) -``` - - - - - - -Abstract - -Abstract base class for sequence packing algorithms. - -Sequence packing is the process of efficiently arranging sequences of different -lengths into fixed-capacity bins (batches) to maximize computational efficiency. - - - - - - -```python -nemo_rl.data.packing.algorithms.SequencePacker._adjust_bin_count( - bins: typing.List[typing.List[int]] -) -> typing.List[typing.List[int]] -``` - - - - - - -Adjust the number of bins to meet minimum and multiple constraints. - -This method preserves the existing bin packing as much as possible and only -moves sequences one at a time to create additional bins when needed. - -**Parameters:** - - -The original bins from the packing algorithm. - - -**Returns:** `List[List[int]]` - -Adjusted bins with minimal changes to meet constraints. - -**Raises:** - -- `ValueError`: If there aren't enough sequences to fill the required number of bins. - - - - - - - -```python -nemo_rl.data.packing.algorithms.SequencePacker._create_indexed_lengths( - sequence_lengths: typing.List[int], - reverse: bool = False -) -> typing.List[typing.Tuple[int, int]] -``` - - - - - - -Create a list of (length, index) pairs from sequence lengths. - -**Parameters:** - - -A list of sequence lengths. - - - -Whether to sort in descending order (True) or ascending order (False). - - -**Returns:** `List[Tuple[int, int]]` - -A list of (length, index) pairs, optionally sorted. - - - - - - - -```python -nemo_rl.data.packing.algorithms.SequencePacker._estimate_bins_needed( - sequence_lengths: typing.List[int] -) -> int -``` - - - - - - -Estimate the number of bins needed based on total length. - -**Parameters:** - - -A list of sequence lengths. - - -**Returns:** `int` - -Estimated number of bins needed. - - - - - - - -```python -nemo_rl.data.packing.algorithms.SequencePacker._pack_implementation( - sequence_lengths: typing.List[int] -) -> typing.List[typing.List[int]] -``` - - - - - - -abstract - -Implementation of the packing algorithm. - -**Parameters:** - - -A list of sequence lengths to pack. - - -**Returns:** `List[List[int]]` - -A list of bins, where each bin is a list of indices into the original - - - - - - - -```python -nemo_rl.data.packing.algorithms.SequencePacker._validate_sequence_lengths( - sequence_lengths: typing.List[int] -) -> None -``` - - - - - - -Validate that all sequence lengths are within bin capacity. - -**Parameters:** - - -A list of sequence lengths to validate. - - -**Raises:** - -- `ValueError`: If any sequence length exceeds bin capacity. - - - - - - - -```python -nemo_rl.data.packing.algorithms.SequencePacker.compute_metrics( - sequence_lengths: typing.List[int], - bins: typing.List[typing.List[int]] -) -> typing.Dict[str, float] -``` - - - - - - -Calculate metrics for a packing solution without updating the metrics tracker. - -**Parameters:** - - -List of sequence lengths - - - -List of bins, where each bin is a list of indices - - -**Returns:** `Dict[str, float]` - -Dictionary of packing metrics - - - - - - - -```python -nemo_rl.data.packing.algorithms.SequencePacker.get_aggregated_metrics() -> typing.Dict[str, float] -``` - - - - - - -Get aggregated metrics across all packing operations. - -**Returns:** `Dict[str, float]` - -Dictionary of aggregated metrics, or empty dict if not collecting - - - - - - - -```python -nemo_rl.data.packing.algorithms.SequencePacker.pack( - sequence_lengths: typing.List[int] -) -> typing.List[typing.List[int]] -``` - - - - - - -Pack sequences into bins and update metrics if enabled. - -**Parameters:** - - -A list of sequence lengths to pack. - - -**Returns:** `List[List[int]]` - -A list of bins, where each bin is a list of indices into the original - - - - - - - -```python -nemo_rl.data.packing.algorithms.SequencePacker.print_metrics() -> None -``` - - - - - - -Print the current metrics in a formatted way. - - - - - - - -```python -nemo_rl.data.packing.algorithms.SequencePacker.reset_metrics() -> None -``` - - - - - - -Reset collected metrics. - - - - - - - - - -```python -nemo_rl.data.packing.algorithms.get_packer( - algorithm: typing.Union[nemo_rl.data.packing.algorithms.PackingAlgorithm, str], - bin_capacity: int, - collect_metrics: bool = False, - min_bin_count: typing.Optional[int] = None, - bin_count_multiple: typing.Optional[int] = None -) -> nemo_rl.data.packing.algorithms.SequencePacker -``` - - - - - - -Factory function to get a sequence packer based on the algorithm. - -**Parameters:** - - -The packing algorithm to use. Can be either a PackingAlgorithm enum value - or a string (case-insensitive) matching one of the enum names. - - - -The maximum capacity of each bin. - - - -Whether to collect metrics across multiple packing operations. - - - -Minimum number of bins to create, even if fewer would suffice. - If None, no minimum is enforced. - - - -The total number of bins must be a multiple of this value. - If None, no multiple constraint is enforced. - - -**Returns:** `SequencePacker` - -A SequencePacker instance for the specified algorithm. - -**Raises:** - -- `ValueError`: If the algorithm is not recognized. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx deleted file mode 100644 index 3f1b4d0..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx +++ /dev/null @@ -1,177 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/packing/metrics -title: nemo_rl.data.packing.metrics ---- - -Metrics for evaluating sequence packing algorithms. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`PackingMetrics`](#nemo_rl-data-packing-metrics-PackingMetrics) | Class for tracking and computing metrics for sequence packing algorithms. | - -### API - - - - - -```python -class nemo_rl.data.packing.metrics.PackingMetrics() -``` - - - - - - -Class for tracking and computing metrics for sequence packing algorithms. - -This class provides methods to calculate various metrics that evaluate the -efficiency and effectiveness of sequence packing algorithms, such as bin -utilization, waste, and imbalance. - - - - - - -```python -nemo_rl.data.packing.metrics.PackingMetrics.calculate_stats_only( - sequence_lengths: typing.List[int], - bins: typing.List[typing.List[int]], - bin_capacity: int -) -> typing.Dict[str, float] -``` - - - - - - -Calculate metrics for a packing solution without updating the tracker. - -**Parameters:** - - -List of sequence lengths - - - -List of bins, where each bin is a list of indices - - - -Maximum capacity of each bin - - -**Returns:** `Dict[str, float]` - -Dictionary of metrics for this packing solution - - - - - - - -```python -nemo_rl.data.packing.metrics.PackingMetrics.get_aggregated_stats() -> typing.Dict[str, float] -``` - - - - - - -Get aggregated metrics across all packing operations. - -**Returns:** `Dict[str, float]` - -Dictionary of aggregated metrics - - - - - - - -```python -nemo_rl.data.packing.metrics.PackingMetrics.print_aggregated_stats() -> None -``` - - - - - - -Print the aggregated metrics in a formatted way. - - - - - - - -```python -nemo_rl.data.packing.metrics.PackingMetrics.reset() -> None -``` - - - - - - -Reset all metrics. - - - - - - - -```python -nemo_rl.data.packing.metrics.PackingMetrics.update( - sequence_lengths: typing.List[int], - bins: typing.List[typing.List[int]], - bin_capacity: int, - packing_time: typing.Optional[float] = None -) -> typing.Dict[str, float] -``` - - - - - - -Update metrics with a new packing solution. - -**Parameters:** - - -List of sequence lengths - - - -List of bins, where each bin is a list of indices - - - -Maximum capacity of each bin - - - -Optional time taken to compute the packing solution - - -**Returns:** `Dict[str, float]` - -Dictionary of metrics for this packing solution - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx deleted file mode 100644 index 7660a0d..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx +++ /dev/null @@ -1,353 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/processors -title: nemo_rl.data.processors ---- - -Contains data processors for evaluation. - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`_construct_multichoice_prompt`](#nemo_rl-data-processors-_construct_multichoice_prompt) | Construct prompt from question and options. | -| [`helpsteer3_data_processor`](#nemo_rl-data-processors-helpsteer3_data_processor) | Process a HelpSteer3 preference datum into a DatumSpec for GRPO training. | -| [`math_data_processor`](#nemo_rl-data-processors-math_data_processor) | Process a datum dictionary (directly loaded from dataset) into a DatumSpec for the Math Environment. | -| [`math_hf_data_processor`](#nemo_rl-data-processors-math_hf_data_processor) | Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Reward Model Environment. | -| [`multichoice_qa_processor`](#nemo_rl-data-processors-multichoice_qa_processor) | Process a datum dictionary (directly loaded from dataset) into a DatumSpec for multiple-choice problems. | -| [`nemo_gym_data_processor`](#nemo_rl-data-processors-nemo_gym_data_processor) | Process a datum dictionary (directly loaded from dataset) into a DatumSpec for Nemo Gym. | -| [`preference_preprocessor`](#nemo_rl-data-processors-preference_preprocessor) | Process a datum dictionary for RM/DPO training. | -| [`register_processor`](#nemo_rl-data-processors-register_processor) | - | -| [`sft_processor`](#nemo_rl-data-processors-sft_processor) | Process a datum dictionary for SFT training. | -| [`vlm_hf_data_processor`](#nemo_rl-data-processors-vlm_hf_data_processor) | Process a datum dictionary (directly loaded from response_datasets/<dataset_name>.py) into a DatumSpec for the VLM Environment. | - -### Data - -[`PROCESSOR_REGISTRY`](#nemo_rl-data-processors-PROCESSOR_REGISTRY) - -[`TokenizerType`](#nemo_rl-data-processors-TokenizerType) - -### API - - - - - -```python -nemo_rl.data.processors._construct_multichoice_prompt( - prompt: str, - question: str, - options: dict[str, str] -) -> str -``` - - - - - - -Construct prompt from question and options. - - - - - - - - -```python -nemo_rl.data.processors.helpsteer3_data_processor( - datum_dict: dict[str, typing.Any], - task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, - tokenizer: nemo_rl.data.processors.TokenizerType, - max_seq_length: int, - idx: int -) -> nemo_rl.data.interfaces.DatumSpec -``` - - - - - - -Process a HelpSteer3 preference datum into a DatumSpec for GRPO training. - -This function converts HelpSteer3 preference data to work with GRPO by: -1. Using the context as the prompt -2. Using the preferred completion as the target response -3. Creating a reward signal based on preference scores - - - - - - - - -```python -nemo_rl.data.processors.math_data_processor( - datum_dict: dict[str, typing.Any], - task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, - tokenizer: nemo_rl.data.processors.TokenizerType, - max_seq_length: int, - idx: int -) -> nemo_rl.data.interfaces.DatumSpec -``` - - - - - - -Process a datum dictionary (directly loaded from dataset) into a DatumSpec for the Math Environment. - - - - - - - - -```python -nemo_rl.data.processors.math_hf_data_processor( - datum_dict: dict[str, typing.Any], - task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, - tokenizer: nemo_rl.data.processors.TokenizerType, - max_seq_length: int, - idx: int -) -> nemo_rl.data.interfaces.DatumSpec -``` - - - - - - -Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Reward Model Environment. - - - - - - - - -```python -nemo_rl.data.processors.multichoice_qa_processor( - datum_dict: dict[str, typing.Any], - task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, - tokenizer: nemo_rl.data.processors.TokenizerType, - max_seq_length: int, - idx: int -) -> nemo_rl.data.interfaces.DatumSpec -``` - - - - - - -Process a datum dictionary (directly loaded from dataset) into a DatumSpec for multiple-choice problems. - - - - - - - - -```python -nemo_rl.data.processors.nemo_gym_data_processor( - datum_dict: dict[str, typing.Any], - task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, - tokenizer: nemo_rl.data.processors.TokenizerType, - max_seq_length: int | None, - idx: int -) -> nemo_rl.data.interfaces.DatumSpec -``` - - - - - - -Process a datum dictionary (directly loaded from dataset) into a DatumSpec for Nemo Gym. - - - - - - - - -```python -nemo_rl.data.processors.preference_preprocessor( - datum_dict: dict[str, typing.Any], - task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, - tokenizer, - max_seq_length: int, - idx: int -) -> nemo_rl.data.interfaces.PreferenceDatumSpec -``` - - - - - - -Process a datum dictionary for RM/DPO training. - -**Examples:** - - - -```python ->>> from transformers import AutoTokenizer ->>> from nemo_rl.data.interfaces import TaskDataSpec ->>> from nemo_rl.data.processors import preference_preprocessor ->>> ->>> # Initialize tokenizer and task spec ->>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") ->>> ## set a passthrough chat template for simplicity ->>> tokenizer.chat_template = "{% for message in messages %}{{ message['content'] }}{% endfor %}" ->>> task_spec = TaskDataSpec(task_name="test_preference") ->>> ->>> datum = { -... "context": [{"role": "user", "content": "What is 2+2?"}], -... "completions": [ -... {"rank": 0, "completion": [{"role": "assistant", "content": "4"}]}, -... {"rank": 1, "completion": [{"role": "assistant", "content": "5"}]} -... ] -... } ->>> ->>> processed = preference_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) # doctest: +ELLIPSIS - -... ->>> len(processed["message_log_chosen"]) -2 ->>> processed["message_log_chosen"][0]["content"] -'<|begin_of_text|>What is 2+2?' ->>> processed["message_log_chosen"][-1]["content"] -'4<|eot_id|>' ->>> processed["message_log_rejected"][-1]["content"] -'5<|eot_id|>' ->>> ->>> # context can also contain multiple turns ->>> datum = { -... "context": [{"role": "user", "content": "I have a question."}, {"role": "assistant", "content": "Sure!"}, {"role": "user", "content": "What is 2+2?"}], -... "completions": [ -... {"rank": 0, "completion": [{"role": "assistant", "content": "4"}]}, -... {"rank": 1, "completion": [{"role": "assistant", "content": "5"}]} -... ] -... } ->>> processed = preference_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) ->>> len(processed["message_log_chosen"]) -4 ->>> processed["message_log_chosen"][1]["content"] -'Sure!' ->>> processed["message_log_chosen"][-1]["content"] -'4<|eot_id|>' ->>> processed["message_log_rejected"][-1]["content"] -'5<|eot_id|>' -``` - - - - - - - - - - -```python -nemo_rl.data.processors.register_processor( - processor_name: str, - processor_function: nemo_rl.data.interfaces.TaskDataProcessFnCallable -) -> None -``` - - - - - - - - - - - - - -```python -nemo_rl.data.processors.sft_processor( - datum_dict: dict[str, typing.Any], - task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, - tokenizer, - max_seq_length: int, - idx: int, - add_bos: bool = True, - add_eos: bool = True, - add_generation_prompt: bool = False -) -> nemo_rl.data.interfaces.DatumSpec -``` - - - - - - -Process a datum dictionary for SFT training. - - - - - - - - -```python -nemo_rl.data.processors.vlm_hf_data_processor( - datum_dict: dict[str, typing.Any], - task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, - processor: transformers.AutoProcessor, - max_seq_length: int, - idx: int -) -> nemo_rl.data.interfaces.DatumSpec -``` - - - - - - -Process a datum dictionary (directly loaded from response_datasets/<dataset_name>.py) into a DatumSpec for the VLM Environment. - - - - - - - - -```python -nemo_rl.data.processors.PROCESSOR_REGISTRY: Dict[str, TaskDataProcessFnCallable] = cast(Dict[str, TaskDataProcessFnCallable], {'default': math_hf_data_processor, '... -``` - - - - - - - - - -```python -nemo_rl.data.processors.TokenizerType = PreTrainedTokenizerBase -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx deleted file mode 100644 index 386880c..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx +++ /dev/null @@ -1,104 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/data/utils -title: nemo_rl.data.utils ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`setup_preference_data`](#nemo_rl-data-utils-setup_preference_data) | Setup preference data. | -| [`setup_response_data`](#nemo_rl-data-utils-setup_response_data) | Setup data with environments. | - -### API - - - - - -```python -nemo_rl.data.utils.setup_preference_data( - tokenizer: transformers.AutoTokenizer, - data_config: nemo_rl.data.DataConfig -) -``` - - - - - - -Setup preference data. - -This function is used to setup the preference data for the training and validation datasets. - -**Parameters:** - - -Tokenizer. - - - -Data config for preference dataset. - - -**Returns:** - -A tuple of (train dataset, validation dataset). - - - - - - - - -```python -nemo_rl.data.utils.setup_response_data( - tokenizer: transformers.AutoProcessor | transformers.AutoTokenizer, - data_config: nemo_rl.data.DataConfig, - env_configs: typing.Optional[dict[str, typing.Any]] = None, - is_vlm: bool = False -) -> typing.Union[tuple[nemo_rl.data.datasets.AllTaskProcessedDataset, typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset]], tuple[nemo_rl.data.datasets.AllTaskProcessedDataset, typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset], dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]]] -``` - - - - - - -Setup data with environments. - -This function is used to setup the data and environments for the training and validation datasets. - -**Parameters:** - - -Tokenizer or processor. - - - -Data config. - - - -Environment configs. -If None, no environments will be created. This is used for: -- Algorithms like SFT which do not need environments. -- Environments like NeMo-Gym which need to handle the environment creation outside of this function. - - - -Whether to use VLM training or not. - - -**Returns:** `Union[tuple[AllTaskProcessedDataset, Optional[AllTaskProcessedDataset]], tuple[AllTaskProcessedDataset, Optional[AllTaskProcessedDataset], dict[str, EnvironmentInterface], dict[str, EnvironmentInterface]]]` - -If env_configs is not None: -A tuple of (train dataset, validation dataset, task to environment, task to validation environment). - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx deleted file mode 100644 index d615e36..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx +++ /dev/null @@ -1,17 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/distributed -title: nemo_rl.distributed ---- - -## Submodules - -- **[`nemo_rl.distributed.batched_data_dict`](/nemo-rl/nemo_rl/distributed/batched_data_dict)** -- **[`nemo_rl.distributed.collectives`](/nemo-rl/nemo_rl/distributed/collectives)** -- **[`nemo_rl.distributed.model_utils`](/nemo-rl/nemo_rl/distributed/model_utils)** -- **[`nemo_rl.distributed.named_sharding`](/nemo-rl/nemo_rl/distributed/named_sharding)** -- **[`nemo_rl.distributed.ray_actor_environment_registry`](/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry)** -- **[`nemo_rl.distributed.stateless_process_group`](/nemo-rl/nemo_rl/distributed/stateless_process_group)** -- **[`nemo_rl.distributed.virtual_cluster`](/nemo-rl/nemo_rl/distributed/virtual_cluster)** -- **[`nemo_rl.distributed.worker_group_utils`](/nemo-rl/nemo_rl/distributed/worker_group_utils)** -- **[`nemo_rl.distributed.worker_groups`](/nemo-rl/nemo_rl/distributed/worker_groups)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx deleted file mode 100644 index c134df2..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx +++ /dev/null @@ -1,671 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/distributed/batched_data_dict -title: nemo_rl.distributed.batched_data_dict ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`BatchedDataDict`](#nemo_rl-distributed-batched_data_dict-BatchedDataDict) | - | -| [`DynamicBatchingArgs`](#nemo_rl-distributed-batched_data_dict-DynamicBatchingArgs) | Configuration settings for dynamic batching. | -| [`SequencePackingArgs`](#nemo_rl-distributed-batched_data_dict-SequencePackingArgs) | Configuration settings for sequence packing. | -| [`SlicedDataDict`](#nemo_rl-distributed-batched_data_dict-SlicedDataDict) | A specialized subclass of BatchedDataDict that represents a slice or shard of a larger batch. | - -### Data - -[`DictT`](#nemo_rl-distributed-batched_data_dict-DictT) - -### API - - - - - -```python -class nemo_rl.distributed.batched_data_dict.BatchedDataDict( - args = (), - kwargs = {} -) -``` - - - - - - -**Bases:** `UserDict`, `Generic[DictT]` - - - - - -Get the batch size of the batch. - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.all_gather( - group: torch.distributed.ProcessGroup -) -> typing_extensions.Self -``` - - - - - - -Gathers batches with possibly jagged leading dimensions across the DP ranks. - -If using reshard, it will treat PP as DP ranks. -Works with data that is either tensors or string lists. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.chunk( - rank: int, - chunks: int -) -> nemo_rl.distributed.batched_data_dict.SlicedDataDict -``` - - - - - - -Chunks a global batch into 'chunks' splits and returns the 'rank'th split batch=[A A A B B B D D E], rank=2, chunks=3 -> [D D E]. - -Requires all leading dimensions of tensors and lengths of lists to be the same over the batch -and the chunks must divide batch size. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.from_batches( - batches: typing.Sequence[typing.Mapping[typing.Any, typing.Any]], - pad_value_dict: typing.Optional[dict[str, int | float]] = None -) -> typing_extensions.Self -``` - - - - - - -classmethod - -Given a list of batches, stack the tensors/lists within and put them in a single dictionary. - -Pad sequences to the max length in the batch using either 0(default) or a non-default value for a given key provided in pad_value_dict. - -**Parameters:** - - -A list of dictionaries, each containing a batch of data. - - - -An optional dict mapping keys to non-default(0) padding values. - - -**Returns:** `Self` - -A new BatchedDataDict containing the stacked data. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_batch( - batch_idx, - batch_size = None -) -> nemo_rl.distributed.batched_data_dict.SlicedDataDict -``` - - - - - - -Slices a subbatch from the batch. - -**Parameters:** - - -the batch index to slice - - - -the size of the batch to be sliced - - -**Returns:** `SlicedDataDict` - -A new BatchedDataDict containing the sliced data - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_dict() -> dict[typing.Any, typing.Any] -``` - - - - - - -Get the underlying data dictionary. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_microbatch_iterator_dynamic_shapes_len() -> int -``` - - - - - - -Get the length of the microbatch iterator for dynamic shapes. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_microbatch_iterator_for_packable_sequences_len() -> tuple[int, int] -``` - - - - - - -Get the length of the microbatch iterator for sequence packing and the max packed seqlen. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_multimodal_dict( - as_tensors: bool = False, - device: typing.Optional[torch.device] = None -) -> dict[str, typing.Any] -``` - - - - - - -Return a regular dict of tensors or packed multimodal data items. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.make_microbatch_iterator( - microbatch_size: int -) -> typing.Iterator[nemo_rl.distributed.batched_data_dict.SlicedDataDict] -``` - - - - - - -Make an iterator over the batch that yields microbatches of size microbatch_size. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.make_microbatch_iterator_for_packable_sequences() -> typing.Iterator[nemo_rl.distributed.batched_data_dict.SlicedDataDict] -``` - - - - - - -Make an iterator over the batch that yields microbatches that can be packed into a given max_tokens_per_microbatch. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.make_microbatch_iterator_with_dynamic_shapes( - sequence_dim: int = 1 -) -> typing.Iterator[nemo_rl.distributed.batched_data_dict.SlicedDataDict] -``` - - - - - - -Makes an iterator that yields microbatchs of dynamic batch and sequence sizes. - -**Parameters:** - - -the index of the sequence dim for all tensors in the data dict - - -**Returns:** `Iterator[SlicedDataDict]` - -Iterator["SlicedDataDict"]: An iterator that yield dynamic microbatches - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.reorder_data( - reorded_indices: list[int] -) -``` - - - - - - -Reorders the data along the batch dimension by the given indices. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.repeat_interleave( - num_repeats: int -) -> typing_extensions.Self -``` - - - - - - -Repeats the batch num_repeats times. - -For each element in the batch, repeat each value num_repeats times. -i.e: -{"key": torch.tensor([1, 2, 3]), "other_key": [1, 2, 3]} -> {"key": torch.tensor([1, 1, 2, 2, 3, 3]), "other_key": [1, 1, 2, 2, 3, 3]} - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.select_indices( - indices: typing.Union[list[int], torch.Tensor] -) -> typing_extensions.Self -``` - - - - - - -Selects specific rows from the batch based on indices. - -**Parameters:** - - -A list or tensor of integer indices to select. - - -**Returns:** `Self` - -A new BatchedDataDict containing only the selected rows. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.shard_by_batch_size( - shards: int, - batch_size: typing.Optional[int] = None, - allow_uneven_shards: bool = False, - dynamic_batching_args: typing.Optional[nemo_rl.distributed.batched_data_dict.DynamicBatchingArgs] = None, - sequence_packing_args: typing.Optional[nemo_rl.distributed.batched_data_dict.SequencePackingArgs] = None -) -> list[nemo_rl.distributed.batched_data_dict.SlicedDataDict] | tuple[list[nemo_rl.distributed.batched_data_dict.SlicedDataDict], list[int]] -``` - - - - - - -Shards a batch by first dividing it into chunks of size batch_size, then further dividing each chunk into shards equal parts. Finally aggregates the sub-shards by their position. - -If batch_size is None, there will be no chunking beforehand (will default to the total batch size). - -For example, with data [A A B B C C D D], batch_size=2, shards=2: -- Element 0: [A B C D] (first elements from each chunk) -- Element 1: [A B C D] (second elements from each chunk) - -Examples: - - -```python ->>> from nemo_rl.distributed.batched_data_dict import BatchedDataDict ->>> # Create a batch of two message logs with different lengths ->>> batch = BatchedDataDict({ -... 'problem_id': [0, 0, 1, 1, 2, 2, 3, 3], -... 'arbitrary_data': [1, 2, 3, 4, 5, 6, 7, 8] -... }) ->>> shards = batch.shard_by_batch_size(shards=2) ->>> shards -[{'problem_id': [0, 0, 1, 1], 'arbitrary_data': [1, 2, 3, 4]}, {'problem_id': [2, 2, 3, 3], 'arbitrary_data': [5, 6, 7, 8]}] ->>> # Now say that I'm training with a GBS of 4 and I want to take gradients steps on problems 0 and 1 before 2 and 3 (problems are repeated because GRPO) ->>> # In the current case, problems 0 and 2 will be trained on first since they're the first elements in each DP rank's batch. ->>> # So, we'll use the batch_size argument to split the batch into chunks of size 4 first. ->>> shards = batch.shard_by_batch_size(shards=2, batch_size=4) ->>> shards -[{'problem_id': [0, 0, 2, 2], 'arbitrary_data': [1, 2, 5, 6]}, {'problem_id': [1, 1, 3, 3], 'arbitrary_data': [3, 4, 7, 8]}] ->>> # Now, the ranks have 0 and 1 first so when they split their batches into microbatches (of size 2 since GBS=4 and DP=2), they'll train on 0 and 1 first. ->>> # Another way to use this function is with the 'allow_uneven_shards' flag, which allows the last shard to be smaller than the others when necessary. ->>> # This is necessary in multi-turn rollouts when some sequences terminate early, leaving unclean batch sizes. ->>> batch = BatchedDataDict({ -... 'problem_id': [0, 1, 2, 3, 4], -... 'arbitrary_data': [10, 11, 12, 13, 14] -... }) ->>> shards = batch.shard_by_batch_size(shards=2, allow_uneven_shards=True) ->>> shards -[{'problem_id': [0, 1, 2], 'arbitrary_data': [10, 11, 12]}, {'problem_id': [3, 4], 'arbitrary_data': [13, 14]}] ->>> # This is incompatible with the batch_size argument -``` - - - -**Parameters:** - - -The number of shards to divide each batch_size chunk into. - - - -The size of each initial chunk. - - - -Whether to allow shards to be unevenly sized. - If True, the last shard may be smaller than the others. - - - -If passed, preprocess batch for dynamic batching. This - dict requires four keys: - 1. max_tokens_per_microbatch (int): the maximum - number of tokens in a microbatch - 2. sequence_length_round (int): round each all - sequence lengths to this multiple - 3. input_key (str): the key in the batch - which holds input ids. - 4. input_lengths_key (str): the key in the batch - which holds the sequence length per value. - The sequence dim index is assumed to be 1. - Cannot be passed with sequence_packing_args. - - - -If passed, preprocess batch for sequence packing. This - dict requires five keys: - 1. max_tokens_per_microbatch (int): the maximum - number of tokens in a microbatch - 2. input_key (str): the key in the batch - which holds input ids. - 3. input_lengths_key (str): the key in the batch - which holds the sequence length per value. - The sequence dim index is assumed to be 1. - 4. algorithm (str): the algorithm to use for sequence packing. - 5. sequence_length_pad_multiple (int): the multiple to pad each sequence to. - With CP enabled, this should be set to a multiple of 2*CP and SP. - Cannot be passed with dynamic_batching_args. - - -**Returns:** `list[SlicedDataDict] | tuple[list[SlicedDataDict], list[int]]` - -list[BatchedDataDict]: A list of BatchedDataDicts, length equal to shards. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.slice( - start: int, - end: int -) -> nemo_rl.distributed.batched_data_dict.SlicedDataDict -``` - - - - - - -Slices the batch from start to end. - -**Parameters:** - - -Starting index (inclusive) - - - -Ending index (exclusive) - - -**Returns:** `SlicedDataDict` - -A new BatchedDataDict containing the sliced data - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.to( - device: str | torch.device -) -> typing_extensions.Self -``` - - - - - - -Move tensors in batched dict to device. - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.BatchedDataDict.truncate_tensors( - dim: int, - truncated_len: int -) -``` - - - - - - -Truncates tensors in this dict of a given dim to a given length. - - - - - - - - - -```python -class nemo_rl.distributed.batched_data_dict.DynamicBatchingArgs -``` - - - - - - -**Bases:** `typing.TypedDict` - -Configuration settings for dynamic batching. - -Pass this to 'shard_by_batch_size()' to preprocess batches for dynamic batching. - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.distributed.batched_data_dict.SequencePackingArgs -``` - - - - - - -**Bases:** `typing.TypedDict` - -Configuration settings for sequence packing. - -Pass this to 'shard_by_batch_size()' to preprocess batches for sequence packing. - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.distributed.batched_data_dict.SlicedDataDict() -``` - - - - - - -**Bases:** [BatchedDataDict](#nemo_rl-distributed-batched_data_dict-BatchedDataDict) - -A specialized subclass of BatchedDataDict that represents a slice or shard of a larger batch. - -This class provides a distinct type to differentiate between full batches and sliced/sharded batches, which can be helpful for -type checking. - - - - - - - - -```python -nemo_rl.distributed.batched_data_dict.DictT = TypeVar('DictT', bound=(Mapping[str, Any])) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx deleted file mode 100644 index 5d756ce..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx +++ /dev/null @@ -1,108 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/distributed/collectives -title: nemo_rl.distributed.collectives ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`gather_jagged_object_lists`](#nemo_rl-distributed-collectives-gather_jagged_object_lists) | Gathers jagged lists of picklable objects from all ranks and flattens them into a single list. | -| [`rebalance_nd_tensor`](#nemo_rl-distributed-collectives-rebalance_nd_tensor) | Takes tensors with variable leading sizes (at dim=0) and stacks them into a single tensor. | - -### Data - -[`T`](#nemo_rl-distributed-collectives-T) - -### API - - - - - -```python -nemo_rl.distributed.collectives.gather_jagged_object_lists( - local_objects: list[nemo_rl.distributed.collectives.T], - group: typing.Optional[torch.distributed.ProcessGroup] = None -) -> list[nemo_rl.distributed.collectives.T] -``` - - - - - - -Gathers jagged lists of picklable objects from all ranks and flattens them into a single list. - -This function handles the case where different GPUs have lists of different lengths -and combines them into a single list containing all objects from all ranks. - -For example, with 3 GPUs: - GPU0: [obj0, obj1] - GPU1: [obj2, obj3, obj4] - GPU2: [obj5] - -WARNING: synchronous - -**Parameters:** - - -List of objects to gather from current rank - - - -Optional process group - - -**Returns:** `list[T]` - -Flattened list of all objects from all ranks in order [rank0, rank1, ...] - - - - - - - - -```python -nemo_rl.distributed.collectives.rebalance_nd_tensor( - tensor: torch.Tensor, - group: typing.Optional[torch.distributed.ProcessGroup] = None -) -> torch.Tensor -``` - - - - - - -Takes tensors with variable leading sizes (at dim=0) and stacks them into a single tensor. - -This function handles the case where different GPUs have tensors with different batch sizes -and combines them into a single balanced tensor across all ranks. - -For example, with 3 GPUs: - GPU0: tensor of shape [3, D] - GPU1: tensor of shape [5, D] - GPU2: tensor of shape [2, D] - -NOTE: assumes all other (i.e., non-zero) dimensions are equal. - - - - - - - - -```python -nemo_rl.distributed.collectives.T = TypeVar('T') -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx deleted file mode 100644 index 57539ab..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx +++ /dev/null @@ -1,851 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/distributed/model_utils -title: nemo_rl.distributed.model_utils ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AllGatherCPTensor`](#nemo_rl-distributed-model_utils-AllGatherCPTensor) | - | -| [`ChunkedDistributedEntropy`](#nemo_rl-distributed-model_utils-ChunkedDistributedEntropy) | Compute H_all = sum_v p_v log p_v across TP with chunking over sequence. | -| [`ChunkedDistributedGatherLogprob`](#nemo_rl-distributed-model_utils-ChunkedDistributedGatherLogprob) | Compute distributed log-softmax once and gather logprobs at given global indices. | -| [`ChunkedDistributedLogprob`](#nemo_rl-distributed-model_utils-ChunkedDistributedLogprob) | Custom autograd function for computing log probabilities in a distributed setting. | -| [`DistributedLogprob`](#nemo_rl-distributed-model_utils-DistributedLogprob) | Custom autograd function for computing log probabilities in a distributed setting. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_compute_distributed_log_softmax`](#nemo_rl-distributed-model_utils-_compute_distributed_log_softmax) | Compute a stable distributed log softmax across tensor parallel workers. | -| [`_get_tokens_on_this_cp_rank`](#nemo_rl-distributed-model_utils-_get_tokens_on_this_cp_rank) | Get tokens on this context parallelism rank. | -| [`allgather_cp_sharded_tensor`](#nemo_rl-distributed-model_utils-allgather_cp_sharded_tensor) | - | -| [`distributed_vocab_topk`](#nemo_rl-distributed-model_utils-distributed_vocab_topk) | Compute global top-k over TP-sharded vocabulary logits. | -| [`dtensor_from_parallel_logits_to_logprobs`](#nemo_rl-distributed-model_utils-dtensor_from_parallel_logits_to_logprobs) | Get log probabilities from TP+CP sharded vocab logits. | -| [`from_parallel_logits_to_logprobs`](#nemo_rl-distributed-model_utils-from_parallel_logits_to_logprobs) | Get log probabilities from TP+CP sharded vocab logits. | -| [`from_parallel_logits_to_logprobs_packed_sequences`](#nemo_rl-distributed-model_utils-from_parallel_logits_to_logprobs_packed_sequences) | Get log probabilities from TP sharded vocab logits for packed sequences. | -| [`gather_logits_at_global_indices`](#nemo_rl-distributed-model_utils-gather_logits_at_global_indices) | Gather student logits at given global token indices under TP+CP sharding. | -| [`get_logprobs_from_vocab_parallel_logits`](#nemo_rl-distributed-model_utils-get_logprobs_from_vocab_parallel_logits) | Computes log probabilities from vocabulary-parallel logits. | - -### API - - - - - -```python -class nemo_rl.distributed.model_utils.AllGatherCPTensor() -``` - - - - - - -**Bases:** `Function` - - - - - -```python -nemo_rl.distributed.model_utils.AllGatherCPTensor.backward( - ctx, - grad_output -) -``` - - - - - - - - - - - - -```python -nemo_rl.distributed.model_utils.AllGatherCPTensor.forward( - ctx, - tensor, - cp_group: torch.distributed.ProcessGroup, - seq_dim = 1 -) -``` - - - - - - - - - - - - - - -```python -class nemo_rl.distributed.model_utils.ChunkedDistributedEntropy() -``` - - - - - - -**Bases:** `Function` - -Compute H_all = sum_v p_v log p_v across TP with chunking over sequence. - -Forward returns [B, S] tensor of global entropy; backward propagates through logits. - - - - - - -```python -nemo_rl.distributed.model_utils.ChunkedDistributedEntropy.backward( - ctx: typing.Any, - grad_outputs: torch.Tensor = () -) -> tuple[torch.Tensor, None, None, None] -``` - - - - - - -staticmethod - - - - - - - -```python -nemo_rl.distributed.model_utils.ChunkedDistributedEntropy.forward( - ctx: typing.Any, - vocab_parallel_logits: torch.Tensor, - chunk_size: int, - tp_group: torch.distributed.ProcessGroup, - inference_only: bool = False -) -> torch.Tensor -``` - - - - - - -staticmethod - - - - - - - - - -```python -class nemo_rl.distributed.model_utils.ChunkedDistributedGatherLogprob() -``` - - - - - - -**Bases:** `Function` - -Compute distributed log-softmax once and gather logprobs at given global indices. - -Forward computes per-chunk distributed log-softmax across TP, gathers selected -log probabilities at the provided global indices (shape [B, S, K]), and returns -a tensor of shape [B, S, K]. - - - - - - -```python -nemo_rl.distributed.model_utils.ChunkedDistributedGatherLogprob.backward( - ctx: typing.Any, - grad_outputs: torch.Tensor = () -) -> tuple[torch.Tensor, None, None, None, None, None, None] -``` - - - - - - -staticmethod - - - - - - - -```python -nemo_rl.distributed.model_utils.ChunkedDistributedGatherLogprob.forward( - ctx: typing.Any, - vocab_parallel_logits: torch.Tensor, - global_indices: torch.Tensor, - vocab_start_index: int, - vocab_end_index: int, - chunk_size: int, - tp_group: torch.distributed.ProcessGroup, - inference_only: bool = False -) -> torch.Tensor -``` - - - - - - -staticmethod - - - - - - - - - -```python -class nemo_rl.distributed.model_utils.ChunkedDistributedLogprob() -``` - - - - - - -**Bases:** `Function` - -Custom autograd function for computing log probabilities in a distributed setting. - -The log probabilities computation is chunked in the sequence dimension -to mitigate GPU OOM (especially during backward pass). -In addition, logits casting from float16 or bfloat16 -> float32 is performed -inside the chunk loop to avoid materializing a whole float32 logits tensor. - -Adapted from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 - - - - - - -```python -nemo_rl.distributed.model_utils.ChunkedDistributedLogprob.backward( - ctx: typing.Any, - grad_outputs: torch.Tensor = () -) -> tuple[torch.Tensor, None, None, None, None, None, None] -``` - - - - - - -staticmethod - - - - - - - -```python -nemo_rl.distributed.model_utils.ChunkedDistributedLogprob.forward( - ctx: typing.Any, - vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, - vocab_start_index: int, - vocab_end_index: int, - chunk_size: int, - tp_group: torch.distributed.ProcessGroup, - inference_only: bool = False -) -> torch.Tensor -``` - - - - - - -staticmethod - - - - - - - - - -```python -class nemo_rl.distributed.model_utils.DistributedLogprob() -``` - - - - - - -**Bases:** `Function` - -Custom autograd function for computing log probabilities in a distributed setting. - -Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 - - - - - - -```python -nemo_rl.distributed.model_utils.DistributedLogprob.backward( - ctx: typing.Any, - grad_outputs: torch.Tensor = () -) -> tuple[torch.Tensor, None, None, None, None, None, None] -``` - - - - - - -staticmethod - - - - - - - -```python -nemo_rl.distributed.model_utils.DistributedLogprob.forward( - ctx: typing.Any, - vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, - vocab_start_index: int, - vocab_end_index: int, - group: torch.distributed.ProcessGroup, - inference_only: bool = False -) -> torch.Tensor -``` - - - - - - -staticmethod - - - - - - - - - -```python -nemo_rl.distributed.model_utils._compute_distributed_log_softmax( - vocab_parallel_logits: torch.Tensor, - group: torch.distributed.ProcessGroup -) -> torch.Tensor -``` - - - - - - -Compute a stable distributed log softmax across tensor parallel workers. - -Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265 - -**Parameters:** - - -Logits tensor with shape [batch_size, seq_length, vocab_size//TP] -where TP is the tensor parallel size. - - - -Process group for the all-reduce operations. - - -**Returns:** `torch.Tensor` - -torch.Tensor: Log softmax output with the same shape as input, but values represent -log probabilities normalized across the full vocabulary dimension. - - - - - - - - -```python -nemo_rl.distributed.model_utils._get_tokens_on_this_cp_rank( - input_ids: torch.Tensor, - cp_rank: int, - cp_size: int, - seq_dim: int = 1 -) -> torch.Tensor -``` - - - - - - -Get tokens on this context parallelism rank. - -Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1. - -**Parameters:** - - -Input token IDs [seq_length, ] - - - -Context parallelism rank - - - -Context parallelism size - - -**Returns:** `torch.Tensor` - -Tokens on this context parallelism rank [1, seq_length // cp_size] - - - - - - - - -```python -nemo_rl.distributed.model_utils.allgather_cp_sharded_tensor( - tensor, - cp_group, - seq_dim = 1 -) -``` - - - - - - - - - - - - - -```python -nemo_rl.distributed.model_utils.distributed_vocab_topk( - vocab_parallel_logits: torch.Tensor, - k: int, - tp_group: torch.distributed.ProcessGroup, - vocab_start_index: int, - vocab_end_index: int, - chunk_size: typing.Optional[int] = None -) -> tuple[torch.Tensor, torch.Tensor] -``` - - - - - - -Compute global top-k over TP-sharded vocabulary logits. - -**Parameters:** - - -[B, S, V_local] - - - -number of top tokens to select globally - - - -tensor-parallel process group - - - -global vocab start for this rank (inclusive) - - - -global vocab end for this rank (exclusive) - - - -optional chunk along sequence dim to bound memory - - -**Returns:** `torch.Tensor` - -[B, S, k] - - - - - - - - -```python -nemo_rl.distributed.model_utils.dtensor_from_parallel_logits_to_logprobs( - vocab_parallel_logits: torch.Tensor, - target: torch.distributed.tensor.DTensor | torch.Tensor, - vocab_start_index: int, - vocab_end_index: int, - tp_group: torch.distributed.ProcessGroup, - inference_only: bool = False, - seq_index: typing.Optional[torch.Tensor] = None, - chunk_size: typing.Optional[int] = None -) -> torch.Tensor -``` - - - - - - -Get log probabilities from TP+CP sharded vocab logits. - -**Parameters:** - - -Logits distributed across tensor parallel workers, -with shape [batch_size, seq_len, vocab_size/tp_size]. - - - -Target token indices with shape [batch_size, seq_len]. -NOTE: Must be the unmodified targets as this function will shift them internally. - - - -Starting vocabulary index for this worker's partition. - - - -Ending vocabulary index for this worker's partition. - - - -Process group for distributed communication. - - - -If True, tensors won't be saved for backward pass. Defaults to False. - - - -Sequence index tensor with shape [seq_len]. -It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension. - - - -Sequence dimension chunk size for computing the log probabilities. - - -**Returns:** `torch.Tensor` - -torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. -The sequence dimension is reduced by 1 due to the target shifting. - - - - - - - - -```python -nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs( - vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, - vocab_start_index: int, - vocab_end_index: int, - tp_group: torch.distributed.ProcessGroup, - inference_only: bool = False, - cp_group: typing.Optional[torch.distributed.ProcessGroup] = None, - chunk_size: typing.Optional[int] = None -) -> torch.Tensor -``` - - - - - - -Get log probabilities from TP+CP sharded vocab logits. - -Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 - -**Parameters:** - - -Logits tensor with shape [batch_size, seq_len // CP, vocab_size // TP] -where TP is the tensor parallel size. - - - -Target token indices with shape [batch_size, seq_len]. -NOTE: Must be the unmodified targets as this function will shift them internally. - - - -Starting vocabulary index for this worker's partition. - - - -Ending vocabulary index for this worker's partition. - - - -Process group for distributed communication. - - - -If True, tensors won't be saved for backward pass. Defaults to False. - - - -Context parallelism process group. Defaults to None. - - - -Sequence dimension chunk size for computing the log probabilities. - - -**Returns:** `torch.Tensor` - -torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. -The sequence dimension is reduced by 1 due to the target shifting. - - - - - - - - -```python -nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs_packed_sequences( - vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, - cu_seqlens_padded: torch.Tensor, - unpacked_seqlen: int, - vocab_start_index: int, - vocab_end_index: int, - group: torch.distributed.ProcessGroup, - inference_only: bool = False, - cp_group: typing.Optional[torch.distributed.ProcessGroup] = None, - chunk_size: typing.Optional[int] = None -) -> torch.Tensor -``` - - - - - - -Get log probabilities from TP sharded vocab logits for packed sequences. - -**Parameters:** - - -Packed logits tensor with shape [1, T // CP, vocab_size//TP] -where T is the total number of tokens across all packed sequences. - - - -Packed target token indices with shape [1, T]. -NOTE: Must be the unmodified targets as this function will shift them internally. - - - -Cumulative sequence lengths tensor with shape [batch_size + 1]. -cu_seqlens[i] indicates the start position of sequence i in the packed format. - - - -The length of the unpacked sequence tensor. - - - -Starting vocabulary index for this worker's partition. - - - -Ending vocabulary index for this worker's partition. - - - -Process group for distributed communication. - - - -If True, tensors won't be saved for backward pass. Defaults to False. - - - -Context parallelism process group. Defaults to None. - - - -Sequence dimension chunk size for computing the log probabilities. - - -**Returns:** `torch.Tensor` - -torch.Tensor: Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. -The total length is reduced by batch_size due to target shifting (one token per sequence). - - - - - - - - -```python -nemo_rl.distributed.model_utils.gather_logits_at_global_indices( - vocab_parallel_logits: torch.Tensor, - global_indices: torch.Tensor, - tp_group: typing.Optional[torch.distributed.ProcessGroup] = None, - cp_group: typing.Optional[torch.distributed.ProcessGroup] = None, - vocab_start_index: int, - vocab_end_index: int, - chunk_size: typing.Optional[int] = None -) -> torch.Tensor -``` - - - - - - -Gather student logits at given global token indices under TP+CP sharding. - -Differentiable w.r.t. vocab_parallel_logits. - -**Parameters:** - - -[B, S_cp, V_local] where S_cp is CP sharded sequence length - - - -[B, S_full, k] where S_full is full sequence length - - - -Optional tensor-parallel process group. If None, treats logits as full-vocab (no TP) and skips TP all-reduce. - - - -global vocab start for this rank (inclusive) - - - -global vocab end for this rank (exclusive) - - - -optional chunk along sequence dim to bound memory - - - -Optional context-parallel process group - - -**Returns:** `torch.Tensor` - -[B, S_full, k] - - - - - - - - -```python -nemo_rl.distributed.model_utils.get_logprobs_from_vocab_parallel_logits( - vocab_parallel_logits: torch.distributed.tensor.DTensor, - input_ids: torch.Tensor | torch.distributed.tensor.DTensor, - seq_index: typing.Optional[torch.Tensor] = None, - chunk_size: typing.Optional[int] = None -) -``` - - - - - - -Computes log probabilities from vocabulary-parallel logits. - -This function takes logits that are sharded across the vocabulary dimension (tensor parallel) -and computes the log probabilities for the given input IDs. - -**Parameters:** - - -Logits distributed across tensor parallel workers, -with shape [batch_size, seq_len, vocab_size/tp_size]. - - - -Input token IDs for which to compute log probabilities, -with shape [batch_size, seq_len]. - - - -Sequence index for the input IDs, -with shape [sequence_length]. - - - -Sequence dimension chunk size for computing log probabilities. - - -**Returns:** - -torch.Tensor: Log probabilities for the given input IDs. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx deleted file mode 100644 index 20d0bde..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx +++ /dev/null @@ -1,236 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/distributed/named_sharding -title: nemo_rl.distributed.named_sharding ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`NamedSharding`](#nemo_rl-distributed-named_sharding-NamedSharding) | Represents an N-dimensional arrangement of ranks with named axes, facilitating data sharding, replication, and collection based on these axes. | - -### API - - - - - -```python -class nemo_rl.distributed.named_sharding.NamedSharding( - layout: typing.Sequence[typing.Any] | numpy.ndarray, - names: list[str] -) -``` - - - - - - -Represents an N-dimensional arrangement of ranks with named axes, facilitating data sharding, replication, and collection based on these axes. - - - - - - - - - - - - -Returns the underlying NumPy array representing the layout. - - - -Returns the names of the axes. - - - -Returns the number of dimensions. - - - -Returns the shape of the rank layout. - - - -Returns the total number of ranks. - - - - - -```python -nemo_rl.distributed.named_sharding.NamedSharding.__eq__( - other: object -) -> bool -``` - - - - - - - - - - - - -```python -nemo_rl.distributed.named_sharding.NamedSharding.__repr__() -> str -``` - - - - - - - - - - - - -```python -nemo_rl.distributed.named_sharding.NamedSharding.get_axis_index( - name: str -) -> int -``` - - - - - - -Gets the numerical index of a named axis. - - - - - - - -```python -nemo_rl.distributed.named_sharding.NamedSharding.get_axis_size( - name: str -) -> int -``` - - - - - - -Gets the size of a named axis. - - - - - - - -```python -nemo_rl.distributed.named_sharding.NamedSharding.get_ranks( - kwargs: int = {} -) -> typing.Union[nemo_rl.distributed.named_sharding.NamedSharding, int] -``` - - - - - - -Gets the ranks corresponding to specific indices along named axes. - -**Parameters:** - - -Keyword arguments where the key is the axis name (e.g., "dp", "tp") - and the value is the index along that axis. - - -**Returns:** `Union[NamedSharding, int]` - -A new NamedSharding instance representing the subset of ranks. - -**Raises:** - -- `ValueError`: If an invalid axis name is provided or if an index is out of bounds. - - - - - - - -```python -nemo_rl.distributed.named_sharding.NamedSharding.get_ranks_by_coord( - coords: int = {} -) -> list[int] -``` - - - - - - -Gets all ranks that match the specified coordinates for named axes. - -**Parameters:** - - -Keyword arguments where the key is the axis name (e.g., "dp", "tp") - and the value is the integer coordinate along that axis. - Axes not specified will match all coordinates along that axis. - - -**Returns:** `list[int]` - -A sorted list of unique rank integers that match the given coordinate criteria. - -**Raises:** - -- `ValueError`: If an invalid axis name is provided. - - - - - - - -```python -nemo_rl.distributed.named_sharding.NamedSharding.get_worker_coords( - worker_id: int -) -> dict[str, int] -``` - - - - - - -Gets the coordinates of a specific worker ID in the sharding layout. - -**Parameters:** - - -The integer ID of the worker. - - -**Returns:** `dict[str, int]` - -A dictionary mapping axis names to their integer coordinates for the given worker_id. - -**Raises:** - -- `ValueError`: If the worker_id is not found in the layout. - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx deleted file mode 100644 index 5396879..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx +++ /dev/null @@ -1,105 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry -title: nemo_rl.distributed.ray_actor_environment_registry ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`get_actor_python_env`](#nemo_rl-distributed-ray_actor_environment_registry-get_actor_python_env) | - | - -### Data - -[`ACTOR_ENVIRONMENT_REGISTRY`](#nemo_rl-distributed-ray_actor_environment_registry-ACTOR_ENVIRONMENT_REGISTRY) - -[`MCORE_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-MCORE_EXECUTABLE) - -[`SGLANG_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-SGLANG_EXECUTABLE) - -[`USE_SYSTEM_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-USE_SYSTEM_EXECUTABLE) - -[`VLLM_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-VLLM_EXECUTABLE) - -### API - - - - - -```python -nemo_rl.distributed.ray_actor_environment_registry.get_actor_python_env( - actor_class_fqn: str -) -> str -``` - - - - - - - - - - - - - -```python -nemo_rl.distributed.ray_actor_environment_registry.ACTOR_ENVIRONMENT_REGISTRY: dict[str, str] = {'nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker': VLLM_EXECUTA... -``` - - - - - - - - - -```python -nemo_rl.distributed.ray_actor_environment_registry.MCORE_EXECUTABLE = PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.MCORE -``` - - - - - - - - - -```python -nemo_rl.distributed.ray_actor_environment_registry.SGLANG_EXECUTABLE = PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.SGLANG -``` - - - - - - - - - -```python -nemo_rl.distributed.ray_actor_environment_registry.USE_SYSTEM_EXECUTABLE = os.environ.get('NEMO_RL_PY_EXECUTABLES_SYSTEM', '0') == '1' -``` - - - - - - - - - -```python -nemo_rl.distributed.ray_actor_environment_registry.VLLM_EXECUTABLE = PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.VLLM -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx deleted file mode 100644 index ebd4125..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx +++ /dev/null @@ -1,73 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/distributed/stateless_process_group -title: nemo_rl.distributed.stateless_process_group ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`StatelessProcessGroup`](#nemo_rl-distributed-stateless_process_group-StatelessProcessGroup) | - | - -### API - - - - - -```python -class nemo_rl.distributed.stateless_process_group.StatelessProcessGroup( - master_address: str, - port: int, - rank: int, - world_size: int -) -``` - - - - - - - - - - - - -```python -nemo_rl.distributed.stateless_process_group.StatelessProcessGroup.broadcast( - tensor: torch.Tensor, - src: int, - stream: typing.Optional[torch.cuda.Stream] = None -) -``` - - - - - - - - - - - - -```python -nemo_rl.distributed.stateless_process_group.StatelessProcessGroup.init_nccl_communicator( - device: int -) -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx deleted file mode 100644 index 108df41..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx +++ /dev/null @@ -1,514 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/distributed/virtual_cluster -title: nemo_rl.distributed.virtual_cluster ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`ClusterConfig`](#nemo_rl-distributed-virtual_cluster-ClusterConfig) | - | -| [`GetGPUIDActor`](#nemo_rl-distributed-virtual_cluster-GetGPUIDActor) | Util actor class to return GPU id of the current worker. | -| [`PY_EXECUTABLES`](#nemo_rl-distributed-virtual_cluster-PY_EXECUTABLES) | - | -| [`RayVirtualCluster`](#nemo_rl-distributed-virtual_cluster-RayVirtualCluster) | Creates a virtual distributed cluster using Ray placement groups. | -| [`ResourceInsufficientError`](#nemo_rl-distributed-virtual_cluster-ResourceInsufficientError) | Exception raised when the cluster does not have enough resources to satisfy the requested configuration. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_get_free_port_local`](#nemo_rl-distributed-virtual_cluster-_get_free_port_local) | - | -| [`_get_node_ip_and_free_port`](#nemo_rl-distributed-virtual_cluster-_get_node_ip_and_free_port) | - | -| [`_get_node_ip_local`](#nemo_rl-distributed-virtual_cluster-_get_node_ip_local) | - | -| [`init_ray`](#nemo_rl-distributed-virtual_cluster-init_ray) | Initialise Ray. | - -### Data - -[`dir_path`](#nemo_rl-distributed-virtual_cluster-dir_path) - -[`git_root`](#nemo_rl-distributed-virtual_cluster-git_root) - -[`logger`](#nemo_rl-distributed-virtual_cluster-logger) - -### API - - - - - -```python -class nemo_rl.distributed.virtual_cluster.ClusterConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.distributed.virtual_cluster.GetGPUIDActor() -``` - - - - - - -Util actor class to return GPU id of the current worker. - - - - - - -```python -nemo_rl.distributed.virtual_cluster.GetGPUIDActor.get_gpu_id() -``` - - - - - - - - - - - - - - -```python -class nemo_rl.distributed.virtual_cluster.PY_EXECUTABLES() -``` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.distributed.virtual_cluster.RayVirtualCluster( - bundle_ct_per_node_list: list[int], - use_gpus: bool = True, - max_colocated_worker_groups: int = 1, - num_gpus_per_node: int = 8, - name: str = '', - placement_group_strategy: str = 'SPREAD' -) -``` - - - - - - -Creates a virtual distributed cluster using Ray placement groups. - -This class simplifies distributed training setup by: -- Creating placement groups that represent logical compute nodes -- Allocating GPU and CPU resources for distributed workers -- Managing communication between distributed processes - -- Bundle: A resource allocation unit (ex: 4 GPUs on a single node) -- Worker: A process that performs computation (model training/inference) -- Node: A physical or virtual machine containing multiple bundles - - - - - - - - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.RayVirtualCluster.__del__() -> None -``` - - - - - - -Shutsdown the virtual cluster when the object is deleted or is garbage collected. - -This is an extra safety net in case the user forgets to call shutdown and the pointer to -the cluster is lost due to leaving a function scope. It's always recommended that the -user calls shutdown(). - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.RayVirtualCluster._create_placement_groups_internal( - strategy: str, - use_unified_pg: bool = False -) -> list[ray.util.placement_group.PlacementGroup] -``` - - - - - - -Internal method to create placement groups without retry logic. - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.RayVirtualCluster._get_sorted_bundle_indices() -> typing.Optional[list[int]] -``` - - - - - - -Gets the sorted bundle indices for the placement groups. - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups( - strategy: str | None = None, - use_unified_pg: bool = False -) -> list[ray.util.placement_group.PlacementGroup] -``` - - - - - - -Creates placement groups based on whether cross-node model parallelism is needed. - -**Parameters:** - - -Ray placement group strategy (defaults to self.placement_group_strategy) - - - -If True, create a single unified placement group. - If False, create per-node placement groups. - - -**Returns:** `list[PlacementGroup]` - -List of placement groups - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.RayVirtualCluster.get_available_address_and_port( - pg_idx: int, - bundle_idx: int -) -> tuple[str, int] -``` - - - - - - -Gets an available address and port for the given placement group index and bundle index. - -**Returns:** `tuple[str, int]` - -Tuple of (address, port) - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.RayVirtualCluster.get_master_address_and_port() -> tuple[str, int] -``` - - - - - - -Gets the master address and port for the distributed training setup. - -**Returns:** `tuple[str, int]` - -Tuple of (address, port) - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.RayVirtualCluster.get_placement_groups() -> list[ray.util.placement_group.PlacementGroup] -``` - - - - - - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.RayVirtualCluster.node_count() -> int -``` - - - - - - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.RayVirtualCluster.shutdown() -> bool -``` - - - - - - -Cleans up and releases all resources associated with this virtual cluster. - -This includes removing all placement groups and resetting the internal state. - -This method is idempotent and can be safely called multiple times. - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.RayVirtualCluster.world_size() -> int -``` - - - - - - - - - - - - - - -```python -class nemo_rl.distributed.virtual_cluster.ResourceInsufficientError() -``` - - - - - - -Exception - -**Bases:** `Exception` - -Exception raised when the cluster does not have enough resources to satisfy the requested configuration. - - - - - - - - -```python -nemo_rl.distributed.virtual_cluster._get_free_port_local() -> int -``` - - - - - - - - - - - - - -```python -nemo_rl.distributed.virtual_cluster._get_node_ip_and_free_port() -> tuple[str, int] -``` - - - - - - - - - - - - - -```python -nemo_rl.distributed.virtual_cluster._get_node_ip_local() -> str -``` - - - - - - - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.init_ray( - log_dir: typing.Optional[str] = None -) -> None -``` - - - - - - -Initialise Ray. - -Try to attach to an existing local cluster. -If that cluster uses the same CUDA_VISIBLE_DEVICES or Slurm managed tag we will reuse it. -Otherwise, we will detach and start a fresh local cluster. - -**Parameters:** - - -Optional directory to store Ray logs and temp files. - - - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.dir_path = os.path.dirname(os.path.abspath(__file__)) -``` - - - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.git_root = os.path.abspath(os.path.join(dir_path, '../..')) -``` - - - - - - - - - -```python -nemo_rl.distributed.virtual_cluster.logger = logging.getLogger(__name__) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx deleted file mode 100644 index 8519d0f..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx +++ /dev/null @@ -1,81 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/distributed/worker_group_utils -title: nemo_rl.distributed.worker_group_utils ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`get_nsight_config_if_pattern_matches`](#nemo_rl-distributed-worker_group_utils-get_nsight_config_if_pattern_matches) | Check if worker name matches patterns in NRL_NSYS_WORKER_PATTERNS and return nsight config. | -| [`recursive_merge_options`](#nemo_rl-distributed-worker_group_utils-recursive_merge_options) | Recursively merge extra options into default options using OmegaConf. | - -### API - - - - - -```python -nemo_rl.distributed.worker_group_utils.get_nsight_config_if_pattern_matches( - worker_name: str -) -> dict[str, typing.Any] -``` - - - - - - -Check if worker name matches patterns in NRL_NSYS_WORKER_PATTERNS and return nsight config. - -**Parameters:** - - -Name of the worker to check against patterns - - -**Returns:** `dict[str, Any]` - -Dictionary containing {"nsight": config} if pattern matches, empty dict otherwise - - - - - - - - -```python -nemo_rl.distributed.worker_group_utils.recursive_merge_options( - default_options: dict[str, typing.Any], - extra_options: dict[str, typing.Any] -) -> dict[str, typing.Any] -``` - - - - - - -Recursively merge extra options into default options using OmegaConf. - -**Parameters:** - - -Default options dictionary (lower precedence) - - - -Extra options provided by the caller (higher precedence) - - -**Returns:** `dict[str, Any]` - -Merged options dictionary with extra_options taking precedence over default_options - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx deleted file mode 100644 index e0205d5..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx +++ /dev/null @@ -1,603 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/distributed/worker_groups -title: nemo_rl.distributed.worker_groups ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`MultiWorkerFuture`](#nemo_rl-distributed-worker_groups-MultiWorkerFuture) | Container for Ray futures with associated worker information. | -| [`RayWorkerBuilder`](#nemo_rl-distributed-worker_groups-RayWorkerBuilder) | - | -| [`RayWorkerGroup`](#nemo_rl-distributed-worker_groups-RayWorkerGroup) | Manages a group of distributed Ray worker/actor processes that execute tasks in parallel. | - -### API - - - - - -```python -class nemo_rl.distributed.worker_groups.MultiWorkerFuture( - futures: list[ray.ObjectRef], - return_from_workers: typing.Optional[list[int]] = None, - called_workers: typing.Optional[list[int]] = None -) -``` - - - - - - -Dataclass - -Container for Ray futures with associated worker information. - - - - - - - - - - - - - - -```python -nemo_rl.distributed.worker_groups.MultiWorkerFuture.get_results( - worker_group: nemo_rl.distributed.worker_groups.RayWorkerGroup, - return_generators_as_proxies: bool = False -) -> list[typing.Any] -``` - - - - - - -Get results from the futures, optionally respecting tied workers. - -The method uses worker_group.worker_to_tied_group_index to identify which tied -worker group each worker belongs to, then selects only the first result from each group. - -**Parameters:** - - -The RayWorkerGroup that spawned the futures. The -mapping contained in worker_group.worker_to_tied_group_index -is required for the deduplication path. - - - -If True, and a future is an ObjectRefGenerator, - return the ObjectRefGenerator itself instead of consuming it. - - -**Returns:** `list[Any]` - -List of results - - - - - - - - - -```python -class nemo_rl.distributed.worker_groups.RayWorkerBuilder( - ray_actor_class_fqn: str, - args = (), - kwargs = {} -) -``` - - - - - - - - - - -```python -nemo_rl.distributed.worker_groups.RayWorkerBuilder.__call__( - placement_group: ray.util.placement_group.PlacementGroup, - placement_group_bundle_index: int, - num_gpus: float | int, - bundle_indices: typing.Optional[tuple[int, list[int]]] = None, - extra_options: typing.Any = {} -) -> ray.actor.ActorHandle -``` - - - - - - -Create a Ray worker with the specified configuration. - -Order of precedence for worker options configuration (from lowest to highest): -1. Options passed by the user to __call__ (extra_options) -2. Options required by the worker via configure_worker (may override user options with warning) -3. Options set by the RayWorkerBuilder.__call__ (specifically scheduling strategy) - -If the worker needs to override user-provided options, it should log a warning -to inform the user about the change and the reason for it. - -**Parameters:** - - -Ray placement group for resource allocation - - - -Index of the bundle in the placement group - - - -Number of GPUs to allocate to this worker (can be fractional) - - - -Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) - - - -Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method) - - -**Returns:** `ray.actor.ActorHandle` - -A Ray actor reference to the created worker - - - - - - - -```python -nemo_rl.distributed.worker_groups.RayWorkerBuilder.create_worker_async( - placement_group: ray.util.placement_group.PlacementGroup, - placement_group_bundle_index: int, - num_gpus: float | int, - bundle_indices: typing.Optional[tuple[int, list[int]]] = None, - extra_options: typing.Any = {} -) -> tuple[ray.ObjectRef, ray.actor.ActorHandle] -``` - - - - - - -Create a Ray worker asynchronously, returning futures. - -This method returns immediately with futures that can be awaited later. - -**Parameters:** - - -Ray placement group for resource allocation - - - -Index of the bundle in the placement group - - - -Number of GPUs to allocate to this worker (can be fractional) - - - -Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) - - - -Additional options to pass to the Ray actor - - -**Returns:** `tuple[ray.ObjectRef, ray.actor.ActorHandle]` - -Tuple of (worker_future, initializer_actor): -- worker_future: A Ray ObjectRef that will resolve to the worker actor -- initializer_actor: The initializer actor (needed to prevent GC) - - - - - - - - - -```python -class nemo_rl.distributed.worker_groups.RayWorkerGroup( - cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, - remote_worker_builder: nemo_rl.distributed.worker_groups.RayWorkerBuilder, - workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None, - name_prefix: str = '', - bundle_indices_list: typing.Optional[list[tuple[int, list[int]]]] = None, - sharding_annotations: typing.Optional[nemo_rl.distributed.named_sharding.NamedSharding] = None, - env_vars: dict[str, str] = {} -) -``` - - - - - - -Manages a group of distributed Ray worker/actor processes that execute tasks in parallel. - -This class creates and manages Ray actor instances that run on resources -allocated by a RayVirtualCluster. It handles: -- Worker creation and placement on specific GPU resources -- Setting up distributed training environment variables (rank, world size, etc.) -- Executing methods across all workers in parallel -- Collecting and aggregating results -- Support for tied worker groups where multiple workers process the same data - - - - - - - - - - - - -Number of data parallel shards. - - - - - - - - - - - -```python -nemo_rl.distributed.worker_groups.RayWorkerGroup._create_workers_from_bundle_indices( - remote_worker_builder: nemo_rl.distributed.worker_groups.RayWorkerBuilder, - bundle_indices_list: list[tuple[int, list[int]]], - env_vars: dict[str, str] = {} -) -> None -``` - - - - - - -Create workers based on explicit bundle indices for tied worker groups. - -**Parameters:** - - -Builder function for Ray actors - - - -List of (node_idx, local_bundle_indices) tuples, where each tuple - specifies a tied group with its node and local bundle indices. If the local_bundle_indices - spans multiple nodes, the node_idx will be the first node's index in the tied group. - - - - - - - - -```python -nemo_rl.distributed.worker_groups.RayWorkerGroup.get_all_worker_results( - future_bundle: nemo_rl.distributed.worker_groups.MultiWorkerFuture, - return_generators_as_proxies: bool = False -) -> list[typing.Any] -``` - - - - - - -Get results from all workers, optionally filtering to get just one result per tied worker group. - -**Parameters:** - - -MultiWorkerFuture containing futures and worker information. - - - -If True, and a future in the bundle is an ObjectRefGenerator, - return the ObjectRefGenerator itself instead of consuming it. - - -**Returns:** `list[Any]` - -List of results, deduplicated as specified in the future_bundle - - - - - - - -```python -nemo_rl.distributed.worker_groups.RayWorkerGroup.get_dp_leader_worker_idx( - dp_shard_idx: int -) -> int -``` - - - - - - -Returns the index of the primary worker for a given data parallel shard. - - - - - - - -```python -nemo_rl.distributed.worker_groups.RayWorkerGroup.run_all_workers_multiple_data( - method_name: str, - args = (), - run_rank_0_only_axes: list[str] | None = None, - common_kwargs: typing.Optional[dict[str, typing.Any]] = None, - kwargs = {} -) -> list[ray.ObjectRef] -``` - - - - - - -Run a method on all workers in parallel with different data. - -**Parameters:** - - -Name of the method to call on each worker - - - -List of arguments to pass to workers/groups - e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]] - - - -List of named axes for which only rank 0 should run the method. - - - -Keyword arguments to pass to all workers - - - -Keyword arguments to pass to workers/groups - e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]} - - -**Returns:** `list[ray.ObjectRef]` - -list[ray.ObjectRef]: A list of ray futures - - - - - - - -```python -nemo_rl.distributed.worker_groups.RayWorkerGroup.run_all_workers_sharded_data( - method_name: str, - args = (), - in_sharded_axes: list[str] | None = None, - replicate_on_axes: list[str] | None = None, - output_is_replicated: list[str] | None = None, - make_dummy_calls_to_free_axes: bool = False, - common_kwargs: typing.Optional[dict[str, typing.Any]] = None, - kwargs = {} -) -> nemo_rl.distributed.worker_groups.MultiWorkerFuture -``` - - - - - - -Run a method on all workers in parallel with sharded data. - -Axes in in_sharded_axes: Data is already split across these axes, so we just send the appropriate slice to each worker (along this axis) -Axes in replicate_on_axes: Data is replicated to all workers along these dimensions -Free axes (axes not in either list): Data is only sent to workers at index 0 of these axes - -**Parameters:** - - -Name of the method to call on each worker - - - -List of arguments to pass to workers/groups - e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]] - - - -List of axes that are sharded - - - -List of axes that are to be replicated - - - -List of axes along which the output is replicated (and we should just return the first result). - We also just return from rank 0 of free axes. - - - -Whether to make dummy calls (with None) to workers that - aren't rank 0 on 'free axes' (axes not in in_sharded_axes or replicate_on_axes). - - - -Keyword arguments to pass to all workers - - - -Keyword arguments to pass to workers/groups - e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]} - - -**Returns:** `MultiWorkerFuture` - -Object containing futures and their associated worker information - - - - - - - -```python -nemo_rl.distributed.worker_groups.RayWorkerGroup.run_all_workers_single_data( - method_name: str, - args = (), - run_rank_0_only_axes: list[str] | None = None, - kwargs = {} -) -> list[ray.ObjectRef] -``` - - - - - - -Run a method on all workers in parallel with the same data. - -**Parameters:** - - -Name of the method to call on each worker - - - -Arguments to pass to the method - - - -List of named axes for which only rank 0 should run the method. - - -**Returns:** `list[ray.ObjectRef]` - -list[ray.ObjectRef]: A list of ray futures - - - - - - - -```python -nemo_rl.distributed.worker_groups.RayWorkerGroup.run_single_worker_single_data( - method_name: str, - worker_idx: int, - args = (), - kwargs = {} -) -> ray.ObjectRef -``` - - - - - - -Run a method on a single, specific worker. - -**Parameters:** - - -Name of the method to call on the worker. - - - -The index of the worker to run the method on. - - - -Arguments to pass to the method. - - -**Returns:** `ray.ObjectRef` - -ray.ObjectRef: A Ray future for the result. - - - - - - - -```python -nemo_rl.distributed.worker_groups.RayWorkerGroup.shutdown( - cleanup_method: typing.Optional[str] = None, - timeout: typing.Optional[float] = 30.0, - force: bool = False -) -> bool -``` - - - - - - -Shutdown all workers in the worker group. - -**Parameters:** - - -Optional method name to call on each worker before termination. - If provided, this method will be called on each worker to allow - for graceful cleanup. - - - -Timeout in seconds for graceful shutdown. Only applicable if cleanup_method is provided. - If None, wait indefinitely for workers to complete their cleanup. - - - -If True, forcefully terminate workers with ray.kill() even if cleanup_method is provided. - If cleanup_method is None, workers are always forcefully terminated. - - -**Returns:** `bool` - -True if all workers were successfully shut down - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx deleted file mode 100644 index de15010..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx +++ /dev/null @@ -1,19 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments -title: nemo_rl.environments ---- - -## Submodules - -- **[`nemo_rl.environments.code_environment`](/nemo-rl/nemo_rl/environments/code_environment)** -- **[`nemo_rl.environments.code_jaccard_environment`](/nemo-rl/nemo_rl/environments/code_jaccard_environment)** -- **[`nemo_rl.environments.dapo_math_verifier`](/nemo-rl/nemo_rl/environments/dapo_math_verifier)** -- **[`nemo_rl.environments.interfaces`](/nemo-rl/nemo_rl/environments/interfaces)** -- **[`nemo_rl.environments.math_environment`](/nemo-rl/nemo_rl/environments/math_environment)** -- **[`nemo_rl.environments.metrics`](/nemo-rl/nemo_rl/environments/metrics)** -- **[`nemo_rl.environments.nemo_gym`](/nemo-rl/nemo_rl/environments/nemo_gym)** -- **[`nemo_rl.environments.reward_model_environment`](/nemo-rl/nemo_rl/environments/reward_model_environment)** -- **[`nemo_rl.environments.rewards`](/nemo-rl/nemo_rl/environments/rewards)** -- **[`nemo_rl.environments.utils`](/nemo-rl/nemo_rl/environments/utils)** -- **[`nemo_rl.environments.vlm_environment`](/nemo-rl/nemo_rl/environments/vlm_environment)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx deleted file mode 100644 index 5c46941..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx +++ /dev/null @@ -1,290 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments/code_environment -title: nemo_rl.environments.code_environment ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`CodeEnvConfig`](#nemo_rl-environments-code_environment-CodeEnvConfig) | - | -| [`CodeEnvMetadata`](#nemo_rl-environments-code_environment-CodeEnvMetadata) | - | -| [`CodeEnvironment`](#nemo_rl-environments-code_environment-CodeEnvironment) | Code execution environment that maintains state between steps. | -| [`CodeExecutionWorker`](#nemo_rl-environments-code_environment-CodeExecutionWorker) | Helper class to process individual code execution steps. | - -### API - - - - - -```python -class nemo_rl.environments.code_environment.CodeEnvConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.environments.code_environment.CodeEnvMetadata -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.environments.code_environment.CodeEnvironment( - cfg: nemo_rl.environments.code_environment.CodeEnvConfig -) -``` - - - - - - -**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) - -Code execution environment that maintains state between steps. - - - - - - - - - - - - - - -```python -nemo_rl.environments.code_environment.CodeEnvironment.global_post_process_and_metrics( - batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict -) -> typing.Tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict] -``` - - - - - - -Compute metrics for the batch. - - - - - - - -```python -nemo_rl.environments.code_environment.CodeEnvironment.shutdown() -``` - - - - - - - - - - - - -```python -nemo_rl.environments.code_environment.CodeEnvironment.step( - message_log_batch: typing.List[nemo_rl.data.interfaces.LLMMessageLogType], - metadata_batch: typing.List[nemo_rl.environments.code_environment.CodeEnvMetadata], - return_extracted_answer: bool = False -) -> nemo_rl.environments.interfaces.EnvironmentReturn -``` - - - - - - -Process a batch of code execution steps. - - - - - - - - - -```python -class nemo_rl.environments.code_environment.CodeExecutionWorker() -``` - - - - - - -Helper class to process individual code execution steps. - - - - - - - - -```python -nemo_rl.environments.code_environment.CodeExecutionWorker.chdir( - dir: str -) -``` - - - - - - -Change to temporary directory for file operations. - - - - - - - -```python -nemo_rl.environments.code_environment.CodeExecutionWorker.execute( - message_batch: str, - metadata_batch: typing.List[nemo_rl.environments.code_environment.CodeEnvMetadata] -) -> typing.Tuple[typing.List[typing.Dict[str, str]], typing.List[bool], typing.List[typing.Any]] -``` - - - - - - -Execute code in a sandboxed environment. - - - - - - - -```python -nemo_rl.environments.code_environment.CodeExecutionWorker.format_result( - result: typing.Any, - code: typing.Optional[str] = None, - lookahead: typing.Optional[str] = None -) -> str -``` - - - - - - - - - - - - -```python -nemo_rl.environments.code_environment.CodeExecutionWorker.safe_import( - name: str, - args = (), - kwargs = {} -) -``` - - - - - - -Safe version of import that blocks risky modules. - - - - - - - -```python -nemo_rl.environments.code_environment.CodeExecutionWorker.safe_open( - file: str, - args = (), - kwargs = {} -) -``` - - - - - - -Safe version of open() that only allows access to temporary directory. - - - - - - - -```python -nemo_rl.environments.code_environment.CodeExecutionWorker.sanitize( - obj: typing.Any -) -> typing.Any -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx deleted file mode 100644 index 0bdc82a..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx +++ /dev/null @@ -1,268 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments/code_jaccard_environment -title: nemo_rl.environments.code_jaccard_environment ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`CodeJaccardEnvConfig`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardEnvConfig) | - | -| [`CodeJaccardEnvironment`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardEnvironment) | Environment for evaluating code responses using Jaccard similarity. | -| [`CodeJaccardEnvironmentMetadata`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardEnvironmentMetadata) | - | -| [`CodeJaccardVerifyWorker`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardVerifyWorker) | Worker for evaluating code responses using Jaccard-based similarity. | - -### API - - - - - -```python -class nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment( - cfg: nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvConfig -) -``` - - - - - - -**Bases:** [EnvironmentInterface[CodeJaccardEnvironmentMetadata]](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) - -Environment for evaluating code responses using Jaccard similarity. - - - - - - - - - - - -```python -nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment.global_post_process_and_metrics( - batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] -) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], dict[str, float | int]] -``` - - - - - - -Post-process batch and compute metrics for CodeJaccard. - - - - - - - -```python -nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment.shutdown() -> None -``` - - - - - - -Shutdown all workers. - - - - - - - -```python -nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment.step( - message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], - metadata: list[nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironmentMetadata], - return_extracted_answer: bool = False -) -> nemo_rl.environments.interfaces.EnvironmentReturn[nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironmentMetadata] -``` - - - - - - -Runs a step in the CodeJaccard environment. - -**Parameters:** - - -Batch of OpenAI-API-like message logs. - - - -Batch of CodeJaccardEnvironmentMetadata with ground truth. - - - -Whether to return extracted answers. - - -**Returns:** `EnvironmentReturn[CodeJaccardEnvironmentMetadata]` - -Tuple containing observations, metadata, stop strings, rewards, and done flags. - - - - - - - - - -```python -class nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironmentMetadata -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - -```python -class nemo_rl.environments.code_jaccard_environment.CodeJaccardVerifyWorker() -``` - - - - - - -Worker for evaluating code responses using Jaccard-based similarity. - - - - - - -```python -nemo_rl.environments.code_jaccard_environment.CodeJaccardVerifyWorker._calculate_preference_score( - response: str, - ground_truth: str -) -> float -``` - - - - - - -Calculate a Jaccard-based alignment score between response and ground truth. - -This is a simplified scoring function. In practice, you might want to use: -- Semantic similarity models -- BLEU/ROUGE scores -- Tokenize both texts into sets A and B (here we use whitespace tokenization). -- Compute intersection size |A ∩ B| and union size |A ∪ B|. -- J(A, B) = |A ∩ B| / |A ∪ B|, with guards for union=0 -> 0.0. -- Optionally combine with a length-ratio penalty to discourage degenerate very short/long matches. - -Complexity: -- Tokenization: O(n + m) -- Set ops: O(n + m) average (hash sets) - -**Parameters:** - - -The model's response - - -**Returns:** `float` - -Score between 0.0 and 1.0 - - - - - - - -```python -nemo_rl.environments.code_jaccard_environment.CodeJaccardVerifyWorker.verify( - pred_responses: list[str], - ground_truths: list[str], - return_extracted_answer: bool = False -) -> typing.Union[list[float], tuple[list[float], list[str | None]]] -``` - - - - - - -Verify code responses against ground-truth solutions using Jaccard-based similarity. - -We use a simple text similarity approach (Jaccard over tokenized words) -to evaluate how well the model's response aligns with the ground truth. - -**Parameters:** - - -list[str]. The predicted responses from the LLM. - - - -list[str]. The ground-truth solutions. - - - -bool. Whether to return extracted answers (here, the full response). - - -**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` - -Union[list[float], tuple[list[float], list[str | None]]]. - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx deleted file mode 100644 index ecea315..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx +++ /dev/null @@ -1,316 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments/dapo_math_verifier -title: nemo_rl.environments.dapo_math_verifier ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`compute_score`](#nemo_rl-environments-dapo_math_verifier-compute_score) | Compute the reward score for a solution. | -| [`is_correct_minerva`](#nemo_rl-environments-dapo_math_verifier-is_correct_minerva) | Check if the solution is correct according to Minerva criteria. | -| [`is_correct_strict_box`](#nemo_rl-environments-dapo_math_verifier-is_correct_strict_box) | Check if the prediction is correct using strict boxed answer criteria. | -| [`last_boxed_only_string`](#nemo_rl-environments-dapo_math_verifier-last_boxed_only_string) | Extract the last LaTeX boxed expression from a string. | -| [`normalize_final_answer`](#nemo_rl-environments-dapo_math_verifier-normalize_final_answer) | Normalize a final answer to a quantitative reasoning question. | -| [`remove_boxed`](#nemo_rl-environments-dapo_math_verifier-remove_boxed) | Remove the LaTeX boxed command from a string. | -| [`verify`](#nemo_rl-environments-dapo_math_verifier-verify) | Verify if the solution is correct. | - -### Data - -[`REMOVED_EXPRESSIONS`](#nemo_rl-environments-dapo_math_verifier-REMOVED_EXPRESSIONS) - -[`SUBSTITUTIONS`](#nemo_rl-environments-dapo_math_verifier-SUBSTITUTIONS) - -### API - - - - - -```python -nemo_rl.environments.dapo_math_verifier.compute_score( - solution_str: str, - ground_truth: str, - strict_box_verify: bool = False, - pause_tokens_index: typing.Optional[list[int]] = None -) -> float -``` - - - - - - -Compute the reward score for a solution. - -**Parameters:** - - -The solution string - - - -The ground truth answer - - - -Whether to use strict box verification - - - -Indices of pause tokens - - -**Returns:** `float` - -Reward score (1.0 for correct, 0.0 for incorrect) - - - - - - - - -```python -nemo_rl.environments.dapo_math_verifier.is_correct_minerva( - solution_str: str, - gt: str, - gt_need_extract: bool = False, - answer_pattern: str = '(?i)Answer\\s*:\\s*([^\\n]+)' -) -> tuple[bool, str] -``` - - - - - - -Check if the solution is correct according to Minerva criteria. - -**Parameters:** - - -The solution string to check - - - -The ground truth answer - - - -Whether the ground truth needs extraction - - - -Regex pattern to extract the answer - - -**Returns:** `tuple[bool, str]` - -Tuple of (is_correct, normalized_prediction) - - - - - - - - -```python -nemo_rl.environments.dapo_math_verifier.is_correct_strict_box( - pred: str, - gt: str, - pause_tokens_index: typing.Optional[list[int]] = None -) -> tuple[int, typing.Optional[str]] -``` - - - - - - -Check if the prediction is correct using strict boxed answer criteria. - -**Parameters:** - - -The prediction string - - - -The ground truth answer - - - -Indices of pause tokens - - -**Returns:** `tuple[int, Optional[str]]` - -Tuple of (score, extracted_prediction) - - - - - - - - -```python -nemo_rl.environments.dapo_math_verifier.last_boxed_only_string( - string: str -) -> typing.Optional[str] -``` - - - - - - -Extract the last LaTeX boxed expression from a string. - -**Parameters:** - - -Input string containing LaTeX code - - -**Returns:** `Optional[str]` - -The last boxed expression or None if not found - - - - - - - - -```python -nemo_rl.environments.dapo_math_verifier.normalize_final_answer( - final_answer: str -) -> str -``` - - - - - - -Normalize a final answer to a quantitative reasoning question. - -**Parameters:** - - -The answer string to normalize - - -**Returns:** `str` - -Normalized answer string - - - - - - - - -```python -nemo_rl.environments.dapo_math_verifier.remove_boxed( - s: str -) -> str -``` - - - - - - -Remove the LaTeX boxed command from a string. - -**Parameters:** - - -String with format "\\boxed{content}" - - -**Returns:** `str` - -The content inside the boxed command - - - - - - - - -```python -nemo_rl.environments.dapo_math_verifier.verify( - solution_str: str, - answer: str, - strict_box_verify: bool = False, - pause_tokens_index: typing.Optional[list[int]] = None -) -> bool -``` - - - - - - -Verify if the solution is correct. - -**Parameters:** - - -The solution string to verify - - - -The ground truth answer - - - -Whether to use strict box verification - - - -Indices of pause tokens - - -**Returns:** `bool` - -True if the solution is correct, False otherwise - - - - - - - - -```python -nemo_rl.environments.dapo_math_verifier.REMOVED_EXPRESSIONS = ['square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'hours', 'km', 'units... -``` - - - - - - - - - -```python -nemo_rl.environments.dapo_math_verifier.SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''), ('\\ ', ''), (' ', ''), ('mb... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx deleted file mode 100644 index 487b356..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx +++ /dev/null @@ -1,151 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments/interfaces -title: nemo_rl.environments.interfaces ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`EnvironmentInterface`](#nemo_rl-environments-interfaces-EnvironmentInterface) | - | -| [`EnvironmentReturn`](#nemo_rl-environments-interfaces-EnvironmentReturn) | Standard batched return type for environment step methods. | - -### Data - -[`MetadataT`](#nemo_rl-environments-interfaces-MetadataT) - -### API - - - - - -```python -class nemo_rl.environments.interfaces.EnvironmentInterface() -``` - - - - - - -Abstract - -**Bases:** `ABC`, `Generic[MetadataT]` - - - - - -```python -nemo_rl.environments.interfaces.EnvironmentInterface.global_post_process_and_metrics( - batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict -) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict] -``` - - - - - - -abstract - -Post processing function after all rollouts are done for the batch and returns metrics. - - - - - - - -```python -nemo_rl.environments.interfaces.EnvironmentInterface.step( - message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], - metadata: list[nemo_rl.environments.interfaces.MetadataT] -) -> nemo_rl.environments.interfaces.EnvironmentReturn[nemo_rl.environments.interfaces.MetadataT] -``` - - - - - - -abstract - -Runs a step in the environment. Allows for asynchrony with remote servers, but it's not required (this function is a ray remote). - -metadata: batch of whatever the environment needs to keep track of. I.e. - math solutions, code unit tests, or agent states. Can be None if episode terminated. - -Returns: -- EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminateds flags. - - - - - - - - - -```python -class nemo_rl.environments.interfaces.EnvironmentReturn() -``` - - - - - - -**Bases:** `NamedTuple`, `Generic[MetadataT]` - -Standard batched return type for environment step methods. - -**All elements are batched.** -observations: New observation from the environment. - It's a (batched) 'message' type, which is a dict - with keys 'role' and 'content'. -metadata: Updated metadata from the environment. -next_stop_strings: The stop strings for the next turn. - If your environment is a game or similar, - you may want to return a list of stop strings - that are valid actions for the next turn or - similar. This field lets you control this per turn. -rewards: the rewards for this turn. -terminateds: whether the episode ended this turn. -answers: the answers for this turn. - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.environments.interfaces.MetadataT = TypeVar('MetadataT') -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx deleted file mode 100644 index eccdd26..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx +++ /dev/null @@ -1,356 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments/math_environment -title: nemo_rl.environments.math_environment ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`EnglishMultichoiceVerifyWorker`](#nemo_rl-environments-math_environment-EnglishMultichoiceVerifyWorker) | - | -| [`HFVerifyWorker`](#nemo_rl-environments-math_environment-HFVerifyWorker) | - | -| [`MathEnvConfig`](#nemo_rl-environments-math_environment-MathEnvConfig) | - | -| [`MathEnvironment`](#nemo_rl-environments-math_environment-MathEnvironment) | - | -| [`MathEnvironmentMetadata`](#nemo_rl-environments-math_environment-MathEnvironmentMetadata) | - | -| [`MultilingualMultichoiceVerifyWorker`](#nemo_rl-environments-math_environment-MultilingualMultichoiceVerifyWorker) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`_mute_output`](#nemo_rl-environments-math_environment-_mute_output) | - | - -### API - - - - - -```python -class nemo_rl.environments.math_environment.EnglishMultichoiceVerifyWorker() -``` - - - - - - - - - - -```python -nemo_rl.environments.math_environment.EnglishMultichoiceVerifyWorker.verify( - pred_responses: list[str], - ground_truths: list[str], - return_extracted_answer: bool = False, - kwargs = {} -) -> typing.Union[list[float], tuple[list[float], list[str | None]]] -``` - - - - - - -Verify the correctness of the predicted responses against the ground truth. - -**Parameters:** - - -list[str]. The predicted responses from the LLM. - - - -list[str]. The ground truth responses. - - -**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` - -Union[list[float], tuple[list[float], list[str | None]]]. - - - - - - - - - -```python -class nemo_rl.environments.math_environment.HFVerifyWorker() -``` - - - - - - - - - - - - -```python -nemo_rl.environments.math_environment.HFVerifyWorker.verify( - pred_responses: list[str], - ground_truths: list[str], - return_extracted_answer: bool = False, - kwargs = {} -) -> typing.Union[list[float], tuple[list[float], list[str | None]]] -``` - - - - - - -Verify the correctness of the predicted responses against the ground truth. - -**Parameters:** - - -list[str]. The predicted responses from the LLM. - - - -list[str]. The ground truth responses. - - -**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` - -Union[list[float], tuple[list[float], list[str | None]]]. - - - - - - - - - -```python -class nemo_rl.environments.math_environment.MathEnvConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.environments.math_environment.MathEnvironment( - cfg: nemo_rl.environments.math_environment.MathEnvConfig -) -``` - - - - - - -**Bases:** [EnvironmentInterface[MathEnvironmentMetadata]](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) - - - - - - - - - - -```python -nemo_rl.environments.math_environment.MathEnvironment.global_post_process_and_metrics( - batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] -) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], dict[str, float | int]] -``` - - - - - - -Computes metrics for this environment given a global rollout batch. - -Every rank will run this function, so you're free to use distributed -calculations if you'd prefer for heavy metrics. - - - - - - - -```python -nemo_rl.environments.math_environment.MathEnvironment.shutdown() -> None -``` - - - - - - - - - - - - -```python -nemo_rl.environments.math_environment.MathEnvironment.step( - message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], - metadata: list[nemo_rl.environments.math_environment.MathEnvironmentMetadata], - return_extracted_answer: bool = False -) -> nemo_rl.environments.interfaces.EnvironmentReturn[nemo_rl.environments.math_environment.MathEnvironmentMetadata] -``` - - - - - - -Runs a step in the math environment. - -**Parameters:** - - -list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM. - - - -list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. The extracted answer will be stored to caculate cons@k. - - -**Returns:** `EnvironmentReturn[MathEnvironmentMetadata]` - -A tuple containing: -- list[dict[str, str]]: Observations/responses batch -- list[dict]: Updated metadata -- list[str]: Next stop strings for the next turn -- Tensor: Rewards tensor -- Tensor: Done flags tensor - - - - - - - - - -```python -class nemo_rl.environments.math_environment.MathEnvironmentMetadata -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.environments.math_environment.MultilingualMultichoiceVerifyWorker() -``` - - - - - - - - - - -```python -nemo_rl.environments.math_environment.MultilingualMultichoiceVerifyWorker.verify( - pred_responses: list[str], - ground_truths: list[str], - return_extracted_answer: bool = False, - kwargs = {} -) -> typing.Union[list[float], tuple[list[float], list[str | None]]] -``` - - - - - - -Verify the correctness of the predicted responses against the ground truth. - -**Parameters:** - - -list[str]. The predicted responses from the LLM. - - - -list[str]. The ground truth responses. - - -**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` - -Union[list[float], tuple[list[float], list[str | None]]]. - - - - - - - - - -```python -nemo_rl.environments.math_environment._mute_output() -``` - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx deleted file mode 100644 index 12386b2..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx +++ /dev/null @@ -1,42 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments/metrics -title: nemo_rl.environments.metrics ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`calculate_pass_rate_per_prompt`](#nemo_rl-environments-metrics-calculate_pass_rate_per_prompt) | Function to compute fraction of prompts that have at least one correct answer (reward > 0). | - -### API - - - - - -```python -nemo_rl.environments.metrics.calculate_pass_rate_per_prompt( - prompts: torch.Tensor, - is_correct: torch.Tensor -) -> float -``` - - - - - - -Function to compute fraction of prompts that have at least one correct answer (reward > 0). - -prompts: tensor (b, s) Tensor of prompts the model used. May be on any device -is_correct: tensor (b,) bool-valued label. May be on any device - -Returns: -pass rate: float - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx deleted file mode 100644 index b8b85fe..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx +++ /dev/null @@ -1,210 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments/nemo_gym -title: nemo_rl.environments.nemo_gym ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`NemoGym`](#nemo_rl-environments-nemo_gym-NemoGym) | This environment class isn't really used for training. It's really meant as an integration wrapper around NeMo-Gym that hooks into the existing NeMo RL resource management via ray. So there is still one source of truth for resource management in NeMo RL. | -| [`NemoGymConfig`](#nemo_rl-environments-nemo_gym-NemoGymConfig) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`setup_nemo_gym_config`](#nemo_rl-environments-nemo_gym-setup_nemo_gym_config) | - | - -### API - - - - - -```python -class nemo_rl.environments.nemo_gym.NemoGym( - cfg: nemo_rl.environments.nemo_gym.NemoGymConfig -) -``` - - - - - - -**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) - -This environment class isn't really used for training. It's really meant as an integration wrapper around NeMo-Gym that hooks into the existing NeMo RL resource management via ray. So there is still one source of truth for resource management in NeMo RL. - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.environments.nemo_gym.NemoGym._postprocess_nemo_gym_to_nemo_rl_result( - nemo_gym_result: dict, - tokenizer: transformers.PreTrainedTokenizerBase -) -> dict -``` - - - - - - - - - - - - -```python -nemo_rl.environments.nemo_gym.NemoGym.global_post_process_and_metrics( - batch -) -``` - - - - - - - - - - - - -```python -nemo_rl.environments.nemo_gym.NemoGym.health_check() -> bool -``` - - - - - - - - - - - - -```python -nemo_rl.environments.nemo_gym.NemoGym.run_rollouts( - nemo_gym_examples: list[dict], - tokenizer: transformers.PreTrainedTokenizerBase, - timer_prefix: str -) -> list[dict] -``` - - - - - - -async - - - - - - - -```python -nemo_rl.environments.nemo_gym.NemoGym.shutdown() -> None -``` - - - - - - - - - - - - -```python -nemo_rl.environments.nemo_gym.NemoGym.step( - message_log_batch, - metadata -) -``` - - - - - - - - - - - - - - -```python -class nemo_rl.environments.nemo_gym.NemoGymConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.environments.nemo_gym.setup_nemo_gym_config( - config, - tokenizer -) -> None -``` - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx deleted file mode 100644 index 67f23d3..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx +++ /dev/null @@ -1,276 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments/reward_model_environment -title: nemo_rl.environments.reward_model_environment ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`RewardModelEnvironment`](#nemo_rl-environments-reward_model_environment-RewardModelEnvironment) | Environment that uses a reward model to score conversations. | -| [`RewardModelEnvironmentConfig`](#nemo_rl-environments-reward_model_environment-RewardModelEnvironmentConfig) | Configuration for RewardModelEnvironment. | - -### API - - - - - -```python -class nemo_rl.environments.reward_model_environment.RewardModelEnvironment( - config: typing.Dict[str, typing.Any] -) -``` - - - - - - -**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) - -Environment that uses a reward model to score conversations. - -This environment implements a reward model-based scoring system for reinforcement -learning tasks. It takes conversation logs as input and returns rewards based on -the quality of the assistant's responses as judged by a pre-trained reward model. - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.environments.reward_model_environment.RewardModelEnvironment.__del__() -``` - - - - - - -Destructor that ensures proper cleanup when the object is garbage collected. - -This is an extra safety net in case the user forgets to call shutdown() and -the pointer to the object is lost due to leaving a function scope. It's always -recommended that the user calls shutdown() explicitly for better resource -management. - - - - - - - -```python -nemo_rl.environments.reward_model_environment.RewardModelEnvironment.global_post_process_and_metrics( - batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict -) -> typing.Tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict] -``` - - - - - - -Post processing function after all rollouts are done for the batch and returns metrics. - -This method computes aggregate statistics and metrics from the processed batch. -It provides insights into reward distribution and processing statistics. - -**Parameters:** - - -The batch data dictionary containing processed conversations and rewards. - - -**Returns:** `BatchedDataDict` - -Tuple of (processed_batch, metrics_dict) where: - - - - - - - -```python -nemo_rl.environments.reward_model_environment.RewardModelEnvironment.preprocess_data( - message_logs: typing.List[nemo_rl.data.interfaces.LLMMessageLogType] -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec] -``` - - - - - - -Preprocess the message logs for the reward model. - -This method tokenizes and formats conversation logs into the format expected -by the reward model. It handles: -- Tokenization of user and assistant messages -- Formatting with proper special tokens -- Batching and padding for efficient processing -- Sequence length validation and truncation - -**Parameters:** - - -List of conversation message logs, where each log contains - a list of messages with 'role' and 'content' fields. - - -**Returns:** `BatchedDataDict[GenerationDatumSpec]` - -BatchedDataDict containing tokenized and formatted data ready for - - - - - - - -```python -nemo_rl.environments.reward_model_environment.RewardModelEnvironment.shutdown() -``` - - - - - - -Shutdown the reward model worker and virtual cluster. - -This method properly cleans up resources by shutting down the reward model -policy and virtual cluster. It should be called when the environment is -no longer needed to prevent resource leaks. - - - - - - - -```python -nemo_rl.environments.reward_model_environment.RewardModelEnvironment.step( - message_logs: typing.List[nemo_rl.data.interfaces.LLMMessageLogType], - env_infos: typing.List[typing.Dict[str, typing.Any]] -) -> nemo_rl.environments.interfaces.EnvironmentReturn -``` - - - - - - -Calculate rewards for the given message logs using the reward model. - -This method processes conversation logs through the reward model to compute -quality scores for each conversation. The rewards are based on the reward -model's assessment of how well the assistant's responses align with human -preferences. - -**Parameters:** - - -List of conversation message logs to be scored. - Each log should contain alternating user and assistant messages. - - - -List of environment info dictionaries (currently unused - but required by the interface). - - -**Returns:** `EnvironmentReturn` - -EnvironmentReturn containing: - - - - - - - - - -```python -class nemo_rl.environments.reward_model_environment.RewardModelEnvironmentConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Configuration for RewardModelEnvironment. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx deleted file mode 100644 index aaa41ba..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx +++ /dev/null @@ -1,180 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments/rewards -title: nemo_rl.environments.rewards ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`bbox_giou_reward`](#nemo_rl-environments-rewards-bbox_giou_reward) | Given [x1, y1, x2, y2] normalized bounding box coordinates within the <{answer_tag}> tags, compute the GIoU between the ground truth and the response. | -| [`combine_reward_functions`](#nemo_rl-environments-rewards-combine_reward_functions) | Returns a callable function that takes (ground_truth, response) and collects multiple reward functions in sequence. | -| [`exact_answer_alphanumeric_reward`](#nemo_rl-environments-rewards-exact_answer_alphanumeric_reward) | Reward the agent when the answer within the <{answer_tag}> tags is the same as the ground truth (case-insensitive). | -| [`format_reward`](#nemo_rl-environments-rewards-format_reward) | Reward the agent when the response follows the format: (.*) <think> (.*) </think> <answer> (.*) </answer>. | -| [`math_expression_reward`](#nemo_rl-environments-rewards-math_expression_reward) | Reward the agent when the answer within the <{tag}> tags is the same expression as the ground truth. | - -### Data - -[`boxed`](#nemo_rl-environments-rewards-boxed) - -[`math_verify_func`](#nemo_rl-environments-rewards-math_verify_func) - -### API - - - - - -```python -nemo_rl.environments.rewards.bbox_giou_reward( - ground_truth: str, - response: str, - giou_penalty_thres: float = 10.0, - answer_tag: str = 'answer' -) -> tuple[float, bool] -``` - - - - - - -Given [x1, y1, x2, y2] normalized bounding box coordinates within the <{answer_tag}> tags, compute the GIoU between the ground truth and the response. - -The `answer_tag` is customizable and must be specified as part of the user COT prompt text file. - - - - - - - - -```python -nemo_rl.environments.rewards.combine_reward_functions( - reward_functions: list[tuple[typing.Callable[[str, str], tuple[float, bool]], float]] -) -> typing.Callable[[str, str], tuple[float, bool]] -``` - - - - - - -Returns a callable function that takes (ground_truth, response) and collects multiple reward functions in sequence. - -The reward functions are weighted by the second element of the tuple. -This information can be provided in the YAML config file and resolved in the VLMEnvironment class. - -**Parameters:** - - -list[tuple[Callable[[str, str], tuple[float, bool]], float]]. A list of reward functions and their weights. - - -**Returns:** `Callable[[str, str], tuple[float, bool]]` - -Callable[[str, str], tuple[float, bool]]: A callable function that takes (ground_truth, response) and collects multiple reward functions in sequence - - - - - - - - -```python -nemo_rl.environments.rewards.exact_answer_alphanumeric_reward( - ground_truth: str, - response: str, - answer_tag: str = 'answer' -) -> tuple[float, bool] -``` - - - - - - -Reward the agent when the answer within the <{answer_tag}> tags is the same as the ground truth (case-insensitive). - -The `answer_tag` is customizable and must be specified as part of the user COT prompt text file. - - - - - - - - -```python -nemo_rl.environments.rewards.format_reward( - ground_truth: str, - response: str, - think_tag: str = 'think', - answer_tag: str = 'answer' -) -> tuple[float, typing.Optional[bool]] -``` - - - - - - -Reward the agent when the response follows the format: (.*) <think> (.*) </think> <answer> (.*) </answer>. - -The `think_tag` and `answer_tag` are customizable and must be specified as part of the user COT prompt text file. - - - - - - - - -```python -nemo_rl.environments.rewards.math_expression_reward( - ground_truth: str, - response: str, - tag: str = 'answer' -) -> tuple[float, bool] -``` - - - - - - -Reward the agent when the answer within the <{tag}> tags is the same expression as the ground truth. - -The `tag` is customizable and must be specified as part of the user COT prompt text file. - - - - - - - - -```python -nemo_rl.environments.rewards.boxed = lambda x: '\\boxed{' + x + '}' if not x.startswith('\\boxed{') else x -``` - - - - - - - - - -```python -nemo_rl.environments.rewards.math_verify_func = math_metric(gold_extraction_target=(LatexExtractionConfig(),), pred_extraction_t... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx deleted file mode 100644 index 67e63b2..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx +++ /dev/null @@ -1,152 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments/utils -title: nemo_rl.environments.utils ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`EnvRegistryEntry`](#nemo_rl-environments-utils-EnvRegistryEntry) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`chunk_list_to_workers`](#nemo_rl-environments-utils-chunk_list_to_workers) | Chunk a list into a list of lists, where each sublist is assigned to a worker. Keeps ordering of elements. | -| [`create_env`](#nemo_rl-environments-utils-create_env) | - | -| [`register_env`](#nemo_rl-environments-utils-register_env) | - | - -### Data - -[`ENV_REGISTRY`](#nemo_rl-environments-utils-ENV_REGISTRY) - -### API - - - - - -```python -class nemo_rl.environments.utils.EnvRegistryEntry -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -nemo_rl.environments.utils.chunk_list_to_workers( - to_chunk: list[typing.Any], - num_workers: int -) -> list[list[typing.Any]] -``` - - - - - - -Chunk a list into a list of lists, where each sublist is assigned to a worker. Keeps ordering of elements. - -If the list is not divisible by the number of workers, the last worker may have fewer elements. -If there are more workers than elements, the first len(list) workers will have a single element each, -and the remaining workers will have empty lists. - -Examples: - - -```python ->>> from nemo_rl.environments.utils import chunk_list_to_workers ->>> chunk_list_to_workers([1, 2, 3, 4, 5], 3) -[[1, 2], [3, 4], [5]] -``` - - - -**Parameters:** - - -The list to be chunked. - - - -The number of workers to distribute the list to. - - -**Returns:** `list[list[Any]]` - -A list of lists, where each sublist contains elements assigned to a worker. - - - - - - - - -```python -nemo_rl.environments.utils.create_env( - env_name: str, - env_config: dict -) -> nemo_rl.environments.interfaces.EnvironmentInterface -``` - - - - - - - - - - - - - -```python -nemo_rl.environments.utils.register_env( - env_name: str, - actor_class_fqn: str -) -> None -``` - - - - - - - - - - - - - -```python -nemo_rl.environments.utils.ENV_REGISTRY: Dict[str, EnvRegistryEntry] = {'math_default': {'actor_class_fqn': 'nemo_rl.environments.math_environment.Math... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx deleted file mode 100644 index 40b7a2a..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx +++ /dev/null @@ -1,243 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/environments/vlm_environment -title: nemo_rl.environments.vlm_environment ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`VLMEnvConfig`](#nemo_rl-environments-vlm_environment-VLMEnvConfig) | - | -| [`VLMEnvironment`](#nemo_rl-environments-vlm_environment-VLMEnvironment) | - | -| [`VLMEnvironmentMetadata`](#nemo_rl-environments-vlm_environment-VLMEnvironmentMetadata) | - | -| [`VLMVerifyWorker`](#nemo_rl-environments-vlm_environment-VLMVerifyWorker) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`_mute_output`](#nemo_rl-environments-vlm_environment-_mute_output) | - | - -### API - - - - - -```python -class nemo_rl.environments.vlm_environment.VLMEnvConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.environments.vlm_environment.VLMEnvironment( - cfg: nemo_rl.environments.vlm_environment.VLMEnvConfig -) -``` - - - - - - -**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) - - - - - - - - - - -```python -nemo_rl.environments.vlm_environment.VLMEnvironment.global_post_process_and_metrics( - batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] -) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], dict[str, float | int]] -``` - - - - - - -Computes metrics for this environment given a global rollout batch. - -Every rank will run this function, so you're free to use distributed -calculations if you'd prefer for heavy metrics. - - - - - - - -```python -nemo_rl.environments.vlm_environment.VLMEnvironment.shutdown() -> None -``` - - - - - - - - - - - - -```python -nemo_rl.environments.vlm_environment.VLMEnvironment.step( - message_log_batch: list[list[dict[str, str]]], - metadata: list[nemo_rl.environments.vlm_environment.VLMEnvironmentMetadata] -) -> nemo_rl.environments.interfaces.EnvironmentReturn -``` - - - - - - -Runs a step in the vlm environment. - -**Parameters:** - - -list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the VLM. - - - -list[VLMEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. - - -**Returns:** `EnvironmentReturn` - -A tuple containing: -- list[dict[str, str]]: Observations/responses batch -- list[dict]: Updated metadata -- list[str]: Next stop strings for the next turn -- Tensor: Rewards tensor -- Tensor: Done flags tensor - - - - - - - - - -```python -class nemo_rl.environments.vlm_environment.VLMEnvironmentMetadata -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - -```python -class nemo_rl.environments.vlm_environment.VLMVerifyWorker( - cfg: nemo_rl.environments.vlm_environment.VLMEnvConfig -) -``` - - - - - - - - - - - - -```python -nemo_rl.environments.vlm_environment.VLMVerifyWorker.verify( - pred_responses: list[str], - ground_truths: list[str] -) -> list[float] -``` - - - - - - -Verify the correctness of the predicted responses against the ground truth. - -**Parameters:** - - -list[str]. The predicted responses from the LLM. - - - -list[str]. The ground truth responses. - - -**Returns:** `list[float]` - -list[float]. The rewards for each predicted response. - - - - - - - - - -```python -nemo_rl.environments.vlm_environment._mute_output() -``` - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx deleted file mode 100644 index e7d333c..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx +++ /dev/null @@ -1,10 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/evals -title: nemo_rl.evals ---- - -## Submodules - -- **[`nemo_rl.evals.answer_parsing`](/nemo-rl/nemo_rl/evals/answer_parsing)** -- **[`nemo_rl.evals.eval`](/nemo-rl/nemo_rl/evals/eval)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx deleted file mode 100644 index 9126f64..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx +++ /dev/null @@ -1,86 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/evals/answer_parsing -title: nemo_rl.evals.answer_parsing ---- - -Contains utility functions for answer parsing. - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`normalize_extracted_answer`](#nemo_rl-evals-answer_parsing-normalize_extracted_answer) | - | -| [`normalize_response`](#nemo_rl-evals-answer_parsing-normalize_response) | Normalize the response by removing markdown and LaTeX formatting that may prevent a match. | - -### Data - -[`MULTILINGUAL_ANSWER_PATTERN_TEMPLATE`](#nemo_rl-evals-answer_parsing-MULTILINGUAL_ANSWER_PATTERN_TEMPLATE) - -[`MULTILINGUAL_ANSWER_REGEXES`](#nemo_rl-evals-answer_parsing-MULTILINGUAL_ANSWER_REGEXES) - -### API - - - - - -```python -nemo_rl.evals.answer_parsing.normalize_extracted_answer( - extracted_answer: str -) -> str -``` - - - - - - - - - - - - - -```python -nemo_rl.evals.answer_parsing.normalize_response( - response: str -) -> str -``` - - - - - - -Normalize the response by removing markdown and LaTeX formatting that may prevent a match. - - - - - - - - -```python -nemo_rl.evals.answer_parsing.MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = '(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])' -``` - - - - - - - - - -```python -nemo_rl.evals.answer_parsing.MULTILINGUAL_ANSWER_REGEXES = ['Answer\\s*:', 'Answer\\s*:\u200b\u200b\u200b\u200b\u200b\u200b', 'উত্তর\\s*:',... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx deleted file mode 100644 index f2712a9..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx +++ /dev/null @@ -1,399 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/evals/eval -title: nemo_rl.evals.eval ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`EvalConfig`](#nemo_rl-evals-eval-EvalConfig) | - | -| [`MasterConfig`](#nemo_rl-evals-eval-MasterConfig) | - | -| [`_PassThroughMathConfig`](#nemo_rl-evals-eval-_PassThroughMathConfig) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`_generate_texts`](#nemo_rl-evals-eval-_generate_texts) | Generate texts using either sync or async method. | -| [`_print_results`](#nemo_rl-evals-eval-_print_results) | Print evaluation results. | -| [`_run_env_eval_impl`](#nemo_rl-evals-eval-_run_env_eval_impl) | Unified implementation for both sync and async evaluation. | -| [`_save_evaluation_data_to_json`](#nemo_rl-evals-eval-_save_evaluation_data_to_json) | Save evaluation data to a JSON file. | -| [`eval_cons_k`](#nemo_rl-evals-eval-eval_cons_k) | Evaluate cons@k score using an unbiased estimator. | -| [`eval_pass_k`](#nemo_rl-evals-eval-eval_pass_k) | Evaluate pass@k score using an unbiased estimator. | -| [`run_env_eval`](#nemo_rl-evals-eval-run_env_eval) | Main entry point for running evaluation using environment. | -| [`setup`](#nemo_rl-evals-eval-setup) | Set up components for model evaluation. | - -### API - - - - - -```python -class nemo_rl.evals.eval.EvalConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.evals.eval.MasterConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.evals.eval._PassThroughMathConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - -```python -nemo_rl.evals.eval._generate_texts( - vllm_generation, - inputs, - use_async -) -``` - - - - - - -async - -Generate texts using either sync or async method. - - - - - - - - -```python -nemo_rl.evals.eval._print_results( - master_config, - generation_config, - score, - dataset_size, - metric, - k_value, - num_tests_per_prompt -) -``` - - - - - - -Print evaluation results. - - - - - - - - -```python -nemo_rl.evals.eval._run_env_eval_impl( - vllm_generation, - dataloader, - env, - master_config, - use_async = False -) -``` - - - - - - -async - -Unified implementation for both sync and async evaluation. - - - - - - - - -```python -nemo_rl.evals.eval._save_evaluation_data_to_json( - evaluation_data, - master_config, - save_path -) -``` - - - - - - -Save evaluation data to a JSON file. - -**Parameters:** - - -List of evaluation samples - - - -Configuration dictionary - - - -Path to save evaluation results. Set to null to disable saving. - Example: "results/eval_output" or "/path/to/evaluation_results" - - - - - - - - - -```python -nemo_rl.evals.eval.eval_cons_k( - rewards: torch.Tensor, - num_tests_per_prompt: int, - k: int, - extracted_answers: list[str | None] -) -> float -``` - - - - - - -Evaluate cons@k score using an unbiased estimator. - -**Parameters:** - - -Tensor of shape (batch_size * num_tests_per_prompt) - - - -int - - - -int - - - -list[str| None] - - -**Returns:** `float` - -float - - - - - - - - -```python -nemo_rl.evals.eval.eval_pass_k( - rewards: torch.Tensor, - num_tests_per_prompt: int, - k: int -) -> float -``` - - - - - - -Evaluate pass@k score using an unbiased estimator. - -Reference: https://github.com/huggingface/evaluate/blob/32546aafec25cdc2a5d7dd9f941fc5be56ba122f/metrics/code_eval/code_eval.py#L198-L213 -Args: - rewards: Tensor of shape (batch_size * num_tests_per_prompt) - k: int (pass@k value) - -**Returns:** `float` - -float - - - - - - - - -```python -nemo_rl.evals.eval.run_env_eval( - vllm_generation, - dataloader, - env, - master_config -) -``` - - - - - - -Main entry point for running evaluation using environment. - -Generates model responses and evaluates them by env. - -**Parameters:** - - -Model for generating responses. - - - -Data loader with evaluation samples. - - - -Environment that scores responses. - - - -Configuration settings. - - - - - - - - - -```python -nemo_rl.evals.eval.setup( - master_config: nemo_rl.evals.eval.MasterConfig, - tokenizer: transformers.AutoTokenizer, - dataset: nemo_rl.data.datasets.AllTaskProcessedDataset -) -> tuple[nemo_rl.models.generation.vllm.VllmGeneration, torch.utils.data.DataLoader, nemo_rl.evals.eval.MasterConfig] -``` - - - - - - -Set up components for model evaluation. - -Initializes the VLLM model and data loader. - -**Parameters:** - - -Configuration settings. - - - -Dataset to evaluate on. - - -**Returns:** `tuple[VllmGeneration, DataLoader, MasterConfig]` - -VLLM model, data loader, and config. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx deleted file mode 100644 index eacff25..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx +++ /dev/null @@ -1,9 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/experience -title: nemo_rl.experience ---- - -## Submodules - -- **[`nemo_rl.experience.rollouts`](/nemo-rl/nemo_rl/experience/rollouts)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx deleted file mode 100644 index f00431e..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx +++ /dev/null @@ -1,469 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/experience/rollouts -title: nemo_rl.experience.rollouts ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AsyncNemoGymRolloutResult`](#nemo_rl-experience-rollouts-AsyncNemoGymRolloutResult) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`_calculate_single_metric`](#nemo_rl-experience-rollouts-_calculate_single_metric) | - | -| [`async_generate_response_for_sample_turn`](#nemo_rl-experience-rollouts-async_generate_response_for_sample_turn) | Generate a response for a single sample's turn using async generation. | -| [`calculate_rewards`](#nemo_rl-experience-rollouts-calculate_rewards) | Calculate rewards for generated responses and get environment feedback. | -| [`generate_responses`](#nemo_rl-experience-rollouts-generate_responses) | Generate responses from policy using synchronous generation. | -| [`generate_responses_async`](#nemo_rl-experience-rollouts-generate_responses_async) | Async version of generate_responses that properly calls generate_async. | -| [`run_async_multi_turn_rollout`](#nemo_rl-experience-rollouts-run_async_multi_turn_rollout) | Run multi-turn rollouts with sample-level processing. | -| [`run_async_nemo_gym_rollout`](#nemo_rl-experience-rollouts-run_async_nemo_gym_rollout) | Run multi-turn rollouts with NeMo-Gym. Please refer to the `run_async_multi_turn_rollout` docs for more information on the parameters. | -| [`run_multi_turn_rollout`](#nemo_rl-experience-rollouts-run_multi_turn_rollout) | Runs a multi-turn rollout loop, interacting with the environment. | -| [`run_sample_multi_turn_rollout`](#nemo_rl-experience-rollouts-run_sample_multi_turn_rollout) | Run a multi-turn rollout for a single sample. | - -### Data - -[`TokenizerType`](#nemo_rl-experience-rollouts-TokenizerType) - -### API - - - - - -```python -class nemo_rl.experience.rollouts.AsyncNemoGymRolloutResult( - input_ids: torch.Tensor, - final_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], - rollout_metrics: dict[str, typing.Any] -) -``` - - - - - - -Dataclass - - - - - - - - - - - - - - - -```python -nemo_rl.experience.rollouts._calculate_single_metric( - values: list[float], - batch_size: int, - key_name: str -) -> dict -``` - - - - - - - - - - - - - -```python -nemo_rl.experience.rollouts.async_generate_response_for_sample_turn( - policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, - sample_message_log: list[dict], - sample_stop_strings: list[str] | None, - tokenizer: nemo_rl.experience.rollouts.TokenizerType, - max_seq_len: int, - greedy: bool = False -) -> tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, float]] -``` - - - - - - -async - -Generate a response for a single sample's turn using async generation. - -**Parameters:** - - -The generation interface to use - - - -Message log for a single sample - - - -Stop strings for this sample - - - -Tokenizer to use - - - -Maximum sequence length - - - -Whether to use greedy decoding - - -**Returns:** `tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, float]]` - -Tuple of (updated_message_log, generated_tokens, input_lengths, generation_metrics) - - - - - - - - -```python -nemo_rl.experience.rollouts.calculate_rewards( - batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], - task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface] -) -> nemo_rl.environments.interfaces.EnvironmentReturn -``` - - - - - - -Calculate rewards for generated responses and get environment feedback. - -**Parameters:** - - -Batch containing message_log (LLMMessageLogType) with generated responses - - - -Dictionary mapping task names to their corresponding environments - - -**Returns:** `EnvironmentReturn` - -EnvironmentReturn namedtuple containing: -- observations: List of observations from the environment for the next turn. -- metadata: List of extracted metadata from the environment. -- next_stop_strings: List of stop strings for the next generation step. -- rewards: Tensor of rewards for the last turn. -- terminateds: Tensor of booleans indicating if an episode ended naturally. - - - - - - - - -```python -nemo_rl.experience.rollouts.generate_responses( - policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, - generation_input_data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], - tokenizer: nemo_rl.experience.rollouts.TokenizerType, - input_lengths: torch.Tensor, - include_logprobs: bool = True, - greedy: bool = False -) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], list[torch.Tensor], dict[str, float | int]] -``` - - - - - - -Generate responses from policy using synchronous generation. - - - - - - - - -```python -nemo_rl.experience.rollouts.generate_responses_async( - policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, - generation_input_data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], - tokenizer: nemo_rl.experience.rollouts.TokenizerType, - input_lengths: torch.Tensor, - include_logprobs: bool = True, - greedy: bool = False -) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], list[torch.Tensor], dict[str, float | int]] -``` - - - - - - -async - -Async version of generate_responses that properly calls generate_async. - - - - - - - - -```python -nemo_rl.experience.rollouts.run_async_multi_turn_rollout( - policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, - input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], - tokenizer: nemo_rl.experience.rollouts.TokenizerType, - task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], - max_seq_len: int, - max_rollout_turns: int = 999999, - greedy: bool = False -) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], dict[str, typing.Any]] -``` - - - - - - -Run multi-turn rollouts with sample-level processing. - -Each sample in the batch proceeds through its interaction independently. -Async generation is used internally when available but the function is synchronous. - -**Parameters:** - - -The generation interface (policy) - - - -The starting batch containing initial message logs - - - -The tokenizer - - - -Dictionary mapping task names to environment instances - - - -Maximum sequence length allowed - - - -Maximum number of agent-environment interaction turns - - - -Whether to use greedy decoding - - -**Returns:** `tuple[BatchedDataDict[DatumSpec], dict[str, Any]]` - -Tuple containing: -- BatchedDataDict with the full interaction history and accumulated rewards -- Dictionary of rollout metrics - - - - - - - - -```python -nemo_rl.experience.rollouts.run_async_nemo_gym_rollout( - policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, - input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], - tokenizer: nemo_rl.experience.rollouts.TokenizerType, - task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], - generation_config: nemo_rl.models.generation.interfaces.GenerationConfig, - max_seq_len: typing.Optional[int] = None, - max_rollout_turns: typing.Optional[int] = None, - greedy: bool = False -) -> nemo_rl.experience.rollouts.AsyncNemoGymRolloutResult -``` - - - - - - -Run multi-turn rollouts with NeMo-Gym. Please refer to the `run_async_multi_turn_rollout` docs for more information on the parameters. - - - - - - - - -```python -nemo_rl.experience.rollouts.run_multi_turn_rollout( - policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, - input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], - tokenizer: nemo_rl.experience.rollouts.TokenizerType, - task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], - max_seq_len: int, - max_rollout_turns: int = 999999, - greedy: bool = False -) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], dict[str, typing.Any]] -``` - - - - - - -Runs a multi-turn rollout loop, interacting with the environment. - -**Parameters:** - - -The generation interface (policy). - - - -The starting batch containing initial message logs. - - - -The tokenizer. - - - -Dictionary mapping task names to environment instances. - - - -Maximum number of agent-environment interaction turns. - - - -Maximum sequence length allowed. - - - -Whether to use greedy decoding. - - -**Returns:** `tuple[BatchedDataDict[DatumSpec], dict[str, Any]]` - -Tuple containing: -- BatchedDataDict with the full interaction history and accumulated rewards -- Dictionary of rollout metrics - - - - - - - - -```python -nemo_rl.experience.rollouts.run_sample_multi_turn_rollout( - sample_idx: int, - initial_sample_state: dict, - policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, - tokenizer: nemo_rl.experience.rollouts.TokenizerType, - task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], - max_seq_len: int, - max_rollout_turns: int = 999999, - greedy: bool = False -) -> tuple[dict, dict[str, typing.Any]] -``` - - - - - - -async - -Run a multi-turn rollout for a single sample. - -This function manages the complete lifecycle of one sample's interaction. -Async generation is used internally when available. - -**Parameters:** - - -Index of this sample in the original batch - - - -Initial state containing message_log, extra_env_info, etc. - - - -The generation interface - - - -Tokenizer to use - - - -Environment mapping - - - -Maximum sequence length - - - -Maximum number of turns - - - -Whether to use greedy decoding - - -**Returns:** `tuple[dict, dict[str, Any]]` - -Tuple of (final_sample_state, sample_metrics) - - - - - - - - -```python -nemo_rl.experience.rollouts.TokenizerType = PreTrainedTokenizerBase -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx deleted file mode 100644 index 78aeea0..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx +++ /dev/null @@ -1,14 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models -title: nemo_rl.models ---- - -## Subpackages - -- **[`nemo_rl.models.automodel`](/nemo-rl/nemo_rl/models/automodel)** -- **[`nemo_rl.models.dtensor`](/nemo-rl/nemo_rl/models/dtensor)** -- **[`nemo_rl.models.generation`](/nemo-rl/nemo_rl/models/generation)** -- **[`nemo_rl.models.huggingface`](/nemo-rl/nemo_rl/models/huggingface)** -- **[`nemo_rl.models.megatron`](/nemo-rl/nemo_rl/models/megatron)** -- **[`nemo_rl.models.policy`](/nemo-rl/nemo_rl/models/policy)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx deleted file mode 100644 index a8d957d..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx +++ /dev/null @@ -1,12 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/automodel -title: nemo_rl.models.automodel ---- - -## Submodules - -- **[`nemo_rl.models.automodel.config`](/nemo-rl/nemo_rl/models/automodel/config)** -- **[`nemo_rl.models.automodel.data`](/nemo-rl/nemo_rl/models/automodel/data)** -- **[`nemo_rl.models.automodel.setup`](/nemo-rl/nemo_rl/models/automodel/setup)** -- **[`nemo_rl.models.automodel.train`](/nemo-rl/nemo_rl/models/automodel/train)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx deleted file mode 100644 index c409b87..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx +++ /dev/null @@ -1,125 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/automodel/config -title: nemo_rl.models.automodel.config ---- - -Configuration classes for automodel-based training in NeMo RL. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`ModelAndOptimizerState`](#nemo_rl-models-automodel-config-ModelAndOptimizerState) | Container for model and optimizer state. | -| [`RuntimeConfig`](#nemo_rl-models-automodel-config-RuntimeConfig) | Runtime configuration for model training and inference. | - -### API - - - - - -```python -class nemo_rl.models.automodel.config.ModelAndOptimizerState() -``` - - - - - - -**Bases:** `NamedTuple` - -Container for model and optimizer state. - -This named tuple holds all model-related state including the model itself, -optimizer, scheduler, and metadata about the model type and configuration. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.automodel.config.RuntimeConfig() -``` - - - - - - -**Bases:** `NamedTuple` - -Runtime configuration for model training and inference. - -This contains all validated runtime settings needed for model initialization, -parallelization, and training. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx deleted file mode 100644 index f777494..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx +++ /dev/null @@ -1,374 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/automodel/data -title: nemo_rl.models.automodel.data ---- - -Data processing utilities for automodel training and inference. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`ProcessedInputs`](#nemo_rl-models-automodel-data-ProcessedInputs) | Processed microbatch inputs ready for model forward pass. | -| [`ProcessedMicrobatch`](#nemo_rl-models-automodel-data-ProcessedMicrobatch) | Container for a processed microbatch ready for model forward pass. | - -### Functions - -| Name | Description | -|------|-------------| -| [`check_sequence_dim`](#nemo_rl-models-automodel-data-check_sequence_dim) | Check and validate sequence dimension across all tensors. | -| [`get_microbatch_iterator`](#nemo_rl-models-automodel-data-get_microbatch_iterator) | Create processed microbatch iterator based on batching strategy. | -| [`make_processed_microbatch_iterator`](#nemo_rl-models-automodel-data-make_processed_microbatch_iterator) | Wrap a raw microbatch iterator to yield processed microbatches. | -| [`process_global_batch`](#nemo_rl-models-automodel-data-process_global_batch) | Process a global batch and compute normalization factors. | -| [`process_microbatch`](#nemo_rl-models-automodel-data-process_microbatch) | Process a microbatch and prepare inputs for model forward. | - -### API - - - - - -```python -class nemo_rl.models.automodel.data.ProcessedInputs( - input_ids: torch.Tensor, - seq_len: int, - attention_mask: typing.Optional[torch.Tensor] = None, - position_ids: typing.Optional[torch.Tensor] = None, - flash_attn_kwargs: dict[str, typing.Any] = dict(), - vlm_kwargs: dict[str, typing.Any] = dict(), - cp_buffers: list[torch.Tensor] = list(), - seq_index: typing.Optional[torch.Tensor] = None -) -``` - - - - - - -Dataclass - -Processed microbatch inputs ready for model forward pass. - -This structure contains all necessary tensors and metadata for a forward pass, -including context parallel buffers and flash attention configuration. - - - - - - - - - - - - -Check if context parallel is enabled. - - - -Check if flash attention is configured. - -Works for both empty dict {} and dataclass objects like FlashAttnKwargs. - - - - - - -Check if this is a multimodal input. - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.automodel.data.ProcessedMicrobatch( - data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, - original_batch_size: int, - original_seq_len: int -) -``` - - - - - - -Dataclass - -Container for a processed microbatch ready for model forward pass. - -This dataclass holds both the original data dictionary and the processed -tensors needed for the automodel forward pass. It follows the same pattern -as nemo_rl/models/megatron/data.py ProcessedMicrobatch. - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.models.automodel.data.check_sequence_dim( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] -) -> typing.Tuple[int, int] -``` - - - - - - -Check and validate sequence dimension across all tensors. - -Verifies that dimension 1 is the sequence dimension for all tensors -in the data dictionary that have more than one dimension. - -**Parameters:** - - -BatchedDataDict to validate - - -**Returns:** `Tuple[int, int]` - -Tuple of (sequence_dim, seq_dim_size) - -**Raises:** - -- `AssertionError`: If any tensor has inconsistent sequence dimension - - - - - - - - -```python -nemo_rl.models.automodel.data.get_microbatch_iterator( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - cfg: dict[str, typing.Any], - mbs: int, - dp_mesh: typing.Any, - tokenizer: transformers.AutoTokenizer, - cp_size: int = 1 -) -> tuple[typing.Iterator[nemo_rl.models.automodel.data.ProcessedMicrobatch], int] -``` - - - - - - -Create processed microbatch iterator based on batching strategy. - -**Parameters:** - - -Full dataset to iterate over - - - -Configuration dictionary (enable_seq_packing is inferred from cfg["sequence_packing"]["enabled"]) - - - -Microbatch size - - - -Data parallel mesh - - - -Tokenizer for processing - - - -Context parallel size - - -**Returns:** `tuple[Iterator[ProcessedMicrobatch], int]` - -Tuple of (processed_microbatch_iterator, iterator_length) - - - - - - - - -```python -nemo_rl.models.automodel.data.make_processed_microbatch_iterator( - raw_iterator: typing.Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any]], - tokenizer: transformers.AutoTokenizer, - cfg: dict[str, typing.Any], - cp_size: int -) -> typing.Iterator[nemo_rl.models.automodel.data.ProcessedMicrobatch] -``` - - - - - - -Wrap a raw microbatch iterator to yield processed microbatches. - -This function takes a raw iterator that yields BatchedDataDict objects and -wraps it to yield ProcessedMicrobatch objects that contain both the original -data and the processed tensors ready for model forward pass. - -**Parameters:** - - -Iterator yielding raw BatchedDataDict microbatches - - - -Tokenizer for processing - - - -Configuration dictionary (enable_seq_packing is inferred from cfg["sequence_packing"]["enabled"]) - - - -Context parallel size - - - - - - - - - -```python -nemo_rl.models.automodel.data.process_global_batch( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - loss_fn: nemo_rl.algorithms.interfaces.LossFunction, - dp_group: torch.distributed.ProcessGroup, - batch_idx: int, - batch_size: int -) -> dict[str, typing.Any] -``` - - - - - - -Process a global batch and compute normalization factors. - -**Parameters:** - - -Full dataset - - - -Loss function (used to check loss type) - - - -Data parallel process group (for consistency with Megatron naming) - - - -Index of batch to extract - - - -Size of batch to extract - - -**Returns:** `dict[str, Any]` - -Dictionary containing: - - - - - - - - -```python -nemo_rl.models.automodel.data.process_microbatch( - mb: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - tokenizer: transformers.AutoTokenizer, - enable_seq_packing: bool, - cfg: dict[str, typing.Any], - cp_size: int -) -> nemo_rl.models.automodel.data.ProcessedInputs -``` - - - - - - -Process a microbatch and prepare inputs for model forward. - -**Parameters:** - - -Microbatch data - - - -Tokenizer for padding value - - - -Whether sequence packing is enabled - - - -Configuration dictionary - - - -Context parallel size - - -**Returns:** `ProcessedInputs` - -ProcessedInputs containing all tensors and metadata for forward pass - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx deleted file mode 100644 index 2b2a115..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx +++ /dev/null @@ -1,229 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/automodel/setup -title: nemo_rl.models.automodel.setup ---- - -Setup utilities for automodel-based training in NeMo RL. - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`setup_distributed`](#nemo_rl-models-automodel-setup-setup_distributed) | Set up distributed training environment and create FSDP2Manager. | -| [`setup_model_and_optimizer`](#nemo_rl-models-automodel-setup-setup_model_and_optimizer) | Set up model, parallelization, and optimizer. | -| [`setup_reference_model_state`](#nemo_rl-models-automodel-setup-setup_reference_model_state) | Set up reference model state dict by creating a CPU copy of the model's state dict. | -| [`validate_and_prepare_config`](#nemo_rl-models-automodel-setup-validate_and_prepare_config) | Validate configuration and prepare runtime settings. | - -### Data - -[`STRING_TO_DTYPE`](#nemo_rl-models-automodel-setup-STRING_TO_DTYPE) - -### API - - - - - -```python -nemo_rl.models.automodel.setup.setup_distributed( - config: nemo_rl.models.policy.PolicyConfig, - runtime_config: nemo_rl.models.automodel.config.RuntimeConfig -) -> nemo_automodel.components.distributed.fsdp2.FSDP2Manager -``` - - - - - - -Set up distributed training environment and create FSDP2Manager. - -Initializes torch.distributed process group and creates an FSDP2Manager -with the appropriate parallelization and precision settings. - -**Parameters:** - - -Policy configuration dictionary - - - -RuntimeConfig named tuple from validate_and_prepare_config - - -**Returns:** `FSDP2Manager` - -FSDP2Manager instance with all distributed configuration - - - - - - - - -```python -nemo_rl.models.automodel.setup.setup_model_and_optimizer( - config: nemo_rl.models.policy.PolicyConfig, - tokenizer: transformers.AutoTokenizer, - runtime_config: nemo_rl.models.automodel.config.RuntimeConfig, - distributed_manager: nemo_automodel.components.distributed.fsdp2.FSDP2Manager, - checkpoint_manager: typing.Any, - is_vlm: bool = False, - init_optimizer: bool = True, - weights_path: typing.Optional[str] = None, - optimizer_path: typing.Optional[str] = None -) -> nemo_rl.models.automodel.config.ModelAndOptimizerState -``` - - - - - - -Set up model, parallelization, and optimizer. - -Creates the model from config, applies parallelization strategies (FSDP2, TP, CP), -loads base weights, and optionally initializes optimizer and scheduler. - -**Parameters:** - - -Policy configuration dictionary - - - -Tokenizer for the model - - - -RuntimeConfig named tuple from validate_and_prepare_config - - - -FSDP2Manager from setup_distributed - - - -Checkpoint manager for loading/saving weights - - - -Whether this is a vision-language model - - - -Whether to initialize optimizer - - - -Optional path to checkpoint weights to load - - - -Optional path to optimizer state to load - - -**Returns:** `ModelAndOptimizerState` - -ModelAndOptimizerState containing model, optimizer, scheduler, and metadata - - - - - - - - -```python -nemo_rl.models.automodel.setup.setup_reference_model_state( - model: torch.nn.Module -) -> dict[str, torch.Tensor] -``` - - - - - - -Set up reference model state dict by creating a CPU copy of the model's state dict. - -This creates a reference copy of the model weights on CPU with pinned memory -for efficient CPU-GPU transfers. The reference model is typically used to -compute reference log probabilities during RL training. - -**Parameters:** - - -The model to create a reference copy from - - -**Returns:** `dict[str, torch.Tensor]` - -Dictionary mapping parameter names to CPU tensors with pinned memory - - - - - - - - -```python -nemo_rl.models.automodel.setup.validate_and_prepare_config( - config: nemo_rl.models.policy.PolicyConfig, - processor: typing.Optional[transformers.AutoProcessor], - rank: int -) -> nemo_rl.models.automodel.config.RuntimeConfig -``` - - - - - - -Validate configuration and prepare runtime settings. - -This function validates the policy configuration, sets environment variables, -determines model configuration, and returns runtime settings as a named tuple. - -**Parameters:** - - -Policy configuration dictionary - - - -Optional processor for multimodal models - - - -Current process rank - - -**Returns:** `RuntimeConfig` - -RuntimeConfig named tuple containing validated configuration values - -**Raises:** - -- `ValueError`: If configuration is invalid -- `RuntimeError`: If incompatible settings are detected - - - - - - - - -```python -nemo_rl.models.automodel.setup.STRING_TO_DTYPE = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16} -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx deleted file mode 100644 index 7126780..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx +++ /dev/null @@ -1,841 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/automodel/train -title: nemo_rl.models.automodel.train ---- - -Training utilities for automodel (DTensor-based) policy workers. - -This module provides post-processor classes and forward/backward functions -that follow the same pattern as nemo_rl/models/megatron/train.py. - -Key differences from megatron approach: -- Post-processors compute results directly (no callable return pattern) -- forward_with_post_processing_fn calls post-processor directly -- automodel_forward_backward uses PyTorch autograd instead of Megatron's pipeline - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`LogprobsPostProcessor`](#nemo_rl-models-automodel-train-LogprobsPostProcessor) | Post-processor for computing log probabilities from model outputs. | -| [`LossPostProcessor`](#nemo_rl-models-automodel-train-LossPostProcessor) | Post-processor for computing training loss from model outputs. | -| [`ScorePostProcessor`](#nemo_rl-models-automodel-train-ScorePostProcessor) | Post-processor for computing reward model scores from model outputs. | -| [`TopkLogitsPostProcessor`](#nemo_rl-models-automodel-train-TopkLogitsPostProcessor) | Post-processor for computing top-k logits from model outputs. | - -### Functions - -| Name | Description | -|------|-------------| -| [`aggregate_training_statistics`](#nemo_rl-models-automodel-train-aggregate_training_statistics) | Aggregate training statistics across microbatches and ranks. | -| [`apply_temperature_scaling`](#nemo_rl-models-automodel-train-apply_temperature_scaling) | Apply temperature scaling to logits. | -| [`automodel_forward_backward`](#nemo_rl-models-automodel-train-automodel_forward_backward) | Execute forward and backward passes for automodel. | -| [`extract_logits`](#nemo_rl-models-automodel-train-extract_logits) | Extract logits from model outputs. | -| [`forward_with_post_processing_fn`](#nemo_rl-models-automodel-train-forward_with_post_processing_fn) | Perform forward pass with pre-processed microbatch and apply post-processing. | -| [`model_forward`](#nemo_rl-models-automodel-train-model_forward) | Perform a single forward pass through the model. | -| [`prepare_data_for_cp`](#nemo_rl-models-automodel-train-prepare_data_for_cp) | Prepare data for context parallel processing. | -| [`redistribute_logits_for_cp`](#nemo_rl-models-automodel-train-redistribute_logits_for_cp) | Redistribute logits for context parallel processing. | - -### Data - -[`PostProcessingFunction`](#nemo_rl-models-automodel-train-PostProcessingFunction) - -### API - - - - - -```python -class nemo_rl.models.automodel.train.LogprobsPostProcessor( - cfg: nemo_rl.models.policy.PolicyConfig, - device_mesh: typing.Any, - cp_mesh: typing.Any, - tp_mesh: typing.Any, - cp_size: int, - enable_seq_packing: bool = False -) -``` - - - - - - -Post-processor for computing log probabilities from model outputs. - - - - - - - - -```python -nemo_rl.models.automodel.train.LogprobsPostProcessor.__call__( - logits: torch.Tensor, - processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, - input_lengths: torch.Tensor, - original_batch_size: int, - original_seq_len: int, - sequence_dim: int = 1 -) -> torch.Tensor -``` - - - - - - -Compute token log probabilities from logits. - -**Parameters:** - - -Model output logits - - - -Processed inputs - - - -Sequence lengths - - - -Original batch size before packing - - - -Original sequence length before packing - - - -Sequence dimension - - -**Returns:** `torch.Tensor` - -Token log probabilities tensor [batch_size, seq_length] - - - - - - - -```python -nemo_rl.models.automodel.train.LogprobsPostProcessor._compute_local_logprobs( - logits: torch.Tensor, - input_ids: torch.Tensor -) -> torch.Tensor -``` - - - - - - -Compute logprobs locally without distributed processing. - -**Parameters:** - - -Model output logits - - - -Input token IDs - - -**Returns:** `torch.Tensor` - -Token log probabilities - - - - - - - - - -```python -class nemo_rl.models.automodel.train.LossPostProcessor( - loss_fn: nemo_rl.algorithms.interfaces.LossFunction, - cfg: nemo_rl.models.policy.PolicyConfig, - device_mesh: typing.Any, - cp_mesh: typing.Any, - tp_mesh: typing.Any, - cp_size: int, - dp_size: int, - enable_seq_packing: bool = False -) -``` - - - - - - -Post-processor for computing training loss from model outputs. - - - - - - -```python -nemo_rl.models.automodel.train.LossPostProcessor.__call__( - logits: torch.Tensor, - mb: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, - global_valid_seqs: torch.Tensor, - global_valid_toks: torch.Tensor, - sequence_dim: int = 1 -) -> tuple[torch.Tensor, dict[str, typing.Any]] -``` - - - - - - -Compute loss from logits. - -**Parameters:** - - -Model output logits - - - -Microbatch data - - - -Processed inputs - - - -Global valid sequence count - - - -Global valid token count - - - -Sequence dimension - - -**Returns:** `tuple[torch.Tensor, dict[str, Any]]` - -Tuple of (loss, metrics) - - - - - - - - - -```python -class nemo_rl.models.automodel.train.ScorePostProcessor( - cfg: nemo_rl.models.policy.PolicyConfig -) -``` - - - - - - -Post-processor for computing reward model scores from model outputs. - - - - - - -```python -nemo_rl.models.automodel.train.ScorePostProcessor.__call__( - logits: torch.Tensor -) -> torch.Tensor -``` - - - - - - -Extract scores from reward model outputs. - -**Parameters:** - - -Model output logits - - -**Returns:** `torch.Tensor` - -Scores tensor - - - - - - - - - -```python -class nemo_rl.models.automodel.train.TopkLogitsPostProcessor( - cfg: nemo_rl.models.policy.PolicyConfig, - device_mesh: typing.Any, - cp_mesh: typing.Any, - tp_mesh: typing.Any, - cp_size: int, - k: int, - enable_seq_packing: bool = False -) -``` - - - - - - -Post-processor for computing top-k logits from model outputs. - - - - - - -```python -nemo_rl.models.automodel.train.TopkLogitsPostProcessor.__call__( - logits: torch.Tensor, - processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, - input_lengths: torch.Tensor, - original_batch_size: int, - original_seq_len: int, - sequence_dim: int = 1 -) -> tuple[torch.Tensor, torch.Tensor] -``` - - - - - - -Compute top-k logits and indices from model outputs. - -**Parameters:** - - -Model output logits - - - -Processed inputs - - - -Sequence lengths - - - -Original batch size before packing - - - -Original sequence length before packing - - - -Sequence dimension - - -**Returns:** `tuple[torch.Tensor, torch.Tensor]` - -Tuple of (top-k values, top-k indices) tensors - - - - - - - - - -```python -nemo_rl.models.automodel.train.aggregate_training_statistics( - losses: list[float], - all_mb_metrics: list[dict[str, typing.Any]], - grad_norm: typing.Optional[torch.Tensor], - dp_group: typing.Any, - dtype: torch.dtype -) -> dict[str, typing.Any] -``` - - - - - - -Aggregate training statistics across microbatches and ranks. - -**Parameters:** - - -List of loss values from each microbatch - - - -List of metrics dictionaries from each microbatch - - - -Gradient norm tensor (or None if eval mode) - - - -Data parallel process group for all-reduce - - - -Model dtype for metrics - - -**Returns:** `dict[str, Any]` - -Dictionary containing aggregated metrics including global_loss, grad_norm, etc. - - - - - - - - -```python -nemo_rl.models.automodel.train.apply_temperature_scaling( - logits: torch.Tensor, - cfg: nemo_rl.models.policy.PolicyConfig -) -> torch.Tensor -``` - - - - - - -Apply temperature scaling to logits. - -**Parameters:** - - -Logits tensor to scale - - - -Configuration dictionary containing generation settings - - -**Returns:** `torch.Tensor` - -torch.Tensor: Temperature-scaled logits - - - - - - - - -```python -nemo_rl.models.automodel.train.automodel_forward_backward( - model: torch.nn.Module, - cfg: nemo_rl.models.policy.PolicyConfig, - data_iterator: typing.Iterator[nemo_rl.models.automodel.data.ProcessedMicrobatch], - post_processing_fn: nemo_rl.models.automodel.train.PostProcessingFunction, - forward_only: bool = False, - is_reward_model: bool = False, - allow_flash_attn_args: bool = True, - global_valid_seqs: typing.Optional[torch.Tensor] = None, - global_valid_toks: typing.Optional[torch.Tensor] = None, - sequence_dim: int = 1, - dp_size: int = 1, - cp_size: int = 1, - num_global_batches: int = 1, - train_context_fn: typing.Optional[typing.Callable[[ProcessedInputs], typing.Any]] = None, - num_valid_microbatches: typing.Optional[int] = None, - on_microbatch_start: typing.Optional[typing.Callable[[int], None]] = None -) -> list[typing.Tuple[typing.Any, dict[str, typing.Any]]] -``` - - - - - - -Execute forward and backward passes for automodel. - -This is the main training loop function that coordinates forward and backward -passes across multiple microbatches using PyTorch autograd. - -Unlike megatron_forward_backward which uses Megatron's pipeline parallel -framework, this uses standard PyTorch operations. - -**Parameters:** - - -The model to train - - - -Configuration dictionary - - - -Iterator yielding ProcessedMicrobatch objects (already processed) - - - -Number of microbatches to process - - - -Post-processing function to apply to the logits - - - -If True, skip backward pass - - - -Whether this is a reward model - - - -Whether to pass flash_attn_kwargs to model - - - -Global valid sequence count for loss normalization - - - -Global valid token count for loss normalization - - - -Sequence dimension - - - -Data parallel size - - - -Context parallel size - - - -Number of global batches (for metric scaling) - - - -Optional callable that takes ProcessedInputs and returns -a context manager for the forward/backward pass. If None, no context is used. - - - -Number of valid (non-dummy) microbatches. If provided, -microbatches beyond this index are treated as dummy batches (loss *= 0). -If None, all microbatches are considered valid. - - - -Optional callback called at the start of each microbatch -with the microbatch index. Useful for cache clearing, etc. - - -**Returns:** `list[Tuple[Any, dict[str, Any]]]` - -List of (result, metrics) tuples from each microbatch - - - - - - - - -```python -nemo_rl.models.automodel.train.extract_logits( - model: torch.nn.Module, - outputs: typing.Any -) -> torch.Tensor -``` - - - - - - -Extract logits from model outputs. - -**Parameters:** - - -The model (used for lm_head if needed) - - - -Model outputs (can be tensor, DTensor, or object with logits attribute) - - -**Returns:** `torch.Tensor` - -torch.Tensor: Logits tensor - - - - - - - - -```python -nemo_rl.models.automodel.train.forward_with_post_processing_fn( - model: torch.nn.Module, - cfg: nemo_rl.models.policy.PolicyConfig, - post_processing_fn: nemo_rl.models.automodel.train.PostProcessingFunction, - processed_mb: nemo_rl.models.automodel.data.ProcessedMicrobatch, - is_reward_model: bool = False, - allow_flash_attn_args: bool = True, - global_valid_seqs: typing.Optional[torch.Tensor] = None, - global_valid_toks: typing.Optional[torch.Tensor] = None, - sequence_dim: int = 1 -) -> typing.Tuple[typing.Any, dict[str, typing.Any], nemo_rl.models.automodel.data.ProcessedMicrobatch] -``` - - - - - - -Perform forward pass with pre-processed microbatch and apply post-processing. - -This function takes a pre-processed microbatch (with sequence packing already handled), -runs the forward step through the model, and applies the post-processing function -to compute the result. - -Unlike the megatron approach which returns a callable, this directly computes -and returns the result since automodel uses PyTorch autograd. - -**Parameters:** - - -The model to run forward pass on - - - -Configuration dictionary - - - -Post-processing function to apply to the logits - - - -Pre-fetched ProcessedMicrobatch containing data and processed inputs - - - -Whether this is a reward model - - - -Whether to pass flash_attn_kwargs to model - - - -Global valid sequence count for loss normalization - - - -Global valid token count for loss normalization - - - -Sequence dimension - - -**Returns:** `Tuple[Any, dict[str, Any], ProcessedMicrobatch]` - -(result, metrics, processed_microbatch) -- result: Output from post-processing (loss, logprobs, topk, or scores) -- metrics: Dictionary of metrics from post-processing -- processed_microbatch: The ProcessedMicrobatch that was processed - - - - - - - - -```python -nemo_rl.models.automodel.train.model_forward( - model: torch.nn.Module, - processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, - is_reward_model: bool = False, - allow_flash_attn_args: bool = True -) -> torch.Tensor -``` - - - - - - -Perform a single forward pass through the model. - -**Parameters:** - - -The model to run forward pass on - - - -ProcessedInputs containing all tensors for forward pass - - - -Whether this is a reward model - - - -Whether to pass flash_attn_kwargs to model - - -**Returns:** `torch.Tensor` - -torch.Tensor: Output tensor from the model (logits) - - - - - - - - -```python -nemo_rl.models.automodel.train.prepare_data_for_cp( - mb: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, - cp_mesh: typing.Any, - sequence_dim: int = 1 -) -> tuple[torch.Tensor, nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any]] -``` - - - - - - -Prepare data for context parallel processing. - -Converts seq_index to full tensor and wraps CP-sharded tensors in DTensor. - -**Parameters:** - - -Microbatch data dictionary - - - -Processed inputs containing CP buffers - - - -Context parallel mesh - - - -Dimension for sequence sharding - - -**Returns:** `tuple[torch.Tensor, BatchedDataDict[Any]]` - -Tuple of (seq_index_dtensor, updated_mb) - - - - - - - - -```python -nemo_rl.models.automodel.train.redistribute_logits_for_cp( - logits: torch.Tensor, - device_mesh: typing.Any, - cp_mesh: typing.Any, - sequence_dim: int = 1 -) -> torch.distributed.tensor.DTensor -``` - - - - - - -Redistribute logits for context parallel processing. - -Handles the case where logits may be TP-sharded DTensor or regular tensor, -and converts them to CP+TP sharded DTensor. - -**Parameters:** - - -Logits tensor (may be DTensor or regular tensor) - - - -Full device mesh - - - -Context parallel mesh (kept for signature compatibility) - - - -Dimension for sequence sharding - - -**Returns:** `DTensor` - -DTensor sharded on both CP and TP dimensions - - - - - - - - -```python -nemo_rl.models.automodel.train.PostProcessingFunction = Union['LossPostProcessor', 'LogprobsPostProcessor', 'TopkLogitsPostProcessor', '... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx deleted file mode 100644 index b29df46..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx +++ /dev/null @@ -1,9 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/dtensor -title: nemo_rl.models.dtensor ---- - -## Submodules - -- **[`nemo_rl.models.dtensor.parallelize`](/nemo-rl/nemo_rl/models/dtensor/parallelize)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx deleted file mode 100644 index 72877d4..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx +++ /dev/null @@ -1,454 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/dtensor/parallelize -title: nemo_rl.models.dtensor.parallelize ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`RotaryEmbedParallel`](#nemo_rl-models-dtensor-parallelize-RotaryEmbedParallel) | Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_parallelize_gemma3`](#nemo_rl-models-dtensor-parallelize-_parallelize_gemma3) | Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions. | -| [`_parallelize_llama`](#nemo_rl-models-dtensor-parallelize-_parallelize_llama) | Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions. | -| [`_parallelize_model`](#nemo_rl-models-dtensor-parallelize-_parallelize_model) | Parallelize a model using DTensor. | -| [`_parallelize_nm5_h`](#nemo_rl-models-dtensor-parallelize-_parallelize_nm5_h) | Parallelize a NemotronHForCausalLM model across data and tensor parallel dimensions. | -| [`_parallelize_qwen`](#nemo_rl-models-dtensor-parallelize-_parallelize_qwen) | Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions. | -| [`clip_grad_by_total_norm_`](#nemo_rl-models-dtensor-parallelize-clip_grad_by_total_norm_) | Clips gradient of an iterable of parameters by total norm. | -| [`get_grad_norm`](#nemo_rl-models-dtensor-parallelize-get_grad_norm) | Calculate the norm of gradients. | -| [`get_hf_tp_plan`](#nemo_rl-models-dtensor-parallelize-get_hf_tp_plan) | Get the Hugging Face tensor parallel plan from the model. | -| [`to_local_if_dtensor`](#nemo_rl-models-dtensor-parallelize-to_local_if_dtensor) | Returns the local shard of the given tensor if it is a DTensor. | -| [`translate_parallel_style`](#nemo_rl-models-dtensor-parallelize-translate_parallel_style) | Translate parallel style str to parallel type. | - -### Data - -[`PARALLIZE_FUNCTIONS`](#nemo_rl-models-dtensor-parallelize-PARALLIZE_FUNCTIONS) - -### API - - - - - -```python -class nemo_rl.models.dtensor.parallelize.RotaryEmbedParallel() -``` - - - - - - -**Bases:** `SequenceParallel` - -Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple. - - - - - - -```python -nemo_rl.models.dtensor.parallelize.RotaryEmbedParallel._prepare_input_fn( - sequence_sharding, - mod, - inputs, - device_mesh -) -``` - - - - - - -staticmethod - - - - - - - -```python -nemo_rl.models.dtensor.parallelize.RotaryEmbedParallel._prepare_output_fn( - use_local_output, - mod, - outputs, - device_mesh -) -``` - - - - - - -staticmethod - - - - - - - - - -```python -nemo_rl.models.dtensor.parallelize._parallelize_gemma3( - model: typing.Union[transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration], - sequence_parallel: bool = False -) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle] -``` - - - - - - -Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions. - - - - - - - - -```python -nemo_rl.models.dtensor.parallelize._parallelize_llama( - model: transformers.models.llama.modeling_llama.LlamaForCausalLM, - sequence_parallel: bool = False -) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle] -``` - - - - - - -Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions. - - - - - - - - -```python -nemo_rl.models.dtensor.parallelize._parallelize_model( - model: typing.Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM, transformers.models.llama.modeling_llama.LlamaForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration], - dp_mesh: torch.distributed.device_mesh.DeviceMesh, - tp_mesh: torch.distributed.device_mesh.DeviceMesh, - param_dtype: torch.dtype, - sequence_parallel: bool = False, - activation_checkpointing: bool = False, - cpu_offload: bool = False, - custom_parallel_plan: typing.Optional[typing.Union[dict, str]] = None -) -``` - - - - - - -Parallelize a model using DTensor. - -**Parameters:** - - -The model to parallelize. - - - -Device mesh for data parallelism. - - - -Device mesh for tensor parallelism. - - - -Data type for model parameters. - - - -Whether to use sequence parallelism. Defaults to False. - - - -Whether to use activation checkpointing. Defaults to False. - - - -Whether to enable cpu offloading for FSDP. Defaults to False. - - - -Custom parallel plan for the model. Defaults to None. -If it's a dict, it will be used as the parallel plan directly. -If it's a string, it must be a path that points to a dict or a function that returns a dict. -The usage example can refer to `docs/design-docs/fsdp2-parallel-plan.md`. - - -**Returns:** - -The parallelized model. - -**Raises:** - -- `ValueError`: If the model type is not supported for parallelization. - - - - - - - - -```python -nemo_rl.models.dtensor.parallelize._parallelize_nm5_h( - model, - dp_mesh: torch.distributed.device_mesh.DeviceMesh, - tp_mesh: torch.distributed.device_mesh.DeviceMesh, - param_dtype: torch.dtype, - sequence_parallel: bool = False, - activation_checkpointing: bool = False, - cpu_offload: bool = False, - custom_parallel_plan: typing.Optional[typing.Union[dict, str]] = None -) -> torch.distributed.fsdp.FSDPModule -``` - - - - - - -Parallelize a NemotronHForCausalLM model across data and tensor parallel dimensions. - - - - - - - - -```python -nemo_rl.models.dtensor.parallelize._parallelize_qwen( - model: typing.Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM], - sequence_parallel: bool = False -) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle] -``` - - - - - - -Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions. - - - - - - - - -```python -nemo_rl.models.dtensor.parallelize.clip_grad_by_total_norm_( - parameters: typing.Union[list[typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], - max_grad_norm: typing.Union[int, float], - total_norm: float -) -``` - - - - - - -Clips gradient of an iterable of parameters by total norm. - -Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L138 - -Note that the gradients are modified in place. - -**Parameters:** - - - -An iterable of Tensors or DTensors, or a single Tensor or DTensor -that will have gradients normalized. - - - -Maximum norm of the gradients. - - - -The pre-computed total norm of the gradients to use for scaling. - - - - - - - - - -```python -nemo_rl.models.dtensor.parallelize.get_grad_norm( - parameters: typing.Union[list[typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], - dp_cp_group: torch.distributed.ProcessGroup, - tp_group: torch.distributed.ProcessGroup, - norm_type: typing.Union[int, float] = 2, - dtype: torch.dtype = torch.float32 -) -> float -``` - - - - - - -Calculate the norm of gradients. - -Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L51 - -**Parameters:** - - - -An iterable of Tensors or DTensors, or a single Tensor or DTensor -that will have gradient norm calculated. - - - -Process group for data parallel communication. - - - -Process group for context parallel communication. - - - -Process group for tensor parallel communication. - - - -Type of the used p-norm. Can be ``'inf'`` for -infinity norm. - - -**Returns:** `float` - -Total norm of the gradients (viewed as a single vector) - - - - - - - - -```python -nemo_rl.models.dtensor.parallelize.get_hf_tp_plan( - model: transformers.modeling_utils.PreTrainedModel -) -``` - - - - - - -Get the Hugging Face tensor parallel plan from the model. - -This function: -- Retrieves TP strategies from model class, instance, and inner model levels. -- Handles special cases for `embed_tokens` and `lm_head` for speed up. -- Converts string-based parallel styles to DTensor parallelization strategies. - -Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L532 - -**Parameters:** - - -A Hugging Face model instance - - -**Returns:** - -A dictionary mapping model component paths to their parallelization strategies - -**Raises:** - -- `AssertionError`: If no TP plan is found - - - - - - - - -```python -nemo_rl.models.dtensor.parallelize.to_local_if_dtensor( - tensor: typing.Union[torch.Tensor, torch.distributed.tensor.DTensor] -) -> torch.Tensor -``` - - - - - - -Returns the local shard of the given tensor if it is a DTensor. - -Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/605f618f237cda8fa80132bc2ccff933512d5a0d/megatron/core/utils.py#L746 - - - - - - - - -```python -nemo_rl.models.dtensor.parallelize.translate_parallel_style( - style: str -) -``` - - - - - - -Translate parallel style str to parallel type. - -Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L547 - - - - - - - - -```python -nemo_rl.models.dtensor.parallelize.PARALLIZE_FUNCTIONS: dict[type[Module], Callable[..., dict[str, ParallelStyle]]] = {Qwen2ForCausalLM: _parallelize_qwen, Qwen3ForCausalLM: _parallelize_qwen, Llama... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx deleted file mode 100644 index ff3114c..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx +++ /dev/null @@ -1,62 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation -title: nemo_rl.models.generation ---- - -## Subpackages - -- **[`nemo_rl.models.generation.sglang`](/nemo-rl/nemo_rl/models/generation/sglang)** -- **[`nemo_rl.models.generation.vllm`](/nemo-rl/nemo_rl/models/generation/vllm)** - -## Submodules - -- **[`nemo_rl.models.generation.interfaces`](/nemo-rl/nemo_rl/models/generation/interfaces)** - -## Package Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`configure_generation_config`](#nemo_rl-models-generation-configure_generation_config) | Apply specific configurations to generation config. | - -### Data - -[`TokenizerType`](#nemo_rl-models-generation-TokenizerType) - -### API - - - - - -```python -nemo_rl.models.generation.configure_generation_config( - config: nemo_rl.models.generation.interfaces.GenerationConfig, - tokenizer: nemo_rl.models.generation.TokenizerType, - is_eval = False -) -> nemo_rl.models.generation.interfaces.GenerationConfig -``` - - - - - - -Apply specific configurations to generation config. - - - - - - - - -```python -nemo_rl.models.generation.TokenizerType = PreTrainedTokenizerBase -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx deleted file mode 100644 index 886ccc8..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx +++ /dev/null @@ -1,569 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/interfaces -title: nemo_rl.models.generation.interfaces ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`ColocationConfig`](#nemo_rl-models-generation-interfaces-ColocationConfig) | - | -| [`GenerationConfig`](#nemo_rl-models-generation-interfaces-GenerationConfig) | Configuration for generation. | -| [`GenerationDatumSpec`](#nemo_rl-models-generation-interfaces-GenerationDatumSpec) | Specification for input data required by generation models. | -| [`GenerationInterface`](#nemo_rl-models-generation-interfaces-GenerationInterface) | Abstract base class defining the interface for RL policies. | -| [`GenerationOutputSpec`](#nemo_rl-models-generation-interfaces-GenerationOutputSpec) | Specification for output data returned by generation models. | -| [`OptionalResourcesConfig`](#nemo_rl-models-generation-interfaces-OptionalResourcesConfig) | - | -| [`ResourcesConfig`](#nemo_rl-models-generation-interfaces-ResourcesConfig) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`verify_right_padding`](#nemo_rl-models-generation-interfaces-verify_right_padding) | Verify that a tensor is right-padded according to the provided lengths. | - -### API - - - - - -```python -class nemo_rl.models.generation.interfaces.ColocationConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.models.generation.interfaces.GenerationConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Configuration for generation. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.generation.interfaces.GenerationDatumSpec -``` - - - - - - -**Bases:** `typing.TypedDict` - -Specification for input data required by generation models. - -- input_ids: Tensor of token IDs representing the input sequences (right padded) -- input_lengths: Tensor containing the actual length of each sequence (without padding) -- stop_strings: Optional list of strings to stop generation (per sample) -- __extra__: Additional model-specific data fields - -Example of a batch with 4 entries with different sequence lengths: - - -```python -# Batch of 4 sequences with lengths [3, 5, 2, 4] - -input_ids (padded): -[ - [101, 2054, 2003, 0, 0], # Length 3 - [101, 2054, 2003, 2001, 1996], # Length 5 - [101, 2054, 0, 0, 0], # Length 2 - [101, 2054, 2003, 2001, 0], # Length 4 -] - -input_lengths: -[3, 5, 2, 4] -``` - - - -All functions receiving or returning GenerationDatumSpec should ensure -right padding is maintained. Use verify_right_padding() to check. - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.generation.interfaces.GenerationInterface() -``` - - - - - - -Abstract - -Abstract base class defining the interface for RL policies. - - - -Whether the generation backend requires KV cache scales synchronization. - - - - - -```python -nemo_rl.models.generation.interfaces.GenerationInterface.clear_logger_metrics() -> None -``` - - - - - - -Clear logger metrics for performance reporting. - -This is an optional method that backends can implement to clear -telemetry metrics. Default implementation does nothing. - - - - - - - -```python -nemo_rl.models.generation.interfaces.GenerationInterface.finish_generation( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> bool -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.generation.interfaces.GenerationInterface.generate( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.generation.interfaces.GenerationInterface.get_logger_metrics() -> dict[str, typing.Any] -``` - - - - - - -Get logger metrics for performance reporting. - -This is an optional method that backends can implement to collect -telemetry metrics. Default implementation returns empty dict. - -**Returns:** `dict[str, Any]` - -Dictionary of metrics. Format may vary by backend. - - - - - - - -```python -nemo_rl.models.generation.interfaces.GenerationInterface.init_collective( - ip: str, - port: int, - world_size: int -) -> list[ray.ObjectRef] -``` - - - - - - -abstract - -Initialize the collective communication. - - - - - - - -```python -nemo_rl.models.generation.interfaces.GenerationInterface.invalidate_kv_cache() -> bool -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.interfaces.GenerationInterface.prepare_for_generation( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> bool -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.generation.interfaces.GenerationInterface.prepare_refit_info( - state_dict_info: dict[str, typing.Any] -) -> None -``` - - - - - - -Prepare the info for refit. - - - - - - - -```python -nemo_rl.models.generation.interfaces.GenerationInterface.update_weights_from_collective() -> list[ray.ObjectRef] -``` - - - - - - -Update the model weights from collective communication. - - - - - - - -```python -nemo_rl.models.generation.interfaces.GenerationInterface.update_weights_via_ipc_zmq() -> list[ray.ObjectRef] -``` - - - - - - -Update the model weights from the given IPC handles. - - - - - - - - - -```python -class nemo_rl.models.generation.interfaces.GenerationOutputSpec -``` - - - - - - -**Bases:** `typing.TypedDict` - -Specification for output data returned by generation models. - -- output_ids: Tensor of token IDs representing the generated sequences (right padded) -- generation_lengths: Tensor containing the actual length of each generated sequence -- unpadded_sequence_lengths: Tensor containing the actual length of each input + generated sequence (without padding) -- logprobs: Tensor of log probabilities for each generated token (right padded with zeros) -- truncated: Boolean tensor indicating if each sequence was truncated (hit max_tokens limit) -- __extra__: Additional model-specific data fields - -Example of a batch with 2 sequences: - - -```python -# Sample batch with 2 examples -# - Example 1: Input length 3, generated response length 4 -# - Example 2: Input length 5, generated response length 2 - -output_ids (right-padded): -[ - [101, 2054, 2003, 2023, 2003, 1037, 2200, 0], # 7 valid tokens (3 input + 4 output) - [101, 2054, 2003, 2001, 1996, 3014, 2005, 0], # 7 valid tokens (5 input + 2 output) -] - -generation_lengths: -[4, 2] # Length of just the generated response part - -unpadded_sequence_lengths: -[7, 7] # Length of full valid sequence (input + generated response) - -logprobs (right-padded with zeros): -[ - [0.0, 0.0, 0.0, -1.2, -0.8, -2.1, -1.5, 0.0], # First 3 are 0 (input tokens), next 4 are actual logprobs - [0.0, 0.0, 0.0, 0.0, 0.0, -0.9, -1.7, 0.0], # First 5 are 0 (input tokens), next 2 are actual logprobs -] - -truncated: -[False, True] # Example 2 was truncated (hit max_tokens limit without EOS) -``` - - - -All functions receiving or returning GenerationOutputSpec should ensure -right padding is maintained. Use verify_right_padding() to check. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.generation.interfaces.OptionalResourcesConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.models.generation.interfaces.ResourcesConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -nemo_rl.models.generation.interfaces.verify_right_padding( - data: typing.Union[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], - pad_value: int = 0, - raise_error: bool = True -) -> tuple[bool, typing.Union[str, None]] -``` - - - - - - -Verify that a tensor is right-padded according to the provided lengths. - -**Parameters:** - - -The BatchedDataDict to check, containing either: -- For GenerationDatumSpec: input_ids and input_lengths -- For GenerationOutputSpec: output_ids and unpadded_sequence_lengths - - - -The expected padding value (default: 0) - - - -Whether to raise an error if wrong padding is detected - - -**Returns:** `bool` - -Tuple of (is_right_padded, error_message) - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx deleted file mode 100644 index 78b60ba..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx +++ /dev/null @@ -1,33 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/sglang -title: nemo_rl.models.generation.sglang ---- - -## Submodules - -- **[`nemo_rl.models.generation.sglang.config`](/nemo-rl/nemo_rl/models/generation/sglang/config)** -- **[`nemo_rl.models.generation.sglang.sglang_copied_utils`](/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils)** -- **[`nemo_rl.models.generation.sglang.sglang_generation`](/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation)** -- **[`nemo_rl.models.generation.sglang.sglang_worker`](/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker)** -- **[`nemo_rl.models.generation.sglang.utils`](/nemo-rl/nemo_rl/models/generation/sglang/utils)** - -## Package Contents - -### Data - -[`__all__`](#nemo_rl-models-generation-sglang-__all__) - -### API - - - - - -```python -nemo_rl.models.generation.sglang.__all__ = ['SGLangConfig', 'SGLangGeneration'] -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx deleted file mode 100644 index f86bc80..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx +++ /dev/null @@ -1,299 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/sglang/config -title: nemo_rl.models.generation.sglang.config ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`SGLangConfig`](#nemo_rl-models-generation-sglang-config-SGLangConfig) | Configuration for SGLang runtime. | -| [`SglangSpecificArgs`](#nemo_rl-models-generation-sglang-config-SglangSpecificArgs) | SGLang-specific configuration arguments. | - -### API - - - - - -```python -class nemo_rl.models.generation.sglang.config.SGLangConfig() -``` - - - - - - -**Bases:** [GenerationConfig](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationConfig) - -Configuration for SGLang runtime. - - - - - - - - - - - - - -```python -class nemo_rl.models.generation.sglang.config.SglangSpecificArgs -``` - - - - - - -**Bases:** `typing.TypedDict` - -SGLang-specific configuration arguments. - -Most fields below map directly to SGLang's ServerArgs (see: -https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py). - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx deleted file mode 100644 index c940dea..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx +++ /dev/null @@ -1,307 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils -title: nemo_rl.models.generation.sglang.sglang_copied_utils ---- - -Standalone utility functions copied from the SGLang project. - -This module contains utility functions that were originally part of the SGLang -repository (https://github.com/sgl-project/sglang). They have been copied here -to avoid requiring sglang as a runtime dependency for weight refitting functionality. - -IMPORTANT: This module should NOT contain any imports from the sglang package. -All functions are standalone and self-contained. - -Each function includes a permalink to its original source in the SGLang repository. -These functions were copied from sglang version 0.5.2. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`MultiprocessingSerializer`](#nemo_rl-models-generation-sglang-sglang_copied_utils-MultiprocessingSerializer) | Serialize/deserialize Python objects using ForkingPickler for IPC. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_device_from_maybe_uuid`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_device_from_maybe_uuid) | Convert a device UUID string or index to a device index. | -| [`_device_to_uuid`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_device_to_uuid) | Convert a device index to its UUID string. | -| [`_modify_tuple`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_modify_tuple) | Create a new tuple with one element modified by a function. | -| [`_rebuild_cuda_tensor_modified`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_rebuild_cuda_tensor_modified) | Modified rebuild_cuda_tensor that accepts GPU UUID or device index. | -| [`_reduce_tensor_modified`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_reduce_tensor_modified) | Modified reduce_tensor that stores GPU UUID instead of device index. | -| [`monkey_patch_torch_reductions`](#nemo_rl-models-generation-sglang-sglang_copied_utils-monkey_patch_torch_reductions) | Monkey patch torch multiprocessing reductions to use GPU UUIDs. | - -### Data - -[`_REDUCE_TENSOR_ARG_DEVICE_INDEX`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_REDUCE_TENSOR_ARG_DEVICE_INDEX) - -### API - - - - - -```python -class nemo_rl.models.generation.sglang.sglang_copied_utils.MultiprocessingSerializer() -``` - - - - - - -Serialize/deserialize Python objects using ForkingPickler for IPC. - -This class enables serialization of objects (including CUDA tensors with IPC -handles) for transfer between processes via HTTP or other mechanisms. - -Original source (sglang v0.5.2): -https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/utils.py#L589-L623 - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_copied_utils.MultiprocessingSerializer.deserialize( - data -) -``` - - - - - - -staticmethod - -Deserialize a previously serialized object. - -**Parameters:** - - -The serialized data, optionally base64-encoded. - - -**Returns:** - -The deserialized Python object. - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_copied_utils.MultiprocessingSerializer.serialize( - obj, - output_str: bool = False -) -``` - - - - - - -staticmethod - -Serialize a Python object using ForkingPickler. - -**Parameters:** - - -The object to serialize. - - - -If True, return a base64-encoded string instead of raw bytes. - - -**Returns:** - -bytes or str: The serialized object. - - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_copied_utils._device_from_maybe_uuid( - device_maybe_uuid: typing.Union[int, str] -) -> int -``` - - - - - - -Convert a device UUID string or index to a device index. - -Original source (sglang v0.5.2): -https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L55-L65 - -**Parameters:** - - -Either an integer device index or a UUID string. - - -**Returns:** `int` - -The integer device index. - -**Raises:** - -- `Exception`: If the UUID doesn't match any available device. - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_copied_utils._device_to_uuid( - device: int -) -> str -``` - - - - - - -Convert a device index to its UUID string. - -Original source (sglang v0.5.2): -https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L51-L52 - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_copied_utils._modify_tuple( - t, - index: int, - modifier: typing.Callable -) -``` - - - - - - -Create a new tuple with one element modified by a function. - -Original source (sglang v0.5.2): -https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L68-L69 - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_copied_utils._rebuild_cuda_tensor_modified( - args = () -) -``` - - - - - - -Modified rebuild_cuda_tensor that accepts GPU UUID or device index. - -Original source (sglang v0.5.2): -https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L46-L48 - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_copied_utils._reduce_tensor_modified( - args = (), - kwargs = {} -) -``` - - - - - - -Modified reduce_tensor that stores GPU UUID instead of device index. - -Original source (sglang v0.5.2): -https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L39-L43 - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_copied_utils.monkey_patch_torch_reductions() -``` - - - - - - -Monkey patch torch multiprocessing reductions to use GPU UUIDs. - -This patch modifies PyTorch's CUDA tensor IPC mechanism to use GPU UUIDs -instead of device indices. This enables proper weight transfer between -processes that may have different CUDA_VISIBLE_DEVICES configurations. - -The patch is idempotent - calling it multiple times is safe. - -This is a workaround before PyTorch https://github.com/pytorch/pytorch/pull/149248 -is merged and released. - -Original source (sglang v0.5.2): -https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L20-L33 - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_copied_utils._REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx deleted file mode 100644 index c8393bd..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx +++ /dev/null @@ -1,369 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation -title: nemo_rl.models.generation.sglang.sglang_generation ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`SGLangGeneration`](#nemo_rl-models-generation-sglang-sglang_generation-SGLangGeneration) | - | - -### Data - -[`TOP_K_THRESHOLD`](#nemo_rl-models-generation-sglang-sglang_generation-TOP_K_THRESHOLD) - -[`TOP_P_THRESHOLD`](#nemo_rl-models-generation-sglang-sglang_generation-TOP_P_THRESHOLD) - -[`logger`](#nemo_rl-models-generation-sglang-sglang_generation-logger) - -### API - - - - - -```python -class nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration( - cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, - config: nemo_rl.models.generation.sglang.config.SGLangConfig, - name_prefix: str = 'sglang_policy', - workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None -) -``` - - - - - - -**Bases:** [GenerationInterface](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationInterface) - - - - - - - - - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.__del__() -> None -``` - - - - - - -Shuts down the worker groups when the object is deleted or is garbage collected. - -This is an extra safety net in case the user forgets to call shutdown() and the pointer to -the object is lost due to leaving a function scope. It's always recommended that the -user calls shutdown(). - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration._allocate_bundles_for_servers( - cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, - num_servers: int, - gpus_per_server: int -) -> list[tuple[int, list[int]]] -``` - - - - - - -Allocate GPU bundles to each SGLang server. - -Each server gets consecutive bundles within the same placement group (node). -Ray will automatically set CUDA_VISIBLE_DEVICES so each server sees logical GPUs 0, 1, 2, ..., gpus_per_server-1. - -**Parameters:** - - -The Ray virtual cluster - - - -Total number of SGLang servers to create - - - -Number of GPUs each server needs - - -**Returns:** `list[tuple[int, list[int]]]` - -List of (node_idx, [bundle_indices]) tuples for each server - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.finish_generation( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> bool -``` - - - - - - -Sleep workers and reset prefix cache. - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.generate( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] -``` - - - - - - -Generate a batch of data using SGLang. - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.get_sglang_server_urls() -> list[str] -``` - - - - - - -Get base URLs of all SGLang servers. - -**Returns:** `list[str]` - -List of base URLs (e.g., ["http://localhost:30000", "http://localhost:30001"]) - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.get_sglang_url_to_gpu_uuids() -> dict[str, list[str]] -``` - - - - - - -Get mapping from SGLang server URL to list of GPU UUIDs it uses. - -**Returns:** `dict[str, list[str]]` - -Dict mapping server URL to list of GPU UUIDs - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.init_collective( - ip: str, - port: int, - world_size: int, - train_world_size: int -) -> list[ray.ObjectRef] -``` - - - - - - -Initialize the collective communication. - -TODO: if weight updates via NCCL are needed in the future. - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.invalidate_kv_cache() -> bool -``` - - - - - - -Invalidate KV cache before weight updates (Megatron-style). - -This flushes the cache before weight updates to clear stale cache. -Only primary workers (TP rank 0, model owners) will flush their cache. - -**Returns:** `bool` - -True if all caches were flushed successfully, False otherwise - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.prepare_for_generation( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> bool -``` - - - - - - -Wake workers up for colocated inference. - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.prepare_refit_info( - state_dict_info: dict[str, typing.Any] -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.shutdown() -> bool -``` - - - - - - -Shut down all SGLang workers and clean up resources. - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.update_weights_from_collective() -> list[ray.ObjectRef] -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.update_weights_via_ipc_zmq() -> list[ray.ObjectRef] -``` - - - - - - - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.TOP_K_THRESHOLD = 8000 -``` - - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.TOP_P_THRESHOLD = 0.99 -``` - - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_generation.logger = logging.getLogger(__name__) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx deleted file mode 100644 index a74da3c..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx +++ /dev/null @@ -1,529 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker -title: nemo_rl.models.generation.sglang.sglang_worker ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`SGLangGenerationWorker`](#nemo_rl-models-generation-sglang-sglang_worker-SGLangGenerationWorker) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`_require_sglang`](#nemo_rl-models-generation-sglang-sglang_worker-_require_sglang) | Import `sglang` lazily so test collection works without the optional extra. | - -### Data - -[`logger`](#nemo_rl-models-generation-sglang-sglang_worker-logger) - -### API - - - - - -```python -class nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker( - config: nemo_rl.models.generation.sglang.config.SGLangConfig, - bundle_indices: typing.Optional[list[int]] = None, - fraction_of_gpus: float = 1.0, - seed: typing.Optional[int] = None -) -``` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.__repr__() -> str -``` - - - - - - -Customizes the actor's prefix in the Ray logs. - -This makes it easier to identify which worker is producing specific log messages. - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._build_sampling_params( - greedy: bool, - stop_strings, - max_new_tokens: typing.Optional[int] = None, - input_len: typing.Optional[int] = None, - context_length: typing.Optional[int] = None, - sample_index: typing.Optional[int] = None -) -> dict[str, typing.Any] -``` - - - - - - -Build sampling parameters dictionary for SGLang API. - -**Parameters:** - - -Whether to use greedy decoding (temperature=0.0) - - - -Merged stop strings (not used here, handled per sample) - - - -Override max_new_tokens from config if provided - - - -Input length for this sample (used for context_length adjustment) - - - -Maximum context length (if provided, adjusts max_new_tokens) - - - -Sample index (used for warning messages, 0-indexed) - - -**Returns:** `dict[str, Any]` - -Dictionary of sampling parameters compatible with SGLang API - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._ensure_session() -``` - - - - - - -async - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._generate_async( - tasks -) -``` - - - - - - -async - -Execute generation tasks with concurrency control. - -TEMP: Uses a semaphore to limit the number of concurrent requests per server, preventing server overload. -A router based solution is preffered in the future. - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._generate_single_sample( - input_ids: list[int], - sampling_params: dict[str, typing.Any], - stop_string: typing.Optional[str] = None -) -> tuple[list[int], list[float]] -``` - - - - - - -async - -Generate a single sample using SGLang API (async function). - -**Parameters:** - - -List of input token IDs (without padding) - - - -Dictionary of sampling parameters (temperature, top_p, max_new_tokens, etc.) - - - -Optional stop string for this sample - - -**Returns:** `tuple[list[int], list[float]]` - -Tuple of (generated_tokens, logprobs): -- generated_tokens: List of generated token IDs -- logprobs: List of log probabilities for generated tokens - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._launch_server_process( - server_args: typing.Any -) -> multiprocessing.Process -``` - - - - - - -Launch the SGLang server process and wait for it to be ready. - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._make_request( - endpoint: str, - payload: typing.Optional[dict] = None -) -``` - - - - - - -Make a POST request to the specified endpoint with the given payload. - -**Parameters:** - - -The API endpoint to call - - - -The JSON payload to send (default: empty dict) - - -**Returns:** - -The JSON response from the server - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._merge_stop_strings( - batch_stop_strings -) -``` - - - - - - -Merge stop strings from config and batch. - -**Parameters:** - - -List of stop strings from batch (one per sample) - - -**Returns:** - -List of merged stop strings (one per sample) - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.configure_worker( - num_gpus: int | float, - bundle_indices: typing.Optional[tuple[int, list[int]]] = None -) -> tuple[dict[str, typing.Any], dict[str, str], dict[str, typing.Any]] -``` - - - - - - -staticmethod - -Provides complete worker configuration for SGLang server. - -This method configures the worker based on bundle_indices which tells us -how many GPUs this server should use. - -**Parameters:** - - -Original GPU allocation for this worker based on the placement group - - - -Tuple of (node_idx, local_bundle_indices) for this server - - -**Returns:** `tuple[dict[str, Any], dict[str, str], dict[str, Any]]` - -tuple with complete worker configuration: -- 'resources': Resource allocation (e.g., num_gpus) -- 'env_vars': Environment variables for this worker -- 'init_kwargs': Parameters to pass to __init__ of the worker - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.generate( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] -``` - - - - - - -Generate a batch of data using SGLang generation. - -**Parameters:** - - -BatchedDataDict containing input_ids and input_lengths tensors - - - -Whether to use greedy decoding instead of sampling - - -**Returns:** `BatchedDataDict[GenerationOutputSpec]` - -BatchedDataDict conforming to GenerationOutputSpec: -- output_ids: input + generated token IDs with proper padding -- logprobs: Log probabilities for tokens -- generation_lengths: Lengths of each response -- unpadded_sequence_lengths: Lengths of each input + generated sequence - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.get_base_url() -> str -``` - - - - - - -Get the base URL of this SGLang server. - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.get_gpu_uuids() -> list[str] -``` - - - - - - -Get list of GPU UUIDs used by this SGLang server. - -**Returns:** `list[str]` - -List of GPU UUIDs (e.g., ["GPU-xxxxx", "GPU-yyyyy"]) - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.invalidate_kv_cache() -> bool -``` - - - - - - -Invalidate KV cache before weight updates (Megatron-style). - -This flushes the cache before weight updates to clear stale cache. -Uses retry logic to handle cases where there are pending requests. - -**Returns:** `bool` - -True if flush was successful, False otherwise - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.shutdown() -> bool -``` - - - - - - -Shutdown the SGLang server process and cleanup async resources. - -**Returns:** `bool` - -True if shutdown was successful, False otherwise - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.sleep() -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.wake_up( - kwargs = {} -) -``` - - - - - - - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker._require_sglang() -``` - - - - - - -Import `sglang` lazily so test collection works without the optional extra. - - - - - - - - -```python -nemo_rl.models.generation.sglang.sglang_worker.logger = logging.getLogger(__name__) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx deleted file mode 100644 index ace8dcd..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx +++ /dev/null @@ -1,109 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/sglang/utils -title: nemo_rl.models.generation.sglang.utils ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AsyncLoopThread`](#nemo_rl-models-generation-sglang-utils-AsyncLoopThread) | A background event loop thread for running async operations in Ray actors. | - -### API - - - - - -```python -class nemo_rl.models.generation.sglang.utils.AsyncLoopThread() -``` - - - - - - -A background event loop thread for running async operations in Ray actors. - -This class creates a dedicated thread with its own event loop, allowing -synchronous Ray actor methods to execute async coroutines without blocking -the main actor thread. This is necessary because run_coroutine_threadsafe -requires the event loop to be in a different thread. - - - - - - - - - - - - - - -```python -nemo_rl.models.generation.sglang.utils.AsyncLoopThread._start_loop() -``` - - - - - - -Run the event loop in the background thread. - - - - - - - -```python -nemo_rl.models.generation.sglang.utils.AsyncLoopThread.run( - coro -) -``` - - - - - - -Schedule a coroutine onto the loop and block until it's done. - -**Parameters:** - - -The coroutine to execute - - -**Returns:** - -The result of the coroutine - - - - - - - -```python -nemo_rl.models.generation.sglang.utils.AsyncLoopThread.shutdown() -``` - - - - - - -Shutdown the event loop and wait for the thread to finish. - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx deleted file mode 100644 index c5f278e..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx +++ /dev/null @@ -1,34 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/vllm -title: nemo_rl.models.generation.vllm ---- - -## Submodules - -- **[`nemo_rl.models.generation.vllm.config`](/nemo-rl/nemo_rl/models/generation/vllm/config)** -- **[`nemo_rl.models.generation.vllm.utils`](/nemo-rl/nemo_rl/models/generation/vllm/utils)** -- **[`nemo_rl.models.generation.vllm.vllm_backend`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend)** -- **[`nemo_rl.models.generation.vllm.vllm_generation`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation)** -- **[`nemo_rl.models.generation.vllm.vllm_worker`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker)** -- **[`nemo_rl.models.generation.vllm.vllm_worker_async`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async)** - -## Package Contents - -### Data - -[`__all__`](#nemo_rl-models-generation-vllm-__all__) - -### API - - - - - -```python -nemo_rl.models.generation.vllm.__all__ = ['VllmConfig', 'VllmGeneration', 'VllmGenerationWorker', 'VllmAsyncGenerationWor... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx deleted file mode 100644 index 6ca4574..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx +++ /dev/null @@ -1,111 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/vllm/config -title: nemo_rl.models.generation.vllm.config ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`VllmConfig`](#nemo_rl-models-generation-vllm-config-VllmConfig) | - | -| [`VllmSpecificArgs`](#nemo_rl-models-generation-vllm-config-VllmSpecificArgs) | - | - -### API - - - - - -```python -class nemo_rl.models.generation.vllm.config.VllmConfig() -``` - - - - - - -**Bases:** [GenerationConfig](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationConfig) - - - - - - - - - - - - -```python -class nemo_rl.models.generation.vllm.config.VllmSpecificArgs -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx deleted file mode 100644 index 5bfcfad..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx +++ /dev/null @@ -1,113 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/vllm/utils -title: nemo_rl.models.generation.vllm.utils ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`aggregate_spec_decode_counters`](#nemo_rl-models-generation-vllm-utils-aggregate_spec_decode_counters) | Aggregate speculative decoding counters from multiple workers. | -| [`compute_spec_decode_metrics`](#nemo_rl-models-generation-vllm-utils-compute_spec_decode_metrics) | Compute delta and derived metrics for speculative decoding. | -| [`format_prompt_for_vllm_generation`](#nemo_rl-models-generation-vllm-utils-format_prompt_for_vllm_generation) | Format a list of prompts for vllm generation (which requires a specific format for its own `generate` method). | - -### API - - - - - -```python -nemo_rl.models.generation.vllm.utils.aggregate_spec_decode_counters( - worker_metrics: list[dict[str, float | list[float]]] -) -> dict[str | tuple[str, int], float] -``` - - - - - - -Aggregate speculative decoding counters from multiple workers. - -Combines spec decode metrics collected from DP leader workers into -a single aggregated counter dictionary. - -**Parameters:** - - -List of metric dictionaries from each worker. -Each dict maps metric names to float values or lists of floats -(for per-position metrics). - - -**Returns:** `dict[str | tuple[str, int], float]` - -Dictionary mapping metric names to their aggregated float values. - - - - - - - - -```python -nemo_rl.models.generation.vllm.utils.compute_spec_decode_metrics( - start_counters: dict[str | tuple[str, int], float], - end_counters: dict[str | tuple[str, int], float] -) -> dict[str, float] -``` - - - - - - -Compute delta and derived metrics for speculative decoding. - -Calculates the difference between two counter snapshots and derives -acceptance rate and acceptance length metrics for logging. - -**Parameters:** - - -Counter snapshot taken before generation. - - - -Counter snapshot taken after generation. - - -**Returns:** `dict[str, float]` - -Dictionary of metrics suitable for logging to wandb/tensorboard. - - - - - - - - -```python -nemo_rl.models.generation.vllm.utils.format_prompt_for_vllm_generation( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - sample_idx: typing.Optional[int] = None -) -> list[dict[str, typing.Any]] -``` - - - - - - -Format a list of prompts for vllm generation (which requires a specific format for its own `generate` method). - -See https://docs.vllm.ai/en/v0.9.1/features/multimodal_inputs.html for prompt format for multimodal inputs. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx deleted file mode 100644 index c06de54..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx +++ /dev/null @@ -1,236 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend -title: nemo_rl.models.generation.vllm.vllm_backend ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`VllmInternalWorkerExtension`](#nemo_rl-models-generation-vllm-vllm_backend-VllmInternalWorkerExtension) | - | - -### API - - - - - -```python -class nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension() -``` - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension._maybe_process_fp8_kv_cache() -> None -``` - - - - - - -Process weights after loading for FP8 KV cache (static scales). - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.cleanup() -> None -``` - - - - - - -Shutdown and cleanup resources. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.get_zmq_address() -``` - - - - - - -Get the ZMQ address for the current device. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.init_collective( - rank_prefix: int, - ip: str, - port: int, - world_size: int, - train_world_size: int -) -> None -``` - - - - - - -Initialize the collective communication. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.maybe_init_zmq() -``` - - - - - - -Initialize the ZMQ socket if it doesn't exist. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.prepare_refit_info( - state_dict_info: dict[str, typing.Any] -) -> None -``` - - - - - - -Prepare state dict metadata for weight refitting and IPC streaming. - -**Parameters:** - - -A dictionary containing the info for refit. -e.g. {tensor_name: (shape, dtype)} - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.report_device_id() -> str -``` - - - - - - -Retrieve the UUID of the current CUDA device. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.start_gpu_profiling() -> None -``` - - - - - - -Start GPU profiling. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.stop_gpu_profiling() -> None -``` - - - - - - -Stop GPU profiling. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.update_weights_from_collective() -> bool -``` - - - - - - -Update the model weights from collective communication. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.update_weights_via_ipc_zmq() -> bool -``` - - - - - - -Receive and update model weights via ZMQ IPC socket. - -**Returns:** `bool` - -True if weights were successfully updated. - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx deleted file mode 100644 index e4cdee2..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx +++ /dev/null @@ -1,656 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation -title: nemo_rl.models.generation.vllm.vllm_generation ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`VllmGeneration`](#nemo_rl-models-generation-vllm-vllm_generation-VllmGeneration) | - | - -### Data - -[`TOP_K_THRESHOLD`](#nemo_rl-models-generation-vllm-vllm_generation-TOP_K_THRESHOLD) - -[`TOP_P_THRESHOLD`](#nemo_rl-models-generation-vllm-vllm_generation-TOP_P_THRESHOLD) - -### API - - - - - -```python -class nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration( - cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, - config: nemo_rl.models.generation.vllm.config.VllmConfig, - name_prefix: str = 'vllm_policy', - workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None -) -``` - - - - - - -**Bases:** [GenerationInterface](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationInterface) - - - - - - - - - - - - - - - - - - - - - - - - - - -Check if KV cache scales should be synchronized during refit. - -Returns True if kv_cache_dtype is fp8/fp8_e4m3. - - - - - - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.__del__() -> None -``` - - - - - - -Shuts down the worker groups when the object is deleted or is garbage collected. - -This is an extra safety net in case the user forgets to call shutdown() and the pointer to -the object is lost due to leaving a function scope. It's always recommended that the -user calls shutdown(). - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._async_generate_base( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - method_name: str, - data_validation_fn, - greedy: bool = False -) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] -``` - - - - - - -async - -Base async generation method that handles common worker management logic. - -**Parameters:** - - -Input data for generation - - - -Name of the worker method to call ('generate_async' or 'generate_text_async') - - - -Function to validate input data - - - -Whether to use greedy decoding - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._get_raw_spec_counters() -> dict[str | tuple[str, int], float] -``` - - - - - - -Collect raw spec decode counters from workers. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._get_tied_worker_bundle_indices( - cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster -) -> list[tuple[int, list[int]]] -``` - - - - - - -Calculate bundle indices for tensor and pipeline parallel workers. - -Handles both unified placement groups (for cross-node model parallelism) and -per-node placement groups (for node-local model parallelism). - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._post_init() -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._report_device_id() -> list[list[str]] -``` - - - - - - -Report the device ID of vllm workers. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._report_dp_openai_server_base_urls() -> list[typing.Optional[str]] -``` - - - - - - -Report the data parallel OpenAI server base URLs of vLLM workers, only populated if it is async vLLM engine and the HTTP server is active. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.clear_logger_metrics() -> None -``` - - - - - - -Clear logger metrics for performance reporting. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.clear_vllm_logger_metrics() -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.finish_generation( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> bool -``` - - - - - - -Sleep workers and reset prefix cache. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] -``` - - - - - - -Generate a batch of data using vLLM. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate_async( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] -``` - - - - - - -async - -Generate responses asynchronously, yielding individual samples as they complete. - -This method provides per-sample streaming across all workers, yielding each -sample result as soon as it's ready, regardless of which worker processed it. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate_text( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] -``` - - - - - - -Generate text responses using vLLM. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate_text_async( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] -``` - - - - - - -async - -Generate text responses asynchronously, yielding results as they are ready. - -**Parameters:** - - -BatchedDataDict containing prompts with text strings - - - -Whether to use greedy decoding instead of sampling - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.get_logger_metrics() -> dict[str, typing.Any] -``` - - - - - - -Get logger metrics for performance reporting. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.get_step_metrics() -> dict[str, float] -``` - - - - - - -Get speculative decoding metrics delta since snapshot_step_metrics(). - -**Returns:** `dict[str, float]` - -Dictionary of delta metrics with 'vllm/' prefix. - -**Raises:** - -- `RuntimeWarning`: If called without snapshot_step_metrics() first. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.get_vllm_logger_metrics() -> dict[str, typing.Any] -``` - - - - - - -Collect vLLM logger metrics from vLLM workers (model-owner actors only). - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.init_collective( - ip: str, - port: int, - world_size: int, - train_world_size: int -) -> list[ray.ObjectRef] -``` - - - - - - -Initialize the collective communication. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.invalidate_kv_cache() -> bool -``` - - - - - - -Invalidate reusable caches in vLLM (e.g., prefix/KV cache) after weight updates. - -For async_engine, calls reset_prefix_cache_async on workers. For sync, calls reset_prefix_cache. -Returns True if all workers report success. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.prepare_for_generation( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> bool -``` - - - - - - -Wake workers up for colocated inference. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.prepare_refit_info( - state_dict_info: dict[str, typing.Any] -) -> None -``` - - - - - - -Prepare the info for refit. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.shutdown() -> bool -``` - - - - - - -Shut down all vLLM workers and clean up resources. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.snapshot_step_metrics() -> None -``` - - - - - - -Snapshot current spec decode counters to begin tracking a training step. - -Call this before generation to establish a baseline for metrics delta. - -**Raises:** - -- `RuntimeWarning`: If called twice without get_step_metrics() in between. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.start_gpu_profiling() -> None -``` - - - - - - -Start GPU profiling. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.stop_gpu_profiling() -> None -``` - - - - - - -Stop GPU profiling. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.update_weights_from_collective() -> list[ray.ObjectRef] -``` - - - - - - -Update weights of the policy using collective communication. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.update_weights_via_ipc_zmq() -> list[ray.ObjectRef] -``` - - - - - - -Update weights of the policy using IPC handles via ZMQ socket. - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.TOP_K_THRESHOLD = 8000 -``` - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_generation.TOP_P_THRESHOLD = 0.99 -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx deleted file mode 100644 index 080c9c1..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx +++ /dev/null @@ -1,545 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker -title: nemo_rl.models.generation.vllm.vllm_worker ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`BaseVllmGenerationWorker`](#nemo_rl-models-generation-vllm-vllm_worker-BaseVllmGenerationWorker) | - | -| [`VllmGenerationWorker`](#nemo_rl-models-generation-vllm-vllm_worker-VllmGenerationWorker) | - | - -### API - - - - - -```python -class nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker( - config: nemo_rl.models.generation.vllm.config.VllmConfig, - bundle_indices: typing.Optional[list[int]] = None, - fraction_of_gpus: float = 1.0, - seed: typing.Optional[int] = None -) -``` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.__repr__() -> str -``` - - - - - - -Customizes the actor's prefix in the Ray logs. - -This makes it easier to identify which worker is producing specific log messages. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker._build_sampling_params( - greedy: bool, - stop_strings, - max_new_tokens: typing.Optional[int] = None -) -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker._get_raw_spec_counters() -> dict[str, float | list[float]] -``` - - - - - - -Get speculative decoding metrics from the vLLM engine. - -Collects spec decode counters including number of drafts, -draft tokens, and accepted tokens for monitoring acceptance rates. - -**Returns:** `dict[str, float | list[float]]` - -Dictionary mapping metric names to their values. - -**Raises:** - -- `AssertionError`: If called before vLLM engine is initialized. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker._merge_stop_strings( - batch_stop_strings -) -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.configure_worker( - num_gpus: int | float, - bundle_indices: typing.Optional[tuple[int, list[int]]] = None -) -> tuple[dict[str, typing.Any], dict[str, str], dict[str, typing.Any]] -``` - - - - - - -staticmethod - -Provides complete worker configuration for vLLM tensor and pipeline parallelism. - -This method configures the worker based on its role in tensor and pipeline parallelism, -which is determined directly from the bundle_indices parameter. - -**Parameters:** - - -Original GPU allocation for this worker based on the placement group - - - -Tuple of (node_idx, local_bundle_indices) for parallelism (if applicable) - - -**Returns:** `tuple[dict[str, Any], dict[str, str], dict[str, Any]]` - -tuple with complete worker configuration: -- 'resources': Resource allocation (e.g., num_gpus) -- 'env_vars': Environment variables for this worker -- 'init_kwargs': Parameters to pass to __init__ of the worker - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.is_alive() -``` - - - - - - -Check if the worker is alive. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.llm() -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.start_gpu_profiling() -> None -``` - - - - - - -Start GPU profiling. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.stop_gpu_profiling() -> None -``` - - - - - - -Stop GPU profiling. - - - - - - - - - -```python -class nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker() -``` - - - - - - -**Bases:** [BaseVllmGenerationWorker](#nemo_rl-models-generation-vllm-vllm_worker-BaseVllmGenerationWorker) - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker._create_engine( - llm_kwargs: dict[str, typing.Any] -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.generate( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] -``` - - - - - - -Generate a batch of data using vLLM generation. - -**Parameters:** - - -BatchedDataDict containing input_ids and input_lengths tensors - - - -Whether to use greedy decoding instead of sampling - - -**Returns:** `BatchedDataDict[GenerationOutputSpec]` - -BatchedDataDict conforming to GenerationOutputSpec: -- output_ids: input + generated token IDs with proper padding -- logprobs: Log probabilities for tokens -- generation_lengths: Lengths of each response -- unpadded_sequence_lengths: Lengths of each input + generated sequence - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.generate_text( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] -``` - - - - - - -Generate text responses using vLLM generation. - -**Parameters:** - - -BatchedDataDict containing prompts with text strings - - - -Whether to use greedy decoding instead of sampling - - -**Returns:** `BatchedDataDict[GenerationOutputSpec]` - -BatchedDataDict containing: -- texts: List of generated text responses - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.init_collective( - rank_prefix: int, - ip: str, - port: int, - world_size: int, - train_world_size: int -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.post_init() -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.prepare_refit_info( - state_dict_info: dict[str, typing.Any] -) -> None -``` - - - - - - -Prepare the info for refit. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.report_device_id() -> list[str] -``` - - - - - - -Report device ID from the vLLM worker. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.reset_prefix_cache() -``` - - - - - - -Reset the prefix cache of vLLM engine. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.shutdown() -> bool -``` - - - - - - -Clean up vLLM resources. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.sleep() -``` - - - - - - -Put the vLLM engine to sleep. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.update_weights_from_collective() -> bool -``` - - - - - - -Update the model weights from collective communication. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.update_weights_via_ipc_zmq() -> bool -``` - - - - - - -Update weights from IPC handles via ZMQ socket. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.wake_up( - kwargs = {} -) -``` - - - - - - -Wake up the vLLM engine. - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx deleted file mode 100644 index b9b731a..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx +++ /dev/null @@ -1,485 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async -title: nemo_rl.models.generation.vllm.vllm_worker_async ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`VllmAsyncGenerationWorker`](#nemo_rl-models-generation-vllm-vllm_worker_async-VllmAsyncGenerationWorker) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`_replace_prefix_tokens`](#nemo_rl-models-generation-vllm-vllm_worker_async-_replace_prefix_tokens) | This is a subroutine used inside the vLLM Chat Completion server. | - -### API - - - - - -```python -class nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker() -``` - - - - - - -**Bases:** [BaseVllmGenerationWorker](/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker#nemo_rl-models-generation-vllm-vllm_worker-BaseVllmGenerationWorker) - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._create_engine( - llm_kwargs: dict[str, typing.Any] -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._setup_vllm_openai_api_server( - app: fastapi.FastAPI -) -> fastapi.FastAPI -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._setup_vllm_server() -> tuple[threading.Thread, str, uvicorn.Server] -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._start_vllm_metrics_logger() -> None -``` - - - - - - -Start a background thread that periodically collects vLLM logger metrics. - -Controlled by vllm_metrics_logger_interval (default: 0.5) in vllm_cfg. -Runs only on the model-owner actor. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.clear_vllm_logger_metrics() -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.generate_async( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] -``` - - - - - - -async - -Generate a batch of data using vLLM's AsyncLLMEngine, yielding results as they are ready. - -**Parameters:** - - -BatchedDataDict with input_ids and input_lengths - - - -Whether to use greedy decoding instead of sampling - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.generate_text_async( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] -``` - - - - - - -async - -Generate text responses asynchronously, yielding results as they are ready. - -**Parameters:** - - -BatchedDataDict containing prompts with text strings - - - -Whether to use greedy decoding instead of sampling - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.get_vllm_logger_metrics() -> dict[str, typing.Any] -``` - - - - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.init_collective_async( - rank_prefix: int, - ip: str, - port: int, - world_size: int, - train_world_size: int -) -> None -``` - - - - - - -async - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.post_init_async() -``` - - - - - - -async - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.prepare_refit_info_async( - state_dict_info: dict[str, typing.Any] -) -> None -``` - - - - - - -async - -Async version of prepare_refit_info. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.report_device_id_async() -> list[str] -``` - - - - - - -async - -Async version of report_device_id. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.report_dp_openai_server_base_url() -> typing.Optional[str] -``` - - - - - - -async - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.reset_prefix_cache_async() -``` - - - - - - -async - -Async version of reset_prefix_cache. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.shutdown() -> bool -``` - - - - - - -async - -Clean up vLLM resources. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.sleep_async() -``` - - - - - - -async - -Async version of sleep. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.update_weights_from_collective_async() -> bool -``` - - - - - - -async - -Async version of update_weights_from_collective. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.update_weights_via_ipc_zmq_async() -> bool -``` - - - - - - -async - -Async version of update_weights_via_ipc_zmq. - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.wake_up_async( - kwargs = {} -) -``` - - - - - - -async - -Async version of wake_up. - - - - - - - - - -```python -nemo_rl.models.generation.vllm.vllm_worker_async._replace_prefix_tokens( - tokenizer, - model_prefix_token_ids: list[int], - template_prefix_token_ids: list[int], - template_token_ids: list[int] -) -> list[int] -``` - - - - - - -This is a subroutine used inside the vLLM Chat Completion server. - -This function is for fixing up the chat template-tokenized messages history -to match the model output tokenization up to the last assistant turn, -in order to preserve the monotonic tokens property for optimized multi-turn -training. - -Some environments (namely NeMo-Gym) require an OpenAI compatible server -endpoint rather than an inference engine handle. This is fine for the most -part, but it may cause issues when the environment is used as a part of -training. - -RL training frameworks train models on token IDs, but the OpenAI compatible -server communicates in what is basically de-tokenized text. When multiple -model calls are made to the OpenAI compatible server in a single trajectory, -model generations in previous model calls may be re-tokenized to something -that is different than what was generated. This is not too big of an issue -(that we know of) at inference time, but the log probs the model produces -are different enough for the differently re-tokenized generation result that -it causes the training to be off policy. Off policy isn't necessarily a bad -thing in isolation, but this source of off-policyness may cause unexpected -issues if not properly accounted for. It also mis-aligns the token ID -sequences across model calls, which feels very strange during training. - -There are real cases where the model output string _does not match_ the chat -template tokenization of the parsed model output. A concrete example is -inconsistent whitespace tokens around tool call special tokens. - -TODO When NeMo RL supports training image generation models, we want to -revisit and possibly update this function. This issue occurs when the model -generates tokens that are de-tokenized into text or images, and then -re-tokenized into tokens. So if there is a situation like that with images -and image tokenization is non-unique, then we will need to uppdate this -function. - -Example (turn-by-turn, concise; eos_token_id = 2): - Turn 1: - - prefill_T1 (template prefill) = [11,12,13,40,41] - - model output = [220,17,2] # decodes to " 4" + EOS - - model_prefix_token_ids = prefill_T1 + model output - => [11,12,13,40,41,220,17,2] - - Turn 2 (template retokenizes prior assistant text differently): - - template_prefix_token_ids = [11,12,13,40,41,1001,2] # 1001 decodes to " 4" - - template_token_ids = [11,12,13,40,41,1001,2,21,22,40,41] - - _replace_prefix_tokens keeps the exact prior model tokens up to EOS and - resumes from the template after that EOS: - output => [11,12,13,40,41,220,17,2,21,22,40,41] - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx deleted file mode 100644 index 6095398..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx +++ /dev/null @@ -1,9 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/huggingface -title: nemo_rl.models.huggingface ---- - -## Submodules - -- **[`nemo_rl.models.huggingface.common`](/nemo-rl/nemo_rl/models/huggingface/common)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx deleted file mode 100644 index f626968..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx +++ /dev/null @@ -1,303 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/huggingface/common -title: nemo_rl.models.huggingface.common ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`FlashAttentionKwargs`](#nemo_rl-models-huggingface-common-FlashAttentionKwargs) | Dataclass to hold FlashAttention v2 kwargs. | -| [`ModelFlag`](#nemo_rl-models-huggingface-common-ModelFlag) | Enum that defines special flags for model-specific behaviors. | - -### Functions - -| Name | Description | -|------|-------------| -| [`get_flash_attention_kwargs`](#nemo_rl-models-huggingface-common-get_flash_attention_kwargs) | Returns kwargs required for FlashAttention v2 forward functions. | -| [`group_and_cat_tensors`](#nemo_rl-models-huggingface-common-group_and_cat_tensors) | Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor. | -| [`is_gemma_model`](#nemo_rl-models-huggingface-common-is_gemma_model) | - | -| [`pack_sequences`](#nemo_rl-models-huggingface-common-pack_sequences) | Packs sequences into rows where each row concatenates multiple sequences. | -| [`unpack_tensor`](#nemo_rl-models-huggingface-common-unpack_tensor) | Unpacks a packed tensor into individual sequences padded to the same length. | - -### Data - -[`Tensor`](#nemo_rl-models-huggingface-common-Tensor) - -### API - - - - - -```python -class nemo_rl.models.huggingface.common.FlashAttentionKwargs( - cu_seqlens_q: nemo_rl.models.huggingface.common.Tensor, - cu_seqlens_k: nemo_rl.models.huggingface.common.Tensor, - max_seqlen_q: int, - max_seqlen_k: int -) -``` - - - - - - -Dataclass - -Dataclass to hold FlashAttention v2 kwargs. - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.huggingface.common.ModelFlag -``` - - - - - - -**Bases:** `enum.Enum` - -Enum that defines special flags for model-specific behaviors. - -This enum provides a way to identify models that require special handling or -configuration in different parts of the NeMo RL codebase. - -Each flag has a `matches` method that determines if the flag applies to a given model_name. - - - - - - - - - - -```python -nemo_rl.models.huggingface.common.get_flash_attention_kwargs( - input_lengths: torch.Tensor -) -> nemo_rl.models.huggingface.common.FlashAttentionKwargs -``` - - - - - - -Returns kwargs required for FlashAttention v2 forward functions. - -**Parameters:** - - -[batch_size] containing lengths of each sequence - - -**Returns:** `FlashAttentionKwargs` - -Dict[str, torch.Tensor | int]: -{ - "cu_seqlens_q": Tensor[int32], - "cu_seqlens_k": Tensor[int32], - "max_seqlen_q": int, - "max_seqlen_k": int -} - - - - - - - - -```python -nemo_rl.models.huggingface.common.group_and_cat_tensors( - tensors: list[torch.Tensor], - group_sizes: list[int], - padding_value: int = 0, - min_seq_len: int = 0 -) -> torch.Tensor -``` - - - - - - -Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor. - -Each group of 1D tensors is concatenated into a single 1D tensor, and all resulting -group tensors are padded to the same length and stacked into a 2D tensor. - -**Parameters:** - - -List of 1D tensors of varying lengths. - - - -List of integers. Each integer specifies how many tensors to group. - - - -Integer used to pad shorter sequences. - - - -Minimum sequence length. - - -**Returns:** `torch.Tensor` - -A 2D tensor where each row is a padded concatenation of the grouped tensors. - - - - - - - - -```python -nemo_rl.models.huggingface.common.is_gemma_model( - model_name: str -) -> bool -``` - - - - - - - - - - - - - -```python -nemo_rl.models.huggingface.common.pack_sequences( - input_ids: torch.Tensor, - input_lengths: torch.Tensor, - packed_sequence_size: list[int], - padding_value: int = 0, - return_attention_mask: bool = True, - min_seq_len: int = 0 -) -> typing.Tuple[torch.Tensor, torch.Tensor, typing.Optional[torch.Tensor]] -``` - - - - - - -Packs sequences into rows where each row concatenates multiple sequences. - -Useful for sequence packing in transformer models (e.g. for SFT training). Returns: -packed input_ids, packed position_ids, and optional attention_mask. - -**Parameters:** - - -Tensor of shape [num_sequences, max_seq_len] - - - -Tensor of shape [num_sequences], containing true lengths - - - -How many sequences to pack per row - - - -Pad value for input_ids - - - -Whether to return per-row causal attention mask - - - -Minimum sequence length. - - -**Returns:** `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]` - - -input_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] -position_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] -attention_mask (Optional[torch.Tensor]): [batch_size, max_len, max_len] if requested - - - - - - - - -```python -nemo_rl.models.huggingface.common.unpack_tensor( - tensor, - input_lengths -) -``` - - - - - - -Unpacks a packed tensor into individual sequences padded to the same length. - -**Parameters:** - - -Packed tensor of shape [batch_size, packed_seq_len]. - - - -Original sequence lengths in the order they were packed. - - -**Returns:** - -torch.Tensor: [num_sequences, max_seq_len], each row is one unpacked and padded sequence. - - - - - - - - -```python -nemo_rl.models.huggingface.common.Tensor = TypeVar('Tensor', bound=(torch.Tensor)) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx deleted file mode 100644 index c37e90e..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx +++ /dev/null @@ -1,13 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/megatron -title: nemo_rl.models.megatron ---- - -## Submodules - -- **[`nemo_rl.models.megatron.common`](/nemo-rl/nemo_rl/models/megatron/common)** -- **[`nemo_rl.models.megatron.community_import`](/nemo-rl/nemo_rl/models/megatron/community_import)** -- **[`nemo_rl.models.megatron.config`](/nemo-rl/nemo_rl/models/megatron/config)** -- **[`nemo_rl.models.megatron.data`](/nemo-rl/nemo_rl/models/megatron/data)** -- **[`nemo_rl.models.megatron.setup`](/nemo-rl/nemo_rl/models/megatron/setup)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx deleted file mode 100644 index de40812..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx +++ /dev/null @@ -1,212 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/megatron/common -title: nemo_rl.models.megatron.common ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`_round_up_to_multiple`](#nemo_rl-models-megatron-common-_round_up_to_multiple) | - | -| [`broadcast_tensor`](#nemo_rl-models-megatron-common-broadcast_tensor) | Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata. | -| [`forward_step_arbitrary_loss`](#nemo_rl-models-megatron-common-forward_step_arbitrary_loss) | Forward training step with support for packed sequences and context parallelism. | -| [`get_moe_metrics`](#nemo_rl-models-megatron-common-get_moe_metrics) | Returns Mixture of Experts (MoE) auxiliary-loss metrics. | - -### API - - - - - -```python -nemo_rl.models.megatron.common._round_up_to_multiple( - value: int, - multiple: int -) -> int -``` - - - - - - - - - - - - - -```python -nemo_rl.models.megatron.common.broadcast_tensor( - tensor: torch.Tensor | None, - src_rank: int, - group: torch.distributed.ProcessGroup -) -> torch.Tensor -``` - - - - - - -Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata. - -Handles the case where the input tensor might be None on non-source ranks. -If the input tensor is provided on non-source ranks, it must have the -correct shape and dtype matching the tensor on the source rank. - -**Parameters:** - - -The tensor to broadcast on the source rank. Can be None on - non-source ranks (will be created with correct shape/dtype). - If not None on non-source ranks, it's used as the buffer - for the broadcast and must match the source tensor's metadata. - - - -The global rank of the source process. - - - -The process group for communication. - - -**Returns:** `torch.Tensor` - -torch.Tensor: The broadcasted tensor. On non-source ranks, this will - be the tensor received from the source. - -**Raises:** - -- `ValueError`: If the tensor is None on the source rank, or if a tensor - provided on a non-source rank has mismatched shape/dtype/device. -- `TypeError`: If broadcasting metadata fails (e.g., due to pickling issues). - - - - - - - - -```python -nemo_rl.models.megatron.common.forward_step_arbitrary_loss( - state: megatron.bridge.training.state.GlobalState, - global_valid_seqs: torch.Tensor, - global_valid_toks: torch.Tensor, - data_iterator: typing.Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any]], - model: megatron.core.models.gpt.GPTModel, - loss_fn: nemo_rl.algorithms.loss_functions.LossFunction, - pack_sequences: bool = False, - defer_fp32_logits: typing.Optional[bool] = None, - cp_normalize: bool = True, - policy_cfg: typing.Optional[dict] = None -) -``` - - - - - - -Forward training step with support for packed sequences and context parallelism. - -Notes on packed sequences with context parallelism (CP): - - When CP > 1, each sequence is padded to a multiple of (cp_size * 2) - - The factor of 2 ensures load balancing for causal attention - - cu_seqlens tracks actual sequence boundaries - - cu_seqlens_padded tracks padded sequence boundaries for CP - - Requires TransformerEngine >= 1.10 for CP support - -**Parameters:** - - -Global state for the run - - - -Global count of valid sequences - - - -Global count of valid tokens - - - -Input data iterator - - - -The GPT Model - - - -Loss function to apply - - - -Whether to pack sequences for efficiency - - - -Whether to skip the conversion of logits to fp32 - - - -Whether to normalize the loss by the cp_size - - - -Policy configuration containing generation parameters - - - - - - - - - -```python -nemo_rl.models.megatron.common.get_moe_metrics( - loss_scale: float, - total_loss_dict: typing.Optional[dict] = None, - per_layer_logging: bool = False -) -> dict[str, typing.Any] -``` - - - - - - -Returns Mixture of Experts (MoE) auxiliary-loss metrics. - -This function reduces MoE auxiliary losses across ranks, aggregates them, and -returns a dictionary of metrics. - -**Parameters:** - - -Scale factor to apply to each auxiliary loss (e.g., 1/num_microbatches). - - - -If provided, accumulate means into this dict (by name). - - - -If True, include per-layer values in the returned dict. - - -**Returns:** `dict[str, Any]` - -dict[str, Any]: A flat dict of aggregated metrics. For each aux loss name, - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx deleted file mode 100644 index a0f53a4..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx +++ /dev/null @@ -1,76 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/megatron/community_import -title: nemo_rl.models.megatron.community_import ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`export_model_from_megatron`](#nemo_rl-models-megatron-community_import-export_model_from_megatron) | - | -| [`import_model_from_hf_name`](#nemo_rl-models-megatron-community_import-import_model_from_hf_name) | Import a Hugging Face model into Megatron checkpoint format and save the Megatron checkpoint to the output path. | - -### API - - - - - -```python -nemo_rl.models.megatron.community_import.export_model_from_megatron( - hf_model_name: str, - input_path: str, - output_path: str, - hf_tokenizer_path: str, - overwrite: bool = False, - hf_overrides: typing.Optional[dict[str, typing.Any]] = {} -) -``` - - - - - - - - - - - - - -```python -nemo_rl.models.megatron.community_import.import_model_from_hf_name( - hf_model_name: str, - output_path: str, - megatron_config: typing.Optional[nemo_rl.models.policy.MegatronConfig] = None, - config_overrides: typing.Any = {} -) -``` - - - - - - -Import a Hugging Face model into Megatron checkpoint format and save the Megatron checkpoint to the output path. - -**Parameters:** - - -Hugging Face model ID or local path (e.g., 'meta-llama/Llama-3.1-8B-Instruct'). - - - -Directory to write the Megatron checkpoint (e.g., /tmp/megatron_ckpt). - - - -Optional megatron config with paralellism settings for distributed megatron model import. - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx deleted file mode 100644 index 7dda8b0..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx +++ /dev/null @@ -1,146 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/megatron/config -title: nemo_rl.models.megatron.config ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`MegatronGenerationConfig`](#nemo_rl-models-megatron-config-MegatronGenerationConfig) | - | -| [`ModelAndOptimizerState`](#nemo_rl-models-megatron-config-ModelAndOptimizerState) | Container for model and optimizer state. | -| [`RuntimeConfig`](#nemo_rl-models-megatron-config-RuntimeConfig) | Runtime configuration for model training and inference. | - -### API - - - - - -```python -class nemo_rl.models.megatron.config.MegatronGenerationConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.megatron.config.ModelAndOptimizerState() -``` - - - - - - -**Bases:** `NamedTuple` - -Container for model and optimizer state. - -This named tuple holds all model-related state including the model itself, -optimizer, scheduler, and metadata about the model type and configuration. - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.megatron.config.RuntimeConfig() -``` - - - - - - -**Bases:** `NamedTuple` - -Runtime configuration for model training and inference. - -This contains all validated runtime settings needed for model initialization, -parallelization, and training. - - - - - - - - - - - - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx deleted file mode 100644 index 30fa12b..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx +++ /dev/null @@ -1,471 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/megatron/data -title: nemo_rl.models.megatron.data ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`ProcessedInputs`](#nemo_rl-models-megatron-data-ProcessedInputs) | Processed microbatch inputs used for model forward pass. | -| [`ProcessedMicrobatch`](#nemo_rl-models-megatron-data-ProcessedMicrobatch) | Container for a processed microbatch ready for model forward pass. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_get_pack_sequence_parameters_for_megatron`](#nemo_rl-models-megatron-data-_get_pack_sequence_parameters_for_megatron) | Get pack sequence parameters for Megatron model processing with optional context parallelism. | -| [`_pack_sequences_for_megatron`](#nemo_rl-models-megatron-data-_pack_sequences_for_megatron) | Pack sequences for Megatron model processing with optional context parallelism. | -| [`_unpack_sequences_from_megatron`](#nemo_rl-models-megatron-data-_unpack_sequences_from_megatron) | Unpack sequences from Megatron output format. | -| [`get_and_validate_seqlen`](#nemo_rl-models-megatron-data-get_and_validate_seqlen) | - | -| [`get_microbatch_iterator`](#nemo_rl-models-megatron-data-get_microbatch_iterator) | Create a processed microbatch iterator from a batch of data. | -| [`make_processed_microbatch_iterator`](#nemo_rl-models-megatron-data-make_processed_microbatch_iterator) | Wrap a raw microbatch iterator to yield processed microbatches. | -| [`process_global_batch`](#nemo_rl-models-megatron-data-process_global_batch) | Process a global batch and compute normalization factors. | -| [`process_microbatch`](#nemo_rl-models-megatron-data-process_microbatch) | Process a microbatch for Megatron model forward pass. | - -### API - - - - - -```python -class nemo_rl.models.megatron.data.ProcessedInputs( - input_ids: torch.Tensor, - input_ids_cp_sharded: torch.Tensor, - attention_mask: typing.Optional[torch.Tensor], - position_ids: typing.Optional[torch.Tensor], - packed_seq_params: typing.Optional[megatron.core.packed_seq_params.PackedSeqParams], - cu_seqlens_padded: typing.Optional[torch.Tensor] -) -``` - - - - - - -Dataclass - -Processed microbatch inputs used for model forward pass. - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.megatron.data.ProcessedMicrobatch( - data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - input_ids: torch.Tensor, - input_ids_cp_sharded: torch.Tensor, - attention_mask: typing.Optional[torch.Tensor], - position_ids: typing.Optional[torch.Tensor], - packed_seq_params: typing.Optional[megatron.core.packed_seq_params.PackedSeqParams], - cu_seqlens_padded: typing.Optional[torch.Tensor] -) -``` - - - - - - -Dataclass - -Container for a processed microbatch ready for model forward pass. - -This dataclass holds both the original data dictionary and the processed -tensors needed for the Megatron model forward pass. - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.models.megatron.data._get_pack_sequence_parameters_for_megatron( - megatron_cfg: dict, - max_seq_len_in_batch: int -) -``` - - - - - - -Get pack sequence parameters for Megatron model processing with optional context parallelism. - -**Parameters:** - - -Megatron configuration - - - -Maximum sequence length in batch - - -**Returns:** - -Tuple of: - - - - - - - - -```python -nemo_rl.models.megatron.data._pack_sequences_for_megatron( - input_ids: torch.Tensor, - seq_lengths: torch.Tensor, - pad_individual_seqs_to_multiple_of: int = 1, - pad_packed_seq_to_multiple_of: int = 1, - pad_packed_seq_to: typing.Optional[int] = None, - cp_rank: int = 0, - cp_size: int = 1 -) -> tuple[torch.Tensor, megatron.core.packed_seq_params.PackedSeqParams, torch.Tensor, typing.Optional[torch.Tensor]] -``` - - - - - - -Pack sequences for Megatron model processing with optional context parallelism. - -**Parameters:** - - -Input token IDs [batch_size, seq_length] - - - -Actual sequence lengths for each sample [batch_size] - - - -Pad individual sequences to a multiple of this value - - - -Pad packed sequences to a multiple of this value - - - -Pad packed sequences to this value (before CP) -- The three parameters above can be calculated using _get_pack_sequence_parameters_for_megatron, we do not recommend users to set these parameters manually. - - - -Context parallelism size - - -**Returns:** `torch.Tensor` - -Tuple of: - - - - - - - - -```python -nemo_rl.models.megatron.data._unpack_sequences_from_megatron( - output_tensor: torch.Tensor, - seq_lengths: torch.Tensor, - cu_seqlens: torch.Tensor, - cu_seqlens_padded: typing.Optional[torch.Tensor], - original_batch_size: int, - original_seq_length: int -) -> torch.Tensor -``` - - - - - - -Unpack sequences from Megatron output format. - -**Parameters:** - - -Packed output tensor [1, T, vocab_size] - - - -Actual sequence lengths for each sample - - - -Cumulative sequence lengths - - - -Padded cumulative sequence lengths (if CP was used) - - - -Original batch size - - - -Original maximum sequence length - - -**Returns:** `torch.Tensor` - -Unpacked output tensor [batch_size, seq_length, vocab_size] - - - - - - - - -```python -nemo_rl.models.megatron.data.get_and_validate_seqlen( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] -) -``` - - - - - - - - - - - - - -```python -nemo_rl.models.megatron.data.get_microbatch_iterator( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - cfg: dict[str, typing.Any], - mbs: int, - straggler_timer: megatron.core.utils.StragglerDetector, - seq_length_key: typing.Optional[str] = None -) -> typing.Tuple[typing.Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch], int, int, int, int] -``` - - - - - - -Create a processed microbatch iterator from a batch of data. - -This function creates an iterator that yields ProcessedMicrobatch objects, -which contain both the original data dictionary and the processed tensors -ready for model forward pass. - -**Parameters:** - - -The batch data to create microbatches from - - - -Configuration dictionary - - - -Microbatch size - - - -Key for sequence lengths in data dict (auto-detected if None) - - -**Returns:** `Iterator[ProcessedMicrobatch]` - -Tuple containing the iterator and metadata - - - - - - - - -```python -nemo_rl.models.megatron.data.make_processed_microbatch_iterator( - raw_iterator: typing.Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any]], - cfg: dict[str, typing.Any], - seq_length_key: typing.Optional[str], - pad_individual_seqs_to_multiple_of: int, - pad_packed_seq_to_multiple_of: int, - straggler_timer: megatron.core.utils.StragglerDetector, - pad_full_seq_to: typing.Optional[int] -) -> typing.Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch] -``` - - - - - - -Wrap a raw microbatch iterator to yield processed microbatches. - -This function takes a raw iterator that yields BatchedDataDict objects and -wraps it to yield ProcessedMicrobatch objects that contain both the original -data and the processed tensors ready for model forward pass. - -**Parameters:** - - -Iterator yielding raw BatchedDataDict microbatches - - - -Configuration dictionary containing sequence_packing settings - - - -Key for sequence length in data dict (required for packing) - - - -Padding multiple for individual sequences - - - -Padding multiple for packed sequences - - - -Target length for full sequence padding (optional) - - - - - - - - - -```python -nemo_rl.models.megatron.data.process_global_batch( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - loss_fn: nemo_rl.algorithms.interfaces.LossFunction, - dp_group: torch.distributed.ProcessGroup, - batch_idx: int, - batch_size: int -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] -``` - - - - - - -Process a global batch and compute normalization factors. - -**Parameters:** - - -Full dataset - - - -Index of batch to extract - - - -Size of batch to extract - - - -Loss function (used to check loss type) - - - -Data parallel mesh - - -**Returns:** `torch.Tensor` - -Dictionary containing: - - - - - - - - -```python -nemo_rl.models.megatron.data.process_microbatch( - data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - seq_length_key: typing.Optional[str] = None, - pad_individual_seqs_to_multiple_of: int = 1, - pad_packed_seq_to_multiple_of: int = 1, - pad_full_seq_to: typing.Optional[int] = None, - pack_sequences: bool = False, - straggler_timer: megatron.core.utils.StragglerDetector = None -) -> tuple[torch.Tensor, torch.Tensor, typing.Optional[torch.Tensor], typing.Optional[torch.Tensor], typing.Optional[megatron.core.packed_seq_params.PackedSeqParams], typing.Optional[torch.Tensor]] -``` - - - - - - -Process a microbatch for Megatron model forward pass. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx deleted file mode 100644 index 485915a..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx +++ /dev/null @@ -1,535 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/megatron/setup -title: nemo_rl.models.megatron.setup ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`MoEFloat16Module`](#nemo_rl-models-megatron-setup-MoEFloat16Module) | Float 16 Module with the ability to keep the expert bias in float32. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_apply_moe_config`](#nemo_rl-models-megatron-setup-_apply_moe_config) | Apply Mixture of Experts configuration. | -| [`_apply_parallelism_config`](#nemo_rl-models-megatron-setup-_apply_parallelism_config) | Apply tensor/pipeline/context parallelism configuration. | -| [`_apply_performance_config`](#nemo_rl-models-megatron-setup-_apply_performance_config) | Apply performance optimization configuration. | -| [`_apply_precision_config`](#nemo_rl-models-megatron-setup-_apply_precision_config) | Apply precision and dtype configuration. | -| [`_create_checkpoint_config`](#nemo_rl-models-megatron-setup-_create_checkpoint_config) | Create checkpoint configurations. | -| [`_create_megatron_config`](#nemo_rl-models-megatron-setup-_create_megatron_config) | Create the final Megatron configuration container. | -| [`_validate_chunking_config`](#nemo_rl-models-megatron-setup-_validate_chunking_config) | Validate chunking configuration. | -| [`_validate_dtype_config`](#nemo_rl-models-megatron-setup-_validate_dtype_config) | - | -| [`_validate_optimizer_config`](#nemo_rl-models-megatron-setup-_validate_optimizer_config) | Validate optimizer configuration. | -| [`_validate_training_config`](#nemo_rl-models-megatron-setup-_validate_training_config) | Validate training configuration. | -| [`destroy_parallel_state`](#nemo_rl-models-megatron-setup-destroy_parallel_state) | Safely destroy parallel state and reset async call tracking. | -| [`finalize_megatron_setup`](#nemo_rl-models-megatron-setup-finalize_megatron_setup) | Finalize the setup with remaining configurations. | -| [`handle_model_import`](#nemo_rl-models-megatron-setup-handle_model_import) | Handle HF model import if checkpoint doesn't exist. | -| [`setup_distributed`](#nemo_rl-models-megatron-setup-setup_distributed) | Handle NCCL settings, dtype mapping, and basic config setup. | -| [`setup_model_and_optimizer`](#nemo_rl-models-megatron-setup-setup_model_and_optimizer) | - | -| [`setup_model_config`](#nemo_rl-models-megatron-setup-setup_model_config) | Handle all the model configuration logic. | -| [`setup_reference_model_state`](#nemo_rl-models-megatron-setup-setup_reference_model_state) | Setup the reference model for inference and return its state dict. | -| [`validate_and_set_config`](#nemo_rl-models-megatron-setup-validate_and_set_config) | - | -| [`validate_model_paths`](#nemo_rl-models-megatron-setup-validate_model_paths) | Validate and setup model paths. | - -### Data - -[`HAVE_FSDP2`](#nemo_rl-models-megatron-setup-HAVE_FSDP2) - -[`TokenizerType`](#nemo_rl-models-megatron-setup-TokenizerType) - -### API - - - - - -```python -class nemo_rl.models.megatron.setup.MoEFloat16Module( - config: megatron.core.transformer.transformer_config.TransformerConfig, - module: torch.nn.Module -) -``` - - - - - - -**Bases:** `Float16Module` - -Float 16 Module with the ability to keep the expert bias in float32. - -**Parameters:** - - -The transformer config used to initalize the model - - - - - - - -```python -nemo_rl.models.megatron.setup.MoEFloat16Module.re_enable_float32_expert_bias() -> None -``` - - - - - - -Ensure MoE router expert bias stays in float32 for numerical stability. - -Walks the wrapped module to find MoE routers and invokes the -`_maintain_float32_expert_bias()` helper which recreates or casts the -expert bias tensors to float32 as required by Megatron-LM. - - - - - - - - - -```python -nemo_rl.models.megatron.setup._apply_moe_config( - model_cfg: typing.Any, - config: nemo_rl.models.policy.PolicyConfig -) -> None -``` - - - - - - -Apply Mixture of Experts configuration. - - - - - - - - -```python -nemo_rl.models.megatron.setup._apply_parallelism_config( - model_cfg: typing.Any, - config: nemo_rl.models.policy.PolicyConfig -) -> None -``` - - - - - - -Apply tensor/pipeline/context parallelism configuration. - - - - - - - - -```python -nemo_rl.models.megatron.setup._apply_performance_config( - model_cfg: typing.Any, - config: nemo_rl.models.policy.PolicyConfig -) -> None -``` - - - - - - -Apply performance optimization configuration. - - - - - - - - -```python -nemo_rl.models.megatron.setup._apply_precision_config( - model_cfg: typing.Any, - config: nemo_rl.models.policy.PolicyConfig, - dtype: torch.dtype -) -> None -``` - - - - - - -Apply precision and dtype configuration. - - - - - - - - -```python -nemo_rl.models.megatron.setup._create_checkpoint_config( - pretrained_path: str, - weights_path: typing.Optional[str] -) -> megatron.bridge.training.config.CheckpointConfig -``` - - - - - - -Create checkpoint configurations. - - - - - - - - -```python -nemo_rl.models.megatron.setup._create_megatron_config( - model_cfg: typing.Any, - checkpoint_config: megatron.bridge.training.config.CheckpointConfig, - config: nemo_rl.models.policy.PolicyConfig, - hf_model_name: str, - dtype: torch.dtype -) -> megatron.bridge.training.config.ConfigContainer -``` - - - - - - -Create the final Megatron configuration container. - - - - - - - - -```python -nemo_rl.models.megatron.setup._validate_chunking_config( - config: nemo_rl.models.policy.PolicyConfig -) -> None -``` - - - - - - -Validate chunking configuration. - - - - - - - - -```python -nemo_rl.models.megatron.setup._validate_dtype_config( - dtype: torch.dtype, - model_cfg: typing.Any, - optimizer_cfg: typing.Any -) -> None -``` - - - - - - - - - - - - - -```python -nemo_rl.models.megatron.setup._validate_optimizer_config( - config: nemo_rl.models.policy.PolicyConfig -) -> None -``` - - - - - - -Validate optimizer configuration. - - - - - - - - -```python -nemo_rl.models.megatron.setup._validate_training_config( - config: nemo_rl.models.policy.PolicyConfig, - model_cfg: typing.Any -) -> None -``` - - - - - - -Validate training configuration. - - - - - - - - -```python -nemo_rl.models.megatron.setup.destroy_parallel_state() -``` - - - - - - -Safely destroy parallel state and reset async call tracking. - -This function is called during initialization to clean up temporary distributed -state from model import operations. Resetting async call tracking ensures that -when the main Megatron distributed context is created, all ranks start with -consistent call_idx values for async checkpointing. - - - - - - - - -```python -nemo_rl.models.megatron.setup.finalize_megatron_setup( - config: nemo_rl.models.policy.PolicyConfig, - megatron_cfg: megatron.bridge.training.config.ConfigContainer, - hf_model_name: str, - worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding, - model, - optimizer -) -> tuple -``` - - - - - - -Finalize the setup with remaining configurations. - -**Returns:** `tuple` - -Tuple of (megatron_tokenizer, megatron_bridge, should_disable_forward_pre_hook, dp_size) - - - - - - - - -```python -nemo_rl.models.megatron.setup.handle_model_import( - config: nemo_rl.models.policy.PolicyConfig, - hf_model_name: str, - pretrained_path: str, - pt_checkpoint_exists: bool -) -> None -``` - - - - - - -Handle HF model import if checkpoint doesn't exist. - - - - - - - - -```python -nemo_rl.models.megatron.setup.setup_distributed() -> None -``` - - - - - - -Handle NCCL settings, dtype mapping, and basic config setup. - - - - - - - - -```python -nemo_rl.models.megatron.setup.setup_model_and_optimizer( - policy_cfg: nemo_rl.models.policy.PolicyConfig, - megatron_cfg: megatron.bridge.training.config.ConfigContainer, - load_optimizer: bool = True, - get_embedding_ranks = None, - get_position_embedding_ranks = None -) -``` - - - - - - - - - - - - - -```python -nemo_rl.models.megatron.setup.setup_model_config( - config: nemo_rl.models.policy.PolicyConfig, - rank, - dtype, - hf_model_name: str, - pretrained_path: str, - weights_path: typing.Optional[str] = None -) -> tuple[megatron.bridge.training.config.ConfigContainer, typing.Any] -``` - - - - - - -Handle all the model configuration logic. - - - - - - - - -```python -nemo_rl.models.megatron.setup.setup_reference_model_state( - config: nemo_rl.models.policy.PolicyConfig, - megatron_cfg: megatron.bridge.training.config.ConfigContainer, - pretrained_path: str -) -> dict -``` - - - - - - -Setup the reference model for inference and return its state dict. - - - - - - - - -```python -nemo_rl.models.megatron.setup.validate_and_set_config( - config, - rank, - hf_model_name, - pretrained_path, - weights_path, - tokenizer -) -``` - - - - - - - - - - - - - -```python -nemo_rl.models.megatron.setup.validate_model_paths( - config: nemo_rl.models.policy.PolicyConfig -) -> tuple[str, str, bool] -``` - - - - - - -Validate and setup model paths. - - - - - - - - -```python -nemo_rl.models.megatron.setup.HAVE_FSDP2 = True -``` - - - - - - - - - -```python -nemo_rl.models.megatron.setup.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx deleted file mode 100644 index 1b3ebde..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx +++ /dev/null @@ -1,948 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/policy -title: nemo_rl.models.policy ---- - -## Subpackages - -- **[`nemo_rl.models.policy.workers`](/nemo-rl/nemo_rl/models/policy/workers)** - -## Submodules - -- **[`nemo_rl.models.policy.interfaces`](/nemo-rl/nemo_rl/models/policy/interfaces)** -- **[`nemo_rl.models.policy.lm_policy`](/nemo-rl/nemo_rl/models/policy/lm_policy)** -- **[`nemo_rl.models.policy.utils`](/nemo-rl/nemo_rl/models/policy/utils)** - -## Package Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AutomodelBackendConfig`](#nemo_rl-models-policy-AutomodelBackendConfig) | Configuration for custom MoE implementation backend in Automodel. | -| [`AutomodelKwargs`](#nemo_rl-models-policy-AutomodelKwargs) | - | -| [`DTensorConfig`](#nemo_rl-models-policy-DTensorConfig) | - | -| [`DTensorConfigDisabled`](#nemo_rl-models-policy-DTensorConfigDisabled) | - | -| [`DynamicBatchingConfig`](#nemo_rl-models-policy-DynamicBatchingConfig) | - | -| [`DynamicBatchingConfigDisabled`](#nemo_rl-models-policy-DynamicBatchingConfigDisabled) | - | -| [`LoRAConfig`](#nemo_rl-models-policy-LoRAConfig) | - | -| [`LoRAConfigDisabled`](#nemo_rl-models-policy-LoRAConfigDisabled) | - | -| [`MegatronConfig`](#nemo_rl-models-policy-MegatronConfig) | - | -| [`MegatronConfigDisabled`](#nemo_rl-models-policy-MegatronConfigDisabled) | - | -| [`MegatronDDPConfig`](#nemo_rl-models-policy-MegatronDDPConfig) | - | -| [`MegatronOptimizerConfig`](#nemo_rl-models-policy-MegatronOptimizerConfig) | - | -| [`MegatronSchedulerConfig`](#nemo_rl-models-policy-MegatronSchedulerConfig) | - | -| [`PolicyConfig`](#nemo_rl-models-policy-PolicyConfig) | - | -| [`PytorchOptimizerConfig`](#nemo_rl-models-policy-PytorchOptimizerConfig) | - | -| [`RewardModelConfig`](#nemo_rl-models-policy-RewardModelConfig) | - | -| [`SequencePackingConfig`](#nemo_rl-models-policy-SequencePackingConfig) | - | -| [`SequencePackingConfigDisabled`](#nemo_rl-models-policy-SequencePackingConfigDisabled) | - | -| [`SinglePytorchMilestonesConfig`](#nemo_rl-models-policy-SinglePytorchMilestonesConfig) | - | -| [`SinglePytorchSchedulerConfig`](#nemo_rl-models-policy-SinglePytorchSchedulerConfig) | - | -| [`TokenizerConfig`](#nemo_rl-models-policy-TokenizerConfig) | - | - -### Data - -[`SchedulerMilestones`](#nemo_rl-models-policy-SchedulerMilestones) - -### API - - - - - -```python -class nemo_rl.models.policy.AutomodelBackendConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Configuration for custom MoE implementation backend in Automodel. - -Used when setting the backend in automodel_kwargs in your config. -Alternatively, pass `force_hf: true` in automodel_kwargs to fall back -to the HuggingFace implementation. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.AutomodelKwargs -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.DTensorConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.DTensorConfigDisabled -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - -```python -class nemo_rl.models.policy.DynamicBatchingConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.DynamicBatchingConfigDisabled -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - -```python -class nemo_rl.models.policy.LoRAConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.LoRAConfigDisabled -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - -```python -class nemo_rl.models.policy.MegatronConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.MegatronConfigDisabled -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - -```python -class nemo_rl.models.policy.MegatronDDPConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.MegatronOptimizerConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.MegatronSchedulerConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.PolicyConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.PytorchOptimizerConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.RewardModelConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.SequencePackingConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.SequencePackingConfigDisabled -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - -```python -class nemo_rl.models.policy.SinglePytorchMilestonesConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - -```python -class nemo_rl.models.policy.SinglePytorchSchedulerConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.models.policy.TokenizerConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.models.policy.SchedulerMilestones = dict[str, list[int]] -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx deleted file mode 100644 index 8cbb649..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx +++ /dev/null @@ -1,574 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/policy/interfaces -title: nemo_rl.models.policy.interfaces ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`ColocatablePolicyInterface`](#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) | - | -| [`LogprobOutputSpec`](#nemo_rl-models-policy-interfaces-LogprobOutputSpec) | logprobs: Tensor of log probabilities. | -| [`PolicyInterface`](#nemo_rl-models-policy-interfaces-PolicyInterface) | Abstract base class defining the interface for RL policies. | -| [`ReferenceLogprobOutputSpec`](#nemo_rl-models-policy-interfaces-ReferenceLogprobOutputSpec) | logprobs: Tensor of log probabilities. | -| [`ScoreOutputSpec`](#nemo_rl-models-policy-interfaces-ScoreOutputSpec) | scores: Tensor of scores. | -| [`TopkLogitsOutputSpec`](#nemo_rl-models-policy-interfaces-TopkLogitsOutputSpec) | Per-position top-k logits and corresponding global token indices. | - -### API - - - - - -```python -class nemo_rl.models.policy.interfaces.ColocatablePolicyInterface() -``` - - - - - - -**Bases:** [PolicyInterface](#nemo_rl-models-policy-interfaces-PolicyInterface) - - - - - -```python -nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.broadcast_weights_for_collective( - kv_scales: typing.Optional[dict[str, float]] = None -) -> list[ray.ObjectRef] -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.init_collective( - ip: str, - port: int, - world_size: int, - train_world_size: int -) -> list[ray.ObjectRef] -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.offload_after_refit() -> None -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.offload_before_refit() -> None -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.prepare_for_lp_inference() -> None -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.stream_weights_via_http( - sglang_url_to_gpu_uuids: dict[str, list[str]] -) -> list[ray.ObjectRef] -``` - - - - - - -Stream model weights to SGLang servers via HTTP API. - -**Parameters:** - - -Dict mapping SGLang server URL to list of GPU UUIDs it uses - - - - - - - - -```python -nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.stream_weights_via_ipc_zmq( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> list[ray.ObjectRef] -``` - - - - - - -abstract - - - - - - - - - -```python -class nemo_rl.models.policy.interfaces.LogprobOutputSpec -``` - - - - - - -**Bases:** `typing.TypedDict` - -logprobs: Tensor of log probabilities. - - - - - - - - - - - -```python -class nemo_rl.models.policy.interfaces.PolicyInterface() -``` - - - - - - -Abstract - -Abstract base class defining the interface for RL policies. - - - - - - -```python -nemo_rl.models.policy.interfaces.PolicyInterface.calibrate_qkv_fp8_scales( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - micro_batch_size: typing.Optional[int] = None, - percentile: float = 99.9, - margin: float = 1.05, - include_q: bool = False -) -> dict[str, typing.Any] -``` - - - - - - -abstract - -Calibrate FP8 scales for Q/K/V activations used by KV cache. - -**Parameters:** - - -BatchedDataDict containing input_ids and input_lengths. - - - -Optional override for micro batch size during calibration. - - - -Percentile for per-tensor amax estimation. - - - -Safety margin multiplier applied to amax. - - - -Whether to also compute scale for Q in addition to K/V. - - -**Returns:** `dict[str, Any]` - -Dict with overall configuration and per-layer scales. - - - - - - - -```python -nemo_rl.models.policy.interfaces.PolicyInterface.finish_training( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> None -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.policy.interfaces.PolicyInterface.get_logprobs( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - timer: typing.Optional[nemo_rl.utils.timer.Timer] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] -``` - - - - - - -abstract - -Get logprobs of actions from observations. - -**Parameters:** - - -BatchedDataDict containing rollouts (tokens) - - -**Returns:** `BatchedDataDict[LogprobOutputSpec]` - -BatchedDataDict containing: -- logprobs: Tensor of logprobs of actions - - - - - - - -```python -nemo_rl.models.policy.interfaces.PolicyInterface.get_reference_policy_logprobs( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - micro_batch_size: typing.Optional[int] = None, - timer: typing.Optional[nemo_rl.utils.timer.Timer] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec] -``` - - - - - - -abstract - -Get logprobs of actions from observations. - -**Parameters:** - - -BatchedDataDict containing rollouts (tokens) - - -**Returns:** `BatchedDataDict[ReferenceLogprobOutputSpec]` - -BatchedDataDict containing: -- logprobs: Tensor of logprobs of actions - - - - - - - -```python -nemo_rl.models.policy.interfaces.PolicyInterface.get_topk_logits( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - k: int, - micro_batch_size: typing.Optional[int] = None, - timer: typing.Optional[nemo_rl.utils.timer.Timer] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec] -``` - - - - - - -abstract - -Get per-position top-k logits and global indices for a batch of inputs. - - - - - - - -```python -nemo_rl.models.policy.interfaces.PolicyInterface.prepare_for_training( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> None -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.policy.interfaces.PolicyInterface.save_checkpoint( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> None -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.policy.interfaces.PolicyInterface.shutdown() -> bool -``` - - - - - - -abstract - - - - - - - -```python -nemo_rl.models.policy.interfaces.PolicyInterface.train( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict, - loss_fn: nemo_rl.algorithms.interfaces.LossFunction, - eval_mode: bool = False, - gbs: typing.Optional[int] = None, - mbs: typing.Optional[int] = None, - timer: typing.Optional[nemo_rl.utils.timer.Timer] = None -) -> dict[str, typing.Any] -``` - - - - - - -abstract - -Train the policy on a global batch of data. - -**Parameters:** - - -BatchedDataDict containing rollouts (tokens) - - - -Loss function to use for training - - - -Whether to run in evaluation mode (no gradient updates) - - - -Global batch size override (if None, uses config default) - - - -Micro batch size override (if None, uses config default) - - - - - - - - - - -```python -class nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec -``` - - - - - - -**Bases:** `typing.TypedDict` - -logprobs: Tensor of log probabilities. - - - - - - - - - - - -```python -class nemo_rl.models.policy.interfaces.ScoreOutputSpec -``` - - - - - - -**Bases:** `typing.TypedDict` - -scores: Tensor of scores. - - - - - - - - - - - -```python -class nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec -``` - - - - - - -**Bases:** `typing.TypedDict` - -Per-position top-k logits and corresponding global token indices. - - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx deleted file mode 100644 index 7636f66..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx +++ /dev/null @@ -1,609 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/policy/lm_policy -title: nemo_rl.models.policy.lm_policy ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`Policy`](#nemo_rl-models-policy-lm_policy-Policy) | - | - -### Data - -[`PathLike`](#nemo_rl-models-policy-lm_policy-PathLike) - -### API - - - - - -```python -class nemo_rl.models.policy.lm_policy.Policy( - cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, - config: nemo_rl.models.policy.PolicyConfig, - tokenizer: transformers.PreTrainedTokenizerBase, - name_prefix: str = 'lm_policy', - workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None, - init_optimizer: bool = True, - weights_path: typing.Optional[nemo_rl.models.policy.lm_policy.PathLike] = None, - optimizer_path: typing.Optional[nemo_rl.models.policy.lm_policy.PathLike] = None, - init_reference_model: bool = True, - processor: typing.Optional[transformers.AutoProcessor] = None -) -``` - - - - - - -**Bases:** [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface), [GenerationInterface](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationInterface) - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.__del__() -> None -``` - - - - - - -Shuts down the worker groups when the object is deleted or is garbage collected. - -This is an extra safety net in case the user forgets to call worker_group.shutdown() and the pointer to -the object is lost due to leaving a function scope. It's always recommended that the -user calls worker_group.shutdown(). - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.broadcast_weights_for_collective( - kv_scales: typing.Optional[dict[str, float]] = None -) -> list[ray.ObjectRef] -``` - - - - - - -Broadcast the weights for collective communication. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.calibrate_qkv_fp8_scales( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - micro_batch_size: typing.Optional[int] = None, - percentile: float = 99.9, - margin: float = 1.05, - include_q: bool = False -) -> dict[str, typing.Any] -``` - - - - - - -Trigger KV-cache FP8 scale calibration across Megatron workers and return results. - -Note: The backend `MegatronPolicyWorker.calibrate_qkv_fp8_scales` already implements -distributed reduction, returning results merged across ranks. Therefore, we shard the -input by DP and call in parallel, then take the result from the first worker. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.finish_generation( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> bool -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.finish_training( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.generate( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] -``` - - - - - - -Generate a batch of data using the policy. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.get_free_memory_bytes() -> int -``` - - - - - - -Get the available free memory. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.get_logprobs( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - timer: typing.Optional[nemo_rl.utils.timer.Timer] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] -``` - - - - - - -Get the logprobs of the model for a data dict. - -**Returns:** `BatchedDataDict[LogprobOutputSpec]` - -a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.get_reference_policy_logprobs( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - micro_batch_size: typing.Optional[int] = None, - timer: typing.Optional[nemo_rl.utils.timer.Timer] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec] -``` - - - - - - -Get the logprobs of the reference policy for a data dict. - -Returns: Identical to get_logprobs. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.get_topk_logits( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - k: int, - micro_batch_size: typing.Optional[int] = None, - timer: typing.Optional[nemo_rl.utils.timer.Timer] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec] -``` - - - - - - -Dispatch get_topk_logits to workers (no CP/packed support initially). - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.init_collective( - ip: str, - port: int, - world_size: int, - train_world_size: int -) -> list[ray.ObjectRef] -``` - - - - - - -Initialize the collective communication. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.invalidate_kv_cache( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> bool -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.offload_after_refit() -> None -``` - - - - - - -Offload the optimizer and buffers to the CPU. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.offload_before_refit() -> None -``` - - - - - - -Offload the optimizer and buffers to the CPU. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.prepare_for_generation( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> bool -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.prepare_for_lp_inference( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.prepare_for_training( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] -``` - - - - - - -Prepare the info for refit. - -**Returns:** `Optional[dict[str, Any]]` - -A dictionary containing the info for refit. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.print_node_ip_and_gpu_id() -> list[tuple[str, int]] -``` - - - - - - -Print the node IP and GPU ID of the current worker. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.save_checkpoint( - weights_path: str, - optimizer_path: typing.Optional[str] = None, - tokenizer_path: typing.Optional[str] = None, - checkpointing_cfg: typing.Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None -) -> None -``` - - - - - - -Save a checkpoint of the model. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.score( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec] -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec] -``` - - - - - - -Score a batch of data using the policy. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.shutdown() -> bool -``` - - - - - - -Shut down all HF workers and clean up resources. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.start_gpu_profiling() -> None -``` - - - - - - -Start GPU profiling. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.stop_gpu_profiling() -> None -``` - - - - - - -Stop GPU profiling. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.stream_weights_via_http( - sglang_url_to_gpu_uuids: dict[str, list[str]] -) -> list[ray.ObjectRef] -``` - - - - - - -Send the weights to SGLang servers via HTTP API. - -**Parameters:** - - -Dict mapping SGLang server URL to list of GPU UUIDs it uses - - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.stream_weights_via_ipc_zmq( - buffer_size_bytes: int, - kv_scales: typing.Optional[dict[str, float]] = None -) -> list[ray.ObjectRef] -``` - - - - - - -Send the weights for IPC handles via ZMQ socket. - - - - - - - -```python -nemo_rl.models.policy.lm_policy.Policy.train( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - loss_fn: nemo_rl.algorithms.interfaces.LossFunction, - eval_mode: bool = False, - gbs: typing.Optional[int] = None, - mbs: typing.Optional[int] = None, - timer: typing.Optional[nemo_rl.utils.timer.Timer] = None -) -> dict[str, typing.Any] -``` - - - - - - -Train the policy on a batch of data with a given loss function. - - - - - - - - - -```python -nemo_rl.models.policy.lm_policy.PathLike = Union[str, 'os.PathLike[Any]'] -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx deleted file mode 100644 index 63b8675..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx +++ /dev/null @@ -1,624 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/policy/utils -title: nemo_rl.models.policy.utils ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`IPCProtocol`](#nemo_rl-models-policy-utils-IPCProtocol) | IPC protocol constants for ZMQ weight streaming. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_gather_ipc_handlers`](#nemo_rl-models-policy-utils-_gather_ipc_handlers) | Gather IPC handlers from all ranks in the default FSDP group, then filter by server. | -| [`_send_tensor_to_sglang`](#nemo_rl-models-policy-utils-_send_tensor_to_sglang) | Send gathered IPC handlers to SGLang server via HTTP. | -| [`_setup_ipc_gather_group`](#nemo_rl-models-policy-utils-_setup_ipc_gather_group) | Setup gather configuration for IPC handlers. | -| [`apply_top_k_only`](#nemo_rl-models-policy-utils-apply_top_k_only) | Apply top-k mask to the logits. | -| [`apply_top_k_top_p`](#nemo_rl-models-policy-utils-apply_top_k_top_p) | Apply top-k and top-p masks to the logits. | -| [`calculate_aligned_size`](#nemo_rl-models-policy-utils-calculate_aligned_size) | Calculate aligned size for memory alignment. | -| [`configure_dynamo_cache`](#nemo_rl-models-policy-utils-configure_dynamo_cache) | Disable dynamo autotune_local_cache. | -| [`get_gpu_info`](#nemo_rl-models-policy-utils-get_gpu_info) | Return information about the GPU being used by this worker. | -| [`get_handle_from_tensor`](#nemo_rl-models-policy-utils-get_handle_from_tensor) | Get IPC handle from a tensor. | -| [`get_megatron_checkpoint_dir`](#nemo_rl-models-policy-utils-get_megatron_checkpoint_dir) | Gets the default megatron checkpoint directory for initial HF -> Mcore conversion. | -| [`get_runtime_env_for_policy_worker`](#nemo_rl-models-policy-utils-get_runtime_env_for_policy_worker) | Get runtime environment configuration for policy workers. | -| [`is_vllm_v1_engine_enabled`](#nemo_rl-models-policy-utils-is_vllm_v1_engine_enabled) | Check if vLLM V1 engine is enabled. | -| [`rebuild_cuda_tensor_from_ipc`](#nemo_rl-models-policy-utils-rebuild_cuda_tensor_from_ipc) | Rebuild a CUDA tensor from an IPC handle. | -| [`resolve_model_class`](#nemo_rl-models-policy-utils-resolve_model_class) | Resolve the appropriate model class for a given model name. | -| [`stream_weights_via_http_impl`](#nemo_rl-models-policy-utils-stream_weights_via_http_impl) | Stream weights to SGLang servers via HTTP API (update_weights_from_tensor). | -| [`stream_weights_via_ipc_zmq_impl`](#nemo_rl-models-policy-utils-stream_weights_via_ipc_zmq_impl) | Shared implementation for streaming weights via IPC ZMQ with improved memory management. | - -### Data - -[`AUTOMODEL_FACTORY`](#nemo_rl-models-policy-utils-AUTOMODEL_FACTORY) - -[`NEMO_AUTOMODEL_AVAILABLE`](#nemo_rl-models-policy-utils-NEMO_AUTOMODEL_AVAILABLE) - -### API - - - - - -```python -class nemo_rl.models.policy.utils.IPCProtocol -``` - - - - - - -**Bases:** `enum.Enum` - -IPC protocol constants for ZMQ weight streaming. - - - - - - - - - - - - - -```python -nemo_rl.models.policy.utils._gather_ipc_handlers( - serialized_handler: str, - gather_group: typing.Optional[torch.distributed.ProcessGroup], - gather_src: typing.Optional[int], - rank: int, - matching_ranks: typing.Optional[list[int]] = None -) -> typing.Optional[list[str]] -``` - - - - - - -Gather IPC handlers from all ranks in the default FSDP group, then filter by server. - -**Parameters:** - - -Serialized IPC handler from this rank - - - -Process group (None means use default FSDP group) - - - -Rank that will collect and filter handlers - - - -Current rank - - - -List of ranks that belong to the same SGLang server - - -**Returns:** `Optional[list[str]]` - -List of serialized handlers in rank order (only on gather_src rank), None otherwise - - - - - - - - -```python -nemo_rl.models.policy.utils._send_tensor_to_sglang( - url: str, - tensor_name: str, - gathered_handlers: list[str], - shape: torch.Size, - dtype: str, - flush_cache: bool = False -) -> None -``` - - - - - - -Send gathered IPC handlers to SGLang server via HTTP. - -Key: gathered_handlers are in rank order [rank0, rank1, ...] -SGLang will automatically match: handler = serialized_handlers[tp_rank] - -**Parameters:** - - -SGLang server URL - - - -Name of the tensor - - - -List of serialized IPC handlers in rank order - - - -Tensor shape - - - -Tensor dtype - - - -Whether to flush cache after this tensor (for last tensor) - - - - - - - - - -```python -nemo_rl.models.policy.utils._setup_ipc_gather_group( - rank: int, - current_device_uuid: str, - sglang_gpu_uuids: list[str], - sglang_url_to_gpu_uuids: dict[str, list[str]] -) -> tuple[typing.Optional[torch.distributed.ProcessGroup], typing.Optional[int], typing.Optional[list[int]]] -``` - - - - - - -Setup gather configuration for IPC handlers. - -**Returns:** `Optional[dist.ProcessGroup]` - -Tuple of (gather_group, gather_src_rank, matching_ranks) - - - - - - - - -```python -nemo_rl.models.policy.utils.apply_top_k_only( - logits: torch.Tensor, - top_k: int -) -> torch.Tensor -``` - - - - - - -Apply top-k mask to the logits. - -Simplified version of VLLM's implementation for scalar parameters. -This implementation doesn't involve sorting the entire vocab. - -Based on VLLM's implementation: -https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py -SPDX-License-Identifier: Apache-2.0 -Copyright contributors to the vLLM project - -**Parameters:** - - -Input logits tensor of shape [batch_size, seq_len, vocab_size] - - - -Top-k sampling parameter. - - -**Returns:** `torch.Tensor` - -Filtered logits with top-k applied - - - - - - - - -```python -nemo_rl.models.policy.utils.apply_top_k_top_p( - logits: torch.Tensor, - top_k: typing.Optional[int] = None, - top_p: typing.Optional[float] = None -) -> torch.Tensor -``` - - - - - - -Apply top-k and top-p masks to the logits. - -Simplified version of VLLM's implementation for scalar parameters. - -Based on VLLM's implementation: -https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py -SPDX-License-Identifier: Apache-2.0 -Copyright contributors to the vLLM project - -**Parameters:** - - -Input logits tensor of shape [batch_size, seq_len, vocab_size] - - - -Top-k sampling parameter. Set to -1 to consider all tokens. - - - -Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. - - -**Returns:** `torch.Tensor` - -Filtered logits with sampling parameters applied - - - - - - - - -```python -nemo_rl.models.policy.utils.calculate_aligned_size( - size_bytes: int, - alignment: int = 512 -) -> int -``` - - - - - - -Calculate aligned size for memory alignment. - -**Parameters:** - - -Size in bytes to align - - - -Alignment boundary in bytes (default 512) - - -**Returns:** `int` - -Aligned size in bytes(int). - - - - - - - - -```python -nemo_rl.models.policy.utils.configure_dynamo_cache() -> None -``` - - - - - - -Disable dynamo autotune_local_cache. - -Dynamo may fail at cached_autotune when there's already a cache with different order of node_bundles. -Disable autotune_local_cache as a workaround. -See https://github.com/pytorch/pytorch/issues/153791 for more details. - - - - - - - - -```python -nemo_rl.models.policy.utils.get_gpu_info( - model: torch.nn.Module -) -> dict[str, typing.Any] -``` - - - - - - -Return information about the GPU being used by this worker. - - - - - - - - -```python -nemo_rl.models.policy.utils.get_handle_from_tensor( - tensor: torch.Tensor -) -> tuple[typing.Any] -``` - - - - - - -Get IPC handle from a tensor. - - - - - - - - -```python -nemo_rl.models.policy.utils.get_megatron_checkpoint_dir() -> str -``` - - - - - - -Gets the default megatron checkpoint directory for initial HF -> Mcore conversion. - -Megatron initial checkpoint should be saved to a path available on all nodes. The directory used will take this order of precendence: -1. $NRL_MEGATRON_CHECKPOINT_DIR (if set) -2. $HF_HOME/nemo_rl (if HF_HOME is set) -3. ~/.cache/huggingface/nemo_rl - -HF_HOME is preferred since many users will also have that path mounted and it means one less directory -to mount into your runtime environment. - - - - - - - - -```python -nemo_rl.models.policy.utils.get_runtime_env_for_policy_worker( - policy_worker_name: str -) -> dict[str, typing.Any] -``` - - - - - - -Get runtime environment configuration for policy workers. - -Note: expandable_segments configuration is handled directly in the worker init methods -to ensure proper GPU detection after CUDA initialization. - - - - - - - - -```python -nemo_rl.models.policy.utils.is_vllm_v1_engine_enabled() -> bool -``` - - - - - - -Check if vLLM V1 engine is enabled. - -**Returns:** `bool` - -True if V1 engine is enabled, False otherwise (defaults to True if not set) - - - - - - - - -```python -nemo_rl.models.policy.utils.rebuild_cuda_tensor_from_ipc( - cuda_ipc_handle: tuple, - device_id: int -) -> torch.Tensor -``` - - - - - - -Rebuild a CUDA tensor from an IPC handle. - - - - - - - - -```python -nemo_rl.models.policy.utils.resolve_model_class( - model_name: str -) -> typing.Any -``` - - - - - - -Resolve the appropriate model class for a given model name. - - - - - - - - -```python -nemo_rl.models.policy.utils.stream_weights_via_http_impl( - params_generator, - sglang_url_to_gpu_uuids: dict[str, list[str]], - rank: int, - worker_name: str, - current_device_uuid: str -) -> None -``` - - - - - - -Stream weights to SGLang servers via HTTP API (update_weights_from_tensor). - -Flow: Each rank creates IPC handler → gather handlers in rank order → send list → SGLang matches by tp_rank index - -Key points: -- Each rank creates handler on its own GPU -- Handlers are gathered in rank order: [rank0_handler, rank1_handler, ...] -- List index = rank = GPU ID -- SGLang automatically matches: handler = serialized_handlers[tp_rank] - -**Parameters:** - - -Generator yielding (name, tensor) pairs - - - -Dict mapping SGLang server URL to list of GPU UUIDs it uses - - - -Worker rank for logging - - - -Name of the worker for logging - - - -UUID of the current training worker's GPU - - - - - - - - - -```python -nemo_rl.models.policy.utils.stream_weights_via_ipc_zmq_impl( - params_generator, - buffer_size_bytes: int, - zmq_socket, - rank: int, - worker_name: str -) -> None -``` - - - - - - -Shared implementation for streaming weights via IPC ZMQ with improved memory management. - -Uses ping-pong double buffering to enable overlapping communication while reusing buffers -to reduce memory allocation overhead and improve stability. - -**Parameters:** - - -Generator yielding (name, tensor) pairs - - - -total size of buffer in bytes for batching parameters - - - -ZMQ socket for communication - - - -Worker rank for logging - - - -Name of the worker for logging - - - - - - - - - -```python -nemo_rl.models.policy.utils.AUTOMODEL_FACTORY: Dict[str, Any] = {'qwen2_5_vl': AutoModelForImageTextToText, 'qwen2_vl': AutoModelForImageTextToT... -``` - - - - - - - - - -```python -nemo_rl.models.policy.utils.NEMO_AUTOMODEL_AVAILABLE = True -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx deleted file mode 100644 index 3becc39..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx +++ /dev/null @@ -1,13 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/policy/workers -title: nemo_rl.models.policy.workers ---- - -## Submodules - -- **[`nemo_rl.models.policy.workers.base_policy_worker`](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker)** -- **[`nemo_rl.models.policy.workers.dtensor_policy_worker`](/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker)** -- **[`nemo_rl.models.policy.workers.dtensor_policy_worker_v2`](/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2)** -- **[`nemo_rl.models.policy.workers.megatron_policy_worker`](/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker)** -- **[`nemo_rl.models.policy.workers.patches`](/nemo-rl/nemo_rl/models/policy/workers/patches)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx deleted file mode 100644 index 0983a40..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx +++ /dev/null @@ -1,309 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker -title: nemo_rl.models.policy.workers.base_policy_worker ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AbstractPolicyWorker`](#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker) | Base class for policy workers with shared functionality. | - -### API - - - - - -```python -class nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker() -``` - - - - - - -Base class for policy workers with shared functionality. - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.finish_training( - args: typing.Any = (), - kwargs: typing.Any = {} -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_free_memory_bytes() -> int -``` - - - - - - -Get the available free memory. - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_gpu_info() -> dict[str, typing.Any] -``` - - - - - - -Return information about the GPU being used by this worker. - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_reference_policy_logprobs( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - micro_batch_size: typing.Optional[int] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec] -``` - - - - - - -Get the logprobs from the reference policy for a batch of data. - -If micro_batch_size is provided, it will be used instead of the configured -logprob_batch_size. - -**Returns:** `BatchedDataDict[ReferenceLogprobOutputSpec]` - -a BatchedDataDict with key "reference_logprobs" and shape [batch_size, sequence_length]. - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_zmq_address() -> str -``` - - - - - - -Get the ZMQ address for the current device. - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.init_collective( - ip: str, - port: int, - world_size: int, - train_world_size: int -) -> None -``` - - - - - - -Initialize the collective communication. - -**Parameters:** - - -IP address for the process group - - - -Port for the process group - - - -Total world size (train_world_size + inference_world_size) - - - -Number of training workers (used in inference cluster) - - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.is_alive() -> bool -``` - - - - - - -Check if the worker is alive. - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.maybe_init_zmq() -> None -``` - - - - - - -Initialize the ZMQ socket if it doesn't exist. - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.report_device_id() -> str -``` - - - - - - -Report the UUID of the current CUDA device using NVML. - -**Returns:** `str` - -UUID of the device in the format "GPU-xxxxx" - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.report_node_ip_and_gpu_id() -> tuple[str, int] -``` - - - - - - -Report the node IP and GPU ID of the current worker. - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.reset_peak_memory_stats() -> None -``` - - - - - - -Reset peak memory statistics. - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.shutdown() -> bool -``` - - - - - - -Shutdown the policy. - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.start_gpu_profiling() -> None -``` - - - - - - -Start GPU profiling. - - - - - - - -```python -nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.stop_gpu_profiling() -> None -``` - - - - - - -Stop GPU profiling. - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx deleted file mode 100644 index 6fb84a0..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx +++ /dev/null @@ -1,693 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker -title: nemo_rl.models.policy.workers.dtensor_policy_worker ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`DTensorPolicyWorker`](#nemo_rl-models-policy-workers-dtensor_policy_worker-DTensorPolicyWorker) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`get_cpu_state_dict`](#nemo_rl-models-policy-workers-dtensor_policy_worker-get_cpu_state_dict) | Copy the state dict generator to CPU memory. | -| [`unshard_fsdp2_model`](#nemo_rl-models-policy-workers-dtensor_policy_worker-unshard_fsdp2_model) | Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference. | - -### API - - - - - -```python -class nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker( - config: nemo_rl.models.policy.PolicyConfig, - tokenizer: transformers.AutoTokenizer, - processor: typing.Optional[transformers.AutoProcessor] = None, - weights_path: typing.Optional[str] = None, - optimizer_path: typing.Optional[str] = None, - init_optimizer: bool = True, - init_reference_model: bool = True, - kwargs: typing.Any = {} -) -``` - - - - - - -**Bases:** [AbstractPolicyWorker](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker), [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.__repr__() -> str -``` - - - - - - -Customizes the actor's prefix in the Ray logs. - -This makes it easier to identify which worker is producing specific log messages. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker._add_noise_to_weights() -> None -``` - - - - - - -Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker._apply_temperature_scaling( - logits: torch.Tensor -) -> torch.Tensor -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.broadcast_weights_for_collective( - kv_scales: typing.Optional[dict[str, float]] = None -) -> None -``` - - - - - - -Broadcast the weights for collective communication. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.calibrate_qkv_fp8_scales( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - micro_batch_size: typing.Optional[int] = None, - percentile: float = 99.9, - margin: float = 1.05, - include_q: bool = False -) -> dict[str, typing.Any] -``` - - - - - - -Placeholder for FP8 Q/K/V scale calibration, not implemented for DTensorPolicyWorker. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.create_context_parallel_ctx( - cp_mesh: torch.distributed.device_mesh.DeviceMesh, - cp_buffers: list[torch.Tensor], - cp_seq_dims: list[int], - cp_no_restore_buffers: typing.Set[torch.Tensor], - cp_rotate_method: typing.Optional[str] = None -) -``` - - - - - - -staticmethod - -Create a context parallel context. - -**Parameters:** - - -The device mesh for context parallel. - - - -The buffers for context parallel. - - - -The sequence dimensions for context parallel. - - - -The no restore buffers for context parallel. - - - -The rotation method for context parallel, such as "allgather" or "addtoall". - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.get_logprobs( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - micro_batch_size: typing.Optional[int] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] -``` - - - - - - -Get the logprobs of the model for a batch of data. - -Uses the configured logprob_batch_size to do microbatching. - -Input data is assumed to be right-padded. The method internally converts to -left-padded format for computation, and returns outputs in right-padded format. - -**Returns:** `BatchedDataDict[LogprobOutputSpec]` - -a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.get_topk_logits( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - k: int, - micro_batch_size: typing.Optional[int] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] -``` - - - - - - -Return per-position top-k logits and corresponding global indices. - -Notes: -- Return shapes are [B, S, k]. -- Computes top-k over the full sequence (no trimming of the last position). -- If alignment with next-token targets is required, the caller should handle it. -- If logits are TP-sharded DTensor, performs distributed global top-k across TP. -- Supports context parallelism with proper CP gather. -- Otherwise, computes local top-k on full-vocab tensor. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.load_checkpoint( - weights_path: str, - optimizer_path: typing.Optional[str] = None -) -> None -``` - - - - - - -Load a checkpoint into the model. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_buffer_to_device( - model: torch.nn.Module, - device: str | torch.device -) -> torch.nn.Module -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_optimizer_to_device( - device: str | torch.device -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_to_cpu( - model: torch.nn.Module -) -> torch.nn.Module -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_to_cuda( - model: torch.nn.Module -) -> torch.nn.Module -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_to_device( - model: torch.nn.Module, - device: str | torch.device -) -> torch.nn.Module -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.offload_after_refit() -> None -``` - - - - - - -Offload as much as possible on the CPU. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.offload_before_refit() -> None -``` - - - - - - -Offload the optimizer to the CPU. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.prepare_for_lp_inference() -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.prepare_for_training( - args = (), - kwargs = {} -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] -``` - - - - - - -Prepare state dict metadata for weight refitting and IPC streaming. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.return_model_config() -> dict[str, typing.Any] -``` - - - - - - -Return the model configuration as a dictionary. - -**Returns:** `dict[str, Any]` - -Model configuration dictionary - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.return_state_dict() -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.save_checkpoint( - weights_path: str, - optimizer_path: typing.Optional[str] = None, - tokenizer_path: typing.Optional[str] = None -) -> None -``` - - - - - - -Save a checkpoint of the model. - -the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.score( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec] -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.stream_weights_via_ipc_zmq( - buffer_size_bytes: int = 0, - kv_scales: typing.Optional[dict[str, float]] = None -) -> None -``` - - - - - - -Stream model weights to peer process via ZMQ IPC socket. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.train( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - loss_fn: nemo_rl.algorithms.interfaces.LossFunction, - eval_mode: bool = False, - gbs: typing.Optional[int] = None, - mbs: typing.Optional[int] = None -) -> dict[str, typing.Any] -``` - - - - - - -Train the policy on a batch of data with a given loss function. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.train_context( - cp_context: typing.Optional[typing.Generator[None, None, None]] = None -) -``` - - - - - - -staticmethod - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.use_reference_model() -> typing.Generator[None, None, None] -``` - - - - - - -Context manager that temporarily swaps the reference model and active model. - -On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references -On exit: Restores original references and re-flips cuda/cpu - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.get_cpu_state_dict( - state_generator: typing.Iterable[tuple[str, typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]]], - pin_memory: bool = False -) -> dict[str, torch.Tensor] -``` - - - - - - -Copy the state dict generator to CPU memory. - -**Parameters:** - - - -An iterable that yields (key, tensor) pairs from a model state. - - - - -Whether to allocate the CPU tensors in pinned memory for faster GPU transfer. -Defaults to False. - - -**Returns:** `dict[str, torch.Tensor]` - -dict[str, torch.Tensor]: A dictionary mapping parameter names to CPU tensors. - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker.unshard_fsdp2_model( - model: torch.nn.Module -) -> typing.Generator[None, None, None] -``` - - - - - - -Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx deleted file mode 100644 index b6ff0e4..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx +++ /dev/null @@ -1,714 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2 -title: nemo_rl.models.policy.workers.dtensor_policy_worker_v2 ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`DTensorPolicyWorkerV2`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-DTensorPolicyWorkerV2) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`_maybe_adapt_tensor_to_hf`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-_maybe_adapt_tensor_to_hf) | - | -| [`_maybe_merge_lora_weight`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-_maybe_merge_lora_weight) | - | -| [`dtensor_params_generator`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-dtensor_params_generator) | Generator that yields (name, tensor) pairs, converting DTensors to local tensors and adapting to HF format. | -| [`get_train_context`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-get_train_context) | Create combined context manager for training with context parallel and autocast. | - -### API - - - - - -```python -class nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2( - config: nemo_rl.models.policy.PolicyConfig, - tokenizer: transformers.AutoTokenizer, - processor: typing.Optional[transformers.AutoProcessor] = None, - weights_path: typing.Optional[str] = None, - optimizer_path: typing.Optional[str] = None, - init_optimizer: bool = True, - init_reference_model: bool = True, - kwargs: typing.Any = {} -) -``` - - - - - - -**Bases:** [AbstractPolicyWorker](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker), [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.__repr__() -> str -``` - - - - - - -Customizes the actor's prefix in the Ray logs. - -This makes it easier to identify which worker is producing specific log messages. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2._add_noise_to_weights() -> None -``` - - - - - - -Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2._init_checkpoint_manager( - config_updates: typing.Optional[dict[str, typing.Any]] = None, - checkpoint_root: typing.Optional[str] = None -) -> None -``` - - - - - - -Initialize the AutomodelCheckpointManager for this worker. - -This creates the checkpoint manager bound to this worker's device meshes -and initializes its underlying checkpointer. - -**Parameters:** - - -Dict of CheckpointingConfig fields to set during initialization. - - - -Optional root directory for checkpoints. - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.broadcast_weights_for_collective( - kv_scales: typing.Optional[dict[str, float]] = None -) -> None -``` - - - - - - -Broadcast the weights for collective communication. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.calibrate_qkv_fp8_scales( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - micro_batch_size: typing.Optional[int] = None, - percentile: float = 99.9, - margin: float = 1.05, - include_q: bool = False -) -> dict[str, typing.Any] -``` - - - - - - -Placeholder for FP8 Q/K/V scale calibration, not implemented for DTensorPolicyWorkerV2. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.get_logprobs( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - micro_batch_size: typing.Optional[int] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] -``` - - - - - - -Get the logprobs of the model for a batch of data. - -Uses the configured logprob_batch_size to do microbatching. - -Input data is assumed to be right-padded. The method internally converts to -left-padded format for computation, and returns outputs in right-padded format. - -**Returns:** `BatchedDataDict[LogprobOutputSpec]` - -a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.get_topk_logits( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - k: int, - micro_batch_size: typing.Optional[int] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] -``` - - - - - - -Return per-position top-k logits and corresponding global indices. - -Notes: -- Return shapes are [B, S, k]. -- Computes top-k over the full sequence (no trimming of the last position). -- If alignment with next-token targets is required, the caller should handle it. -- If logits are TP-sharded DTensor, performs distributed global top-k across TP. -- Supports context parallelism with proper CP gather. -- Otherwise, computes local top-k on full-vocab tensor. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.load_checkpoint( - weights_path: str, - optimizer_path: typing.Optional[str] = None -) -> None -``` - - - - - - -Load a checkpoint into the model using Automodel Checkpointer. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_buffer_to_device( - model: torch.nn.Module, - device: str | torch.device -) -> torch.nn.Module -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_optimizer_to_device( - device: str | torch.device -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_to_cpu( - model: torch.nn.Module -) -> torch.nn.Module -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_to_cuda( - model: torch.nn.Module -) -> torch.nn.Module -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_to_device( - model: torch.nn.Module, - device: str | torch.device -) -> torch.nn.Module -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.offload_after_refit() -> None -``` - - - - - - -Offload as much as possible on the CPU. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.offload_before_refit() -> None -``` - - - - - - -Offload the optimizer to the CPU. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.prepare_for_lp_inference() -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.prepare_for_training( - args = (), - kwargs = {} -) -> None -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] -``` - - - - - - -Prepare state dict metadata for weight refitting and IPC streaming. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.return_model_config() -> dict[str, typing.Any] -``` - - - - - - -Return the model configuration as a dictionary. - -**Returns:** `dict[str, Any]` - -Model configuration dictionary - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.return_state_dict() -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.save_checkpoint( - weights_path: str, - optimizer_path: typing.Optional[str] = None, - tokenizer_path: typing.Optional[str] = None, - checkpointing_cfg: typing.Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None -) -> None -``` - - - - - - -Save a checkpoint of the model. - -the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.score( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec] -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.stream_weights_via_http( - sglang_url_to_gpu_uuids: dict[str, list[str]] -) -> None -``` - - - - - - -Stream model weights to SGLang servers via HTTP API. - -**Parameters:** - - -Dict mapping SGLang server URL to list of GPU UUIDs it uses - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.stream_weights_via_ipc_zmq( - buffer_size_bytes: int = 0, - kv_scales: typing.Optional[dict[str, float]] = None -) -> None -``` - - - - - - -Stream model weights to peer process via ZMQ IPC socket. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.train( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - loss_fn: nemo_rl.algorithms.interfaces.LossFunction, - eval_mode: bool = False, - gbs: typing.Optional[int] = None, - mbs: typing.Optional[int] = None -) -> dict[str, typing.Any] -``` - - - - - - -Train the policy on a batch of data with a given loss function. - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.use_reference_model() -> typing.Generator[None, None, None] -``` - - - - - - -Context manager that temporarily swaps the reference model and active model. - -On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references -On exit: Restores original references and re-flips cuda/cpu - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2._maybe_adapt_tensor_to_hf( - model_part: torch.nn.Module, - fqn: str, - tensor: torch.Tensor, - quantization: bool = False -) -> list[tuple[str, torch.Tensor]] -``` - - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2._maybe_merge_lora_weight( - module_map: dict[str, torch.nn.Module], - fqn: str, - tensor: torch.Tensor -) -> torch.Tensor -``` - - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.dtensor_params_generator( - model: torch.nn.Module, - target_dtype: torch.dtype -) -> typing.Generator[tuple[str, torch.Tensor], None, None] -``` - - - - - - -Generator that yields (name, tensor) pairs, converting DTensors to local tensors and adapting to HF format. - -**Parameters:** - - -The model whose parameters to generate. - - - -The dtype to convert tensors to. - - - -Optional LoRA config for filtering which layers to merge. - - - - - - - - - -```python -nemo_rl.models.policy.workers.dtensor_policy_worker_v2.get_train_context( - cp_size: int, - cp_mesh: typing.Any, - cp_buffers: list, - sequence_dim: int, - dtype: torch.dtype, - autocast_enabled: bool = True -) -> typing.Generator[None, None, None] -``` - - - - - - -Create combined context manager for training with context parallel and autocast. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx deleted file mode 100644 index c8803a8..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx +++ /dev/null @@ -1,682 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker -title: nemo_rl.models.policy.workers.megatron_policy_worker ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`MegatronPolicyWorker`](#nemo_rl-models-policy-workers-megatron_policy_worker-MegatronPolicyWorker) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`broadcast_object_across_pp_ranks`](#nemo_rl-models-policy-workers-megatron_policy_worker-broadcast_object_across_pp_ranks) | Broadcast an object across pipeline parallel ranks. | - -### Data - -[`TokenizerType`](#nemo_rl-models-policy-workers-megatron_policy_worker-TokenizerType) - -### API - - - - - -```python -class nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker( - config: nemo_rl.models.policy.PolicyConfig, - tokenizer: nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType, - weights_path: typing.Optional[str] = None, - optimizer_path: typing.Optional[str] = None, - init_optimizer: bool = True, - init_reference_model: bool = True, - worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding, - kwargs: typing.Any = {} -) -``` - - - - - - -**Bases:** [AbstractPolicyWorker](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker), [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.__repr__() -``` - - - - - - -Customizes the actor's prefix in the Ray logs. - -This makes it easier to identify which worker is producing specific log messages. - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker._calculate_refit_param_info() -> list[tuple[str, int]] -``` - - - - - - -Calculate parameter information for refit. - -Each task contains: -- param_name: Local parameter name without module prefixes -- mapping: MegatronParamMapping instance for weight transformation -- pp_rank: Pipeline-parallel rank owning the parameter -- vp_stage: Virtual-pipeline stage index -- megatron_module: Reference to Megatron model/submodule -- param_weight: Target parameter tensor for converted weight - -**Returns:** `list[tuple[str, int]]` - -List of (parameter_name, size_in_bytes) tuples. - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker._iter_params_with_optional_kv_scales( - kv_scales: typing.Optional[dict[str, float]] = None -) -> typing.Iterator[tuple[str, torch.Tensor]] -``` - - - - - - -Yield exported HF parameters and optionally append FP8 KV/Q scale tensors. - -This helper is used by both IPC-based streaming and collective broadcast -so that the logic for adding KV scales stays consistent in one place. - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.broadcast_weights_for_collective( - kv_scales: typing.Optional[dict[str, float]] = None -) -> None -``` - - - - - - -Broadcast the weights for collective communication. - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.calibrate_qkv_fp8_scales( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - micro_batch_size: typing.Optional[int] = None, - percentile: float = 99.9, - margin: float = 1.05, - include_q: bool = False -) -> dict[str, typing.Any] -``` - - - - - - -One-shot calibration of Q/K/V activation scales (for FP8 KV cache). - -- Captures each layer's `query_key_value` output through forward hooks, splits Q/K/V, and computes percentile amax. -- In parallel (DP/TP/PP) environments, first computes local percentiles, then takes max across all ranks for conservativeness. -- By default only returns and saves K/V scales, optionally returns Q. - -**Parameters:** - - -Representative sample batch for calibration, following get_logprobs input conventions. - - - -Micro batch size during calibration; if None, reuses logprob_batch_size. - - - -Percentile for amax (e.g. 99.9). - - - -Margin factor, e.g. 1.05. - - - -If provided, rank0 will save results as JSON. - - - -Whether to also return Q scale (usually only K/V needed). - - -**Returns:** `dict[str, Any]` - -{ "format": "fp8", "percentile": float, "margin": float, -"layers": { layer_name: {"k_scale": float, "v_scale": float[, "q_scale": float] } } } - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.check_tensor_parallel_attributes() -> dict[str, typing.Any] -``` - - - - - - -Check tensor parallel attributes on model parameters. - -**Returns:** `dict[str, Any]` - -Dictionary containing information about tensor parallel parameters: - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.disable_forward_pre_hook( - param_sync = True -) -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.enable_forward_pre_hook() -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.generate( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - greedy: bool = False -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] -``` - - - - - - -Generate a batch of data using huggingface framework generation. - -Returns: - BatchedDataDict conforming to GenerationOutputSpec: - - output_ids: input + generated token IDs - - logprobs: Log probabilities for each token - - generation_lengths: Lengths of each response - -**Parameters:** - - -BatchedDataDict containing input_ids and input_lengths tensors - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.get_logprobs( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], - micro_batch_size: typing.Optional[int] = None -) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] -``` - - - - - - -Get the logprobs of the model for a batch of data. - -Uses the configured logprob_batch_size to do microbatching. -Input data is assumed to be right-padded. The method internally converts to -left-padded format for computation, and returns outputs in right-padded format. -If micro_batch_size is provided, it will be used instead of the configured -logprob_batch_size. - -**Returns:** `BatchedDataDict[LogprobOutputSpec]` - -a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.get_topk_logits( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], - k: int, - micro_batch_size: typing.Optional[int] = None -) -``` - - - - - - -Get the top-k logits and indices for a batch of data. - -The major difference from get_logprobs is that we compute top-k logits and indices for each position in the sequence. - -**Returns:** - -BatchedDataDict containing: -- topk_logits: Tensor of top-k logits for each position in the sequence -- topk_indices: Tensor of top-k indices for each position in the sequence - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.load_checkpoint( - weights_path: str, - optimizer_path: typing.Optional[str] = None -) -``` - - - - - - -Load a training checkpoint. - -**Parameters:** - - -The exact directory path from which to load the checkpoint. - - - -If not None, attempts to load optimizer and scheduler states - if self.optimizer and self.scheduler are initialized. - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.move_model( - model: torch.nn.Module, - device: str, - move_params: bool = True, - move_grads: bool = True -) -> torch.nn.Module -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.move_optimizer( - device: str -) -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.offload_after_refit() -``` - - - - - - -Offload as much as possible on the CPU. - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.offload_before_refit() -``` - - - - - - -Offload the optimizer and buffers to the CPU. - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.prepare_for_lp_inference() -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.prepare_for_training( - args = (), - kwargs = {} -) -``` - - - - - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.prepare_refit_info() -> None -``` - - - - - - -Prepare state dict metadata for weight refitting and IPC streaming. - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.save_checkpoint( - weights_path: str, - optimizer_path: typing.Optional[str] = None, - kwargs = {} -) -``` - - - - - - -Save a training checkpoint. - -**Parameters:** - - -The specific directory path where the checkpoint will be saved. - - - -If not None, optimizer and scheduler states are saved if they exist. - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.stream_weights_via_ipc_zmq( - buffer_size_bytes: int = 0, - kv_scales: typing.Optional[dict[str, float]] = None -) -> None -``` - - - - - - -Stream model weights to peer process via ZMQ IPC socket. - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.train( - data: nemo_rl.distributed.batched_data_dict.BatchedDataDict, - loss_fn: nemo_rl.algorithms.interfaces.LossFunction, - eval_mode: bool = False, - gbs: typing.Optional[int] = None, - mbs: typing.Optional[int] = None -) -> dict[str, typing.Any] -``` - - - - - - -Train the policy on a batch of data with a given loss function. - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.use_reference_model() -``` - - - - - - -Context manager that temporarily swaps the reference model and active model. - -On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references -On exit: Restores original references and re-flips cuda/cpu - - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.broadcast_object_across_pp_ranks( - obj -) -``` - - - - - - -Broadcast an object across pipeline parallel ranks. - -This utility function handles broadcasting an object from the rank that owns it -to all other pipeline parallel ranks. If only one rank has the object (non-None), -it will be broadcast to all other ranks. - -**Parameters:** - - -The object to broadcast. Can be None on ranks that don't own it. - - -**Returns:** - -The object on all ranks (either the original or the broadcast copy). - -**Raises:** - -- `ValueError`: If the object doesn't exist on any pipeline parallel rank. - - - - - - - - -```python -nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx deleted file mode 100644 index 4250027..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx +++ /dev/null @@ -1,85 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/models/policy/workers/patches -title: nemo_rl.models.policy.workers.patches ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`_get_transformer_engine_file`](#nemo_rl-models-policy-workers-patches-_get_transformer_engine_file) | Return absolute path to a Transformer Engine file or raise if it cannot be found. | -| [`apply_torch_aten_alias_tensor_patch`](#nemo_rl-models-policy-workers-patches-apply_torch_aten_alias_tensor_patch) | Register a sharding rule for `torch.ops.aten.alias.default`. | -| [`apply_transformer_engine_patch`](#nemo_rl-models-policy-workers-patches-apply_transformer_engine_patch) | Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files. | - -### API - - - - - -```python -nemo_rl.models.policy.workers.patches._get_transformer_engine_file( - relative_path: str -) -> str -``` - - - - - - -Return absolute path to a Transformer Engine file or raise if it cannot be found. - -The relative_path should be a POSIX-style path under the transformer_engine -package root, e.g. "pytorch/triton/permutation.py". - - - - - - - - -```python -nemo_rl.models.policy.workers.patches.apply_torch_aten_alias_tensor_patch() -``` - - - - - - -Register a sharding rule for `torch.ops.aten.alias.default`. - -Work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' -in PyTorch 2.9. See https://github.com/pytorch/pytorch/pull/166867 for the upstream fix. -We can remove this patch when we upgrade torch to include this fix. - - - - - - - - -```python -nemo_rl.models.policy.workers.patches.apply_transformer_engine_patch() -``` - - - - - - -Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files. - -This locates the target file via importlib metadata instead of importing -`transformer_engine`, to avoid side effects during initialization. If the -permutation module has already been imported, it will be reloaded so that -the patched source takes effect. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx deleted file mode 100644 index 2dc77ed..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx +++ /dev/null @@ -1,235 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/package_info -title: nemo_rl.package_info ---- - -## Module Contents - -### Data - -[`MAJOR`](#nemo_rl-package_info-MAJOR) - -[`MINOR`](#nemo_rl-package_info-MINOR) - -[`PATCH`](#nemo_rl-package_info-PATCH) - -[`PRE_RELEASE`](#nemo_rl-package_info-PRE_RELEASE) - -[`VERSION`](#nemo_rl-package_info-VERSION) - -[`__contact_emails__`](#nemo_rl-package_info-__contact_emails__) - -[`__contact_names__`](#nemo_rl-package_info-__contact_names__) - -[`__description__`](#nemo_rl-package_info-__description__) - -[`__download_url__`](#nemo_rl-package_info-__download_url__) - -[`__homepage__`](#nemo_rl-package_info-__homepage__) - -[`__keywords__`](#nemo_rl-package_info-__keywords__) - -[`__license__`](#nemo_rl-package_info-__license__) - -[`__package_name__`](#nemo_rl-package_info-__package_name__) - -[`__repository_url__`](#nemo_rl-package_info-__repository_url__) - -[`__shortversion__`](#nemo_rl-package_info-__shortversion__) - -[`__version__`](#nemo_rl-package_info-__version__) - -### API - - - - - -```python -nemo_rl.package_info.MAJOR = 0 -``` - - - - - - - - - -```python -nemo_rl.package_info.MINOR = 5 -``` - - - - - - - - - -```python -nemo_rl.package_info.PATCH = 0 -``` - - - - - - - - - -```python -nemo_rl.package_info.PRE_RELEASE = 'rc0' -``` - - - - - - - - - -```python -nemo_rl.package_info.VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE) -``` - - - - - - - - - -```python -nemo_rl.package_info.__contact_emails__ = 'nemo-tookit@nvidia.com' -``` - - - - - - - - - -```python -nemo_rl.package_info.__contact_names__ = 'NVIDIA' -``` - - - - - - - - - -```python -nemo_rl.package_info.__description__ = 'NeMo-RL - a toolkit for model alignment' -``` - - - - - - - - - -```python -nemo_rl.package_info.__download_url__ = 'https://github.com/NVIDIA-NeMo/RL/releases' -``` - - - - - - - - - -```python -nemo_rl.package_info.__homepage__ = 'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/' -``` - - - - - - - - - -```python -nemo_rl.package_info.__keywords__ = 'deep learning, machine learning, gpu, NLP, NeMo, nvidia, pytorch, torch, langua... -``` - - - - - - - - - -```python -nemo_rl.package_info.__license__ = 'Apache2' -``` - - - - - - - - - -```python -nemo_rl.package_info.__package_name__ = 'nemo_rl' -``` - - - - - - - - - -```python -nemo_rl.package_info.__repository_url__ = 'https://github.com/NVIDIA-NeMo/RL' -``` - - - - - - - - - -```python -nemo_rl.package_info.__shortversion__ = '.'.join(map(str, VERSION[:3])) -``` - - - - - - - - - -```python -nemo_rl.package_info.__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:]) -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx deleted file mode 100644 index b7dfc66..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx +++ /dev/null @@ -1,22 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils -title: nemo_rl.utils ---- - -## Submodules - -- **[`nemo_rl.utils.automodel_checkpoint`](/nemo-rl/nemo_rl/utils/automodel_checkpoint)** -- **[`nemo_rl.utils.checkpoint`](/nemo-rl/nemo_rl/utils/checkpoint)** -- **[`nemo_rl.utils.config`](/nemo-rl/nemo_rl/utils/config)** -- **[`nemo_rl.utils.flops_formulas`](/nemo-rl/nemo_rl/utils/flops_formulas)** -- **[`nemo_rl.utils.flops_tracker`](/nemo-rl/nemo_rl/utils/flops_tracker)** -- **[`nemo_rl.utils.logger`](/nemo-rl/nemo_rl/utils/logger)** -- **[`nemo_rl.utils.memory_tracker`](/nemo-rl/nemo_rl/utils/memory_tracker)** -- **[`nemo_rl.utils.native_checkpoint`](/nemo-rl/nemo_rl/utils/native_checkpoint)** -- **[`nemo_rl.utils.nsys`](/nemo-rl/nemo_rl/utils/nsys)** -- **[`nemo_rl.utils.nvml`](/nemo-rl/nemo_rl/utils/nvml)** -- **[`nemo_rl.utils.packed_tensor`](/nemo-rl/nemo_rl/utils/packed_tensor)** -- **[`nemo_rl.utils.prefetch_venvs`](/nemo-rl/nemo_rl/utils/prefetch_venvs)** -- **[`nemo_rl.utils.timer`](/nemo-rl/nemo_rl/utils/timer)** -- **[`nemo_rl.utils.venvs`](/nemo-rl/nemo_rl/utils/venvs)** diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx deleted file mode 100644 index 2afdec6..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx +++ /dev/null @@ -1,436 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/automodel_checkpoint -title: nemo_rl.utils.automodel_checkpoint ---- - -Automodel checkpoint utilities for DTensor policy workers. - -This module provides a wrapper class around the nemo_automodel Checkpointer -for saving and loading model checkpoints in DTensor-based policy workers. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`AutomodelCheckpointManager`](#nemo_rl-utils-automodel_checkpoint-AutomodelCheckpointManager) | Manages checkpointing for DTensor-based models using nemo_automodel's Checkpointer. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_infer_checkpoint_root`](#nemo_rl-utils-automodel_checkpoint-_infer_checkpoint_root) | Infer checkpoint root directory from weights path. | -| [`detect_checkpoint_format`](#nemo_rl-utils-automodel_checkpoint-detect_checkpoint_format) | Detect model save format and PEFT status from checkpoint directory. | - -### API - - - - - -```python -class nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager( - dp_mesh: torch.distributed.device_mesh.DeviceMesh, - tp_mesh: torch.distributed.device_mesh.DeviceMesh, - model_state_dict_keys: typing.Optional[list[str]] = None, - moe_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None -) -``` - - - - - - -Manages checkpointing for DTensor-based models using nemo_automodel's Checkpointer. - -This class provides a clean interface for saving and loading model checkpoints, -wrapping the nemo_automodel Checkpointer with configuration management. - - - - - - - - - - - -```python -nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager._get_dp_rank() -> int -``` - - - - - - -Get the data parallel rank. - - - - - - - -```python -nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager._get_tp_rank() -> int -``` - - - - - - -Get the tensor parallel rank. - - - - - - - -```python -nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager._rebuild_checkpointer_addons() -> None -``` - - - - - - -Rebuild the checkpointer's _addons list based on current config. - -The Checkpointer's _addons list is populated during __init__ based on config. -When config changes (e.g., model_save_format or is_peft), we need to rebuild -the addons list to match the new config. - - - - - - - -```python -nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.init_checkpointer( - config_updates: typing.Optional[dict[str, typing.Any]] = None, - checkpoint_root: typing.Optional[str] = None -) -> None -``` - - - - - - -Initialize the Automodel Checkpointer if not already created. - -This method creates a new Checkpointer instance with the provided configuration. -If a checkpointer already exists, this method does nothing. - -**Parameters:** - - -Dict of CheckpointingConfig fields to set during initialization. - - - -Optional root directory for checkpoints. - - - - - - - - -```python -nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.load_base_model( - model: torch.nn.Module, - model_name: str, - hf_cache_dir: typing.Optional[str] = None, - dequantize_base_checkpoint: bool = False, - peft_init_method: typing.Optional[str] = None -) -> None -``` - - - - - - -Load base model weights using the Automodel Checkpointer. - -This method loads the initial HuggingFace model weights into the parallelized model. - -**Parameters:** - - -The model to load weights into. - - - -Name or path of the model. - - - -Optional HuggingFace cache directory. - - - -Whether to dequantize the base checkpoint. - - -**Raises:** - -- `AssertionError`: If checkpointer has not been initialized. - - - - - - - -```python -nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.load_checkpoint( - model: torch.nn.Module, - weights_path: str, - optimizer: typing.Optional[torch.optim.Optimizer] = None, - optimizer_path: typing.Optional[str] = None, - scheduler: typing.Optional[torch.optim.lr_scheduler.LRScheduler] = None -) -> None -``` - - - - - - -Load a checkpoint into the model using Automodel Checkpointer. - -**Parameters:** - - -The model to load weights into. - - - -Path to the checkpoint weights. - - - -Optional optimizer to load state into. - - - -Optional path to optimizer checkpoint. - - - -Optional learning rate scheduler. - - - - - - - - -```python -nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.save_checkpoint( - model: torch.nn.Module, - weights_path: str, - optimizer: typing.Optional[torch.optim.Optimizer] = None, - optimizer_path: typing.Optional[str] = None, - scheduler: typing.Optional[torch.optim.lr_scheduler.LRScheduler] = None, - tokenizer: typing.Optional[transformers.AutoTokenizer] = None, - tokenizer_path: typing.Optional[str] = None, - checkpointing_cfg: typing.Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None, - lora_enabled: bool = False, - peft_config: typing.Optional[nemo_automodel.components._peft.lora.PeftConfig] = None -) -> None -``` - - - - - - -Save a checkpoint of the model. - -The optimizer states are saved only if `optimizer` and `optimizer_path` are provided. - -**Parameters:** - - -The model to save. - - - -Path to save model weights. - - - -Optional optimizer to save. - - - -Optional path to save optimizer state. - - - -Optional learning rate scheduler. - - - -Optional tokenizer to save with the checkpoint. - - - -Optional path to save tokenizer separately. - - - -Checkpointing configuration. - - - -Whether LoRA is enabled. - - - -Optional PEFT configuration. - - - - - - - - -```python -nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.set_model_state_dict_keys( - keys: list[str] -) -> None -``` - - - - - - -Set the model state dict keys for checkpoint validation. - -**Parameters:** - - -List of model state dict keys. - - - - - - - - -```python -nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.update_checkpointer_config( - config_updates: typing.Optional[dict[str, typing.Any]] = None, - checkpoint_root: typing.Optional[str] = None -) -> None -``` - - - - - - -Update the configuration of an existing Checkpointer. - -This method updates the mutable config fields on the existing Checkpointer instance. -If no checkpointer exists, this method does nothing. - -Note: Some config changes (like model_save_format) require rebuilding the -checkpointer's internal addons list. This method handles that automatically. - -**Parameters:** - - -Dict of CheckpointingConfig fields to update. - - - -Optional root directory for checkpoints. - - - - - - - - - - -```python -nemo_rl.utils.automodel_checkpoint._infer_checkpoint_root( - weights_path: str -) -> str -``` - - - - - - -Infer checkpoint root directory from weights path. - -When weights_path ends with "…/weights/model", we need the parent of -the weights directory (the checkpoint root), not the weights directory itself. - -**Parameters:** - - -Path to model weights (e.g., "/path/to/policy/weights/model") - - -**Returns:** `str` - -Checkpoint root directory (e.g., "/path/to/policy") - - - - - - - - -```python -nemo_rl.utils.automodel_checkpoint.detect_checkpoint_format( - weights_path: str -) -> tuple[str, bool] -``` - - - - - - -Detect model save format and PEFT status from checkpoint directory. - -**Parameters:** - - -Path to the checkpoint directory (e.g., weights/model) - - -**Returns:** `tuple[str, bool]` - -(model_save_format, is_peft) where: - model_save_format is "torch_save" for DCP or "safetensors" for safetensors - is_peft is True if PEFT/adapter patterns are detected - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx deleted file mode 100644 index 8c380d8..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx +++ /dev/null @@ -1,411 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/checkpoint -title: nemo_rl.utils.checkpoint ---- - -Checkpoint management utilities for the rl algorithm loop. - -It handles logic at the algorithm level. Each RL Actor is expected to have its -own checkpoint saving function (called by the algorithm loop). - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`CheckpointManager`](#nemo_rl-utils-checkpoint-CheckpointManager) | Manages model checkpoints during training. | -| [`CheckpointingConfig`](#nemo_rl-utils-checkpoint-CheckpointingConfig) | Configuration for checkpoint management. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_load_checkpoint_history`](#nemo_rl-utils-checkpoint-_load_checkpoint_history) | Load the history of checkpoints and their metrics. | - -### Data - -[`PathLike`](#nemo_rl-utils-checkpoint-PathLike) - -### API - - - - - -```python -class nemo_rl.utils.checkpoint.CheckpointManager( - config: nemo_rl.utils.checkpoint.CheckpointingConfig -) -``` - - - - - - -Manages model checkpoints during training. - -This class handles creating checkpoint dirs, saving training info, and -configurations. It also provides utilities for keeping just the top-k checkpoints. -The checkpointing structure looks like this: - - -```python -checkpoint_dir/ - step_0/ - training_info.json - config.yaml - policy.py (up to the algorithm loop to save here) - policy_optimizer.py (up to the algorithm loop to save here) - ... - step_1/ - ... -``` - - - -Attributes: Derived from the CheckpointingConfig. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.utils.checkpoint.CheckpointManager.finalize_checkpoint( - checkpoint_path: nemo_rl.utils.checkpoint.PathLike -) -> None -``` - - - - - - -Complete a checkpoint by moving it from temporary to permanent location. - -If a checkpoint at the target location already exists (i.e when resuming training), -we override the old one. -Also triggers cleanup of old checkpoints based on the keep_top_k setting. - -**Parameters:** - - -Path to the temporary checkpoint directory. - - - - - - - - -```python -nemo_rl.utils.checkpoint.CheckpointManager.get_best_checkpoint_path() -> typing.Optional[str] -``` - - - - - - -Get the path to the best checkpoint based on the metric. - -Returns the path to the checkpoint with the best metric value. If no checkpoints -exist, returns None. If some checkpoints are missing the metric, they are filtered -out with a warning. If no checkpoints have the metric, returns the latest checkpoint. - -**Returns:** `Optional[str]` - -Optional[str]: Path to the best checkpoint, or None if no checkpoints exist. - - - - - - - -```python -nemo_rl.utils.checkpoint.CheckpointManager.get_latest_checkpoint_path() -> typing.Optional[str] -``` - - - - - - -Get the path to the latest checkpoint. - -Returns the path to the checkpoint with the highest step number. - -**Returns:** `Optional[str]` - -Optional[str]: Path to the latest checkpoint, or None if no checkpoints exist. - - - - - - - -```python -nemo_rl.utils.checkpoint.CheckpointManager.init_tmp_checkpoint( - step: int, - training_info: typing.Mapping[str, typing.Any], - run_config: typing.Optional[typing.Mapping[str, typing.Any]] = None -) -> nemo_rl.utils.checkpoint.PathLike -``` - - - - - - -Initialize a temporary checkpoint directory. - -Creates a temporary directory for a new checkpoint and saves training info -and configuration. The directory is named 'tmp_step_{step}' and will be renamed -to 'step_{step}' when the checkpoint is completed. -We do it this way to allow the algorithm loop to save any files it wants to save -in a safe, temporary directory. - -**Parameters:** - - -The training step number. - - - -Dictionary containing training metrics and info. - - - -Optional configuration for the training run. - - -**Returns:** `PathLike` - -Path to the temporary checkpoint directory. - - - - - - - -```python -nemo_rl.utils.checkpoint.CheckpointManager.load_training_info( - checkpoint_path: typing.Optional[nemo_rl.utils.checkpoint.PathLike] = None -) -> typing.Optional[dict[str, typing.Any]] -``` - - - - - - -Load the training info from a checkpoint. - -**Parameters:** - - -Path to the checkpoint. If None, -returns None. - - -**Returns:** `Optional[dict[str, Any]]` - -Optional[dict[str, Any]]: Dictionary containing the training info, or None if -checkpoint_path is None. - - - - - - - -```python -nemo_rl.utils.checkpoint.CheckpointManager.remove_old_checkpoints( - exclude_latest: bool = True -) -> None -``` - - - - - - -Remove checkpoints that are not in the top-k or latest based on the (optional) metric. - -If keep_top_k is set, this method removes all checkpoints except the top-k -best ones. The "best" checkpoints are determined by: -- If a metric is provided: the given metric value and the higher_is_better setting. - When multiple checkpoints have the same metric value, more recent checkpoints - (higher step numbers) are preferred. -- If no metric is provided: the step number. The most recent k checkpoints are kept. - -**Parameters:** - - -Whether to exclude the latest checkpoint from deletion. (may result in K+1 checkpoints) - - - - - - - - - - -```python -class nemo_rl.utils.checkpoint.CheckpointingConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - -Configuration for checkpoint management. - -Attributes: -enabled (bool): Whether checkpointing is enabled. -checkpoint_dir (PathLike): Directory where checkpoints will be saved. -metric_name (str | None): Name of the metric to use for determining best checkpoints. - Must be of the form "val:<metric_name>" or "train:<metric_name>" to indicate whether - the metric should be taken from the validation or training metrics. -higher_is_better (bool): Whether higher values of the metric indicate better performance. -keep_top_k (Optional[int]): Number of best checkpoints to keep. If None, all checkpoints are kept. -model_save_format (str | None): Format for saving model (v2 allowed values: "torch_save" or "safetensors", v1 allowed values: None). -save_consolidated (bool): Whether to save consolidated checkpoints (for HF compatibility). -model_cache_dir (str): Directory for model cache (for safetensors format). -model_repo_id (str): Repository ID for the model (for safetensors format). -is_peft (bool): Whether the model uses PEFT. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.utils.checkpoint._load_checkpoint_history( - checkpoint_dir: pathlib.Path -) -> list[tuple[int, nemo_rl.utils.checkpoint.PathLike, dict[str, typing.Any]]] -``` - - - - - - -Load the history of checkpoints and their metrics. - -**Parameters:** - - -Directory containing the checkpoints. - - -**Returns:** `list[tuple[int, PathLike, dict[str, Any]]]` - -list[tuple[int, PathLike, dict[str, Any]]]: List of tuples containing -(step_number, checkpoint_path, info) for each checkpoint. - - - - - - - - -```python -nemo_rl.utils.checkpoint.PathLike = Union[str, 'os.PathLike[Any]'] -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx deleted file mode 100644 index ba6ab36..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx +++ /dev/null @@ -1,266 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/config -title: nemo_rl.utils.config ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`OverridesError`](#nemo_rl-utils-config-OverridesError) | Custom exception for Hydra override parsing errors. | - -### Functions - -| Name | Description | -|------|-------------| -| [`load_config`](#nemo_rl-utils-config-load_config) | Load a config file with inheritance support and convert it to an OmegaConf object. | -| [`load_config_with_inheritance`](#nemo_rl-utils-config-load_config_with_inheritance) | Load a config file with inheritance support. | -| [`merge_with_override`](#nemo_rl-utils-config-merge_with_override) | Merge configs with support for _override_ marker to completely override sections. | -| [`parse_hydra_overrides`](#nemo_rl-utils-config-parse_hydra_overrides) | Parse and apply Hydra overrides to an OmegaConf config. | -| [`register_omegaconf_resolvers`](#nemo_rl-utils-config-register_omegaconf_resolvers) | Register shared OmegaConf resolvers used in configs. | -| [`resolve_path`](#nemo_rl-utils-config-resolve_path) | Resolve a path relative to the base path. | - -### API - - - - - -```python -class nemo_rl.utils.config.OverridesError() -``` - - - - - - -Exception - -**Bases:** `Exception` - -Custom exception for Hydra override parsing errors. - - - - - - - - -```python -nemo_rl.utils.config.load_config( - config_path: typing.Union[str, pathlib.Path] -) -> omegaconf.DictConfig -``` - - - - - - -Load a config file with inheritance support and convert it to an OmegaConf object. - -The config inheritance system supports: - -1. Single inheritance: - ```python - # child.yaml - defaults: parent.yaml - common: - value: 43 - ``` - -2. Multiple inheritance: - ```python - # child.yaml - defaults: - - parent1.yaml - - parent2.yaml - common: - value: 44 - ``` - -3. Nested inheritance: - ```python - # parent.yaml - defaults: grandparent.yaml - common: - value: 43 - - # child.yaml - defaults: parent.yaml - common: - value: 44 - ``` - -4. Variable interpolation: - ```python - # parent.yaml - base_value: 42 - derived: - value: ${base_value} - - # child.yaml - defaults: parent.yaml - base_value: 43 # This will update both base_value and derived.value - ``` - -The system handles: -- Relative and absolute paths -- Multiple inheritance -- Nested inheritance -- Variable interpolation - -The inheritance is resolved depth-first, with later configs overriding earlier ones. -This means in multiple inheritance, the last config in the list takes precedence. - -**Parameters:** - - -Path to the config file - - -**Returns:** `DictConfig` - -Merged config dictionary - - - - - - - - -```python -nemo_rl.utils.config.load_config_with_inheritance( - config_path: typing.Union[str, pathlib.Path], - base_dir: typing.Optional[typing.Union[str, pathlib.Path]] = None -) -> omegaconf.DictConfig -``` - - - - - - -Load a config file with inheritance support. - -**Parameters:** - - -Path to the config file - - - -Base directory for resolving relative paths. If None, uses config_path's directory - - -**Returns:** `DictConfig` - -Merged config dictionary - - - - - - - - -```python -nemo_rl.utils.config.merge_with_override( - base_config: omegaconf.DictConfig, - override_config: omegaconf.DictConfig -) -> omegaconf.DictConfig -``` - - - - - - -Merge configs with support for _override_ marker to completely override sections. - - - - - - - - -```python -nemo_rl.utils.config.parse_hydra_overrides( - cfg: omegaconf.DictConfig, - overrides: list[str] -) -> omegaconf.DictConfig -``` - - - - - - -Parse and apply Hydra overrides to an OmegaConf config. - -**Parameters:** - - -OmegaConf config to apply overrides to - - - -List of Hydra override strings - - -**Returns:** `DictConfig` - -Updated config with overrides applied - -**Raises:** - -- `OverridesError`: If there's an error parsing or applying overrides - - - - - - - - -```python -nemo_rl.utils.config.register_omegaconf_resolvers() -> None -``` - - - - - - -Register shared OmegaConf resolvers used in configs. - - - - - - - - -```python -nemo_rl.utils.config.resolve_path( - base_path: pathlib.Path, - path: str -) -> pathlib.Path -``` - - - - - - -Resolve a path relative to the base path. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx deleted file mode 100644 index a0ee8f2..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx +++ /dev/null @@ -1,501 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/flops_formulas -title: nemo_rl.utils.flops_formulas ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`FLOPSConfig`](#nemo_rl-utils-flops_formulas-FLOPSConfig) | Contains the model hparams needed for FLOPS computations. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_hybrid_model_flops`](#nemo_rl-utils-flops_formulas-_hybrid_model_flops) | Model FLOPs for hybrid model. | -| [`_mamba_layer_flops`](#nemo_rl-utils-flops_formulas-_mamba_layer_flops) | Model FLOPs for Mamba layer. We ignore part of the flops of scan because the chunk size is not known from model config. | -| [`_mlp_layer_flops`](#nemo_rl-utils-flops_formulas-_mlp_layer_flops) | Model FLOPs for MLP layer. | -| [`_non_mla_attn_layer_flops`](#nemo_rl-utils-flops_formulas-_non_mla_attn_layer_flops) | Model FLOPs for attention layer. | -| [`bert`](#nemo_rl-utils-flops_formulas-bert) | Model FLOPs for BERT family. | -| [`deepseekv3`](#nemo_rl-utils-flops_formulas-deepseekv3) | Model FLOPs for DeepSeek V3. | -| [`flux`](#nemo_rl-utils-flops_formulas-flux) | Model FLOPs for FLUX. | -| [`gpt3`](#nemo_rl-utils-flops_formulas-gpt3) | Model FLOPs for GPT3 family. | -| [`llama`](#nemo_rl-utils-flops_formulas-llama) | Model FLOPs for llama3 family. | -| [`mixtral`](#nemo_rl-utils-flops_formulas-mixtral) | Model FLOPs for mixtral family. | -| [`nemotron`](#nemo_rl-utils-flops_formulas-nemotron) | Model FLOPs for nemotron family. | -| [`nemotronh`](#nemo_rl-utils-flops_formulas-nemotronh) | Model FLOPs for NemotronH. | -| [`qwen2`](#nemo_rl-utils-flops_formulas-qwen2) | Model FLOPs for Qwen2 family. | -| [`qwen3`](#nemo_rl-utils-flops_formulas-qwen3) | Model FLOPs for Qwen3 family. | -| [`transformer`](#nemo_rl-utils-flops_formulas-transformer) | Calculate FLOPs for a standard Transformer model. | - -### API - - - - - -```python -class nemo_rl.utils.flops_formulas.FLOPSConfig( - gbs: int, - enc_seq_len: typing.Optional[int] = None, - hs: typing.Optional[int] = None, - layers: typing.Optional[int] = None, - ffn_hs: typing.Optional[int] = None, - attention_heads: typing.Optional[int] = None, - moe_router_topk: typing.Optional[int] = None, - query_groups: typing.Optional[int] = None, - img_seq_len: typing.Optional[int] = None, - img_h: typing.Optional[int] = None, - img_w: typing.Optional[int] = None, - in_channels: typing.Optional[int] = None, - patch_dim: typing.Optional[int] = None, - class_token_len: typing.Optional[int] = None, - projector_type: typing.Optional[str] = None, - inp_s: typing.Optional[int] = None, - model_pattern: typing.Optional[str] = None, - vocab_size: typing.Optional[int] = None, - model_channels: typing.Optional[int] = None, - vec_in_dim: typing.Optional[int] = None, - q_lora_rank: typing.Optional[int] = None, - kv_lora_rank: typing.Optional[int] = None, - qk_head_dim: typing.Optional[int] = None, - qk_pos_emb_head_dim: typing.Optional[int] = None, - v_head_dim: typing.Optional[int] = None, - moe_layer_freq: typing.Optional[typing.Union[int, typing.List[int]]] = None, - moe_shared_expert_intermediate_size: typing.Optional[int] = None, - moe_ffn_hidden_size: typing.Optional[int] = None, - mtp_num_layers: typing.Optional[int] = None, - causal_self_attn: typing.Optional[bool] = None, - is_hybrid_model: bool = False, - hybrid_override_pattern: typing.Optional[str] = None, - mamba_state_dim: typing.Optional[int] = None, - mamba_head_dim: typing.Optional[int] = None, - mamba_num_groups: typing.Optional[int] = None, - mamba_num_heads: typing.Optional[int] = None -) -``` - - - - - - -Dataclass - -Contains the model hparams needed for FLOPS computations. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.utils.flops_formulas._hybrid_model_flops( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for hybrid model. - - - - - - - - -```python -nemo_rl.utils.flops_formulas._mamba_layer_flops( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for Mamba layer. We ignore part of the flops of scan because the chunk size is not known from model config. - - - - - - - - -```python -nemo_rl.utils.flops_formulas._mlp_layer_flops( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for MLP layer. - - - - - - - - -```python -nemo_rl.utils.flops_formulas._non_mla_attn_layer_flops( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for attention layer. - - - - - - - - -```python -nemo_rl.utils.flops_formulas.bert( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for BERT family. - - - - - - - - -```python -nemo_rl.utils.flops_formulas.deepseekv3( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for DeepSeek V3. - - - - - - - - -```python -nemo_rl.utils.flops_formulas.flux( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for FLUX. - - - - - - - - -```python -nemo_rl.utils.flops_formulas.gpt3( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for GPT3 family. - - - - - - - - -```python -nemo_rl.utils.flops_formulas.llama( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for llama3 family. - - - - - - - - -```python -nemo_rl.utils.flops_formulas.mixtral( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for mixtral family. - - - - - - - - -```python -nemo_rl.utils.flops_formulas.nemotron( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for nemotron family. - - - - - - - - -```python -nemo_rl.utils.flops_formulas.nemotronh( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for NemotronH. - - - - - - - - -```python -nemo_rl.utils.flops_formulas.qwen2( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for Qwen2 family. - - - - - - - - -```python -nemo_rl.utils.flops_formulas.qwen3( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Model FLOPs for Qwen3 family. - - - - - - - - -```python -nemo_rl.utils.flops_formulas.transformer( - config: nemo_rl.utils.flops_formulas.FLOPSConfig -) -``` - - - - - - -Calculate FLOPs for a standard Transformer model. - -Note: This does not cover encoder-decoder models. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx deleted file mode 100644 index 1965cd5..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx +++ /dev/null @@ -1,215 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/flops_tracker -title: nemo_rl.utils.flops_tracker ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`FLOPTracker`](#nemo_rl-utils-flops_tracker-FLOPTracker) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`convert_config_to_flops_config`](#nemo_rl-utils-flops_tracker-convert_config_to_flops_config) | Convert a pretrained config to a tuple containing a FLOPSConfig and a flops formula. | -| [`get_default_hf_config`](#nemo_rl-utils-flops_tracker-get_default_hf_config) | Get the default Hugging Face config for a model. | -| [`get_theoretical_tflops`](#nemo_rl-utils-flops_tracker-get_theoretical_tflops) | Get the theoretical total flops for a device name. | -| [`is_using_tf32`](#nemo_rl-utils-flops_tracker-is_using_tf32) | Check if the current device is using TF32. | - -### Data - -[`THEORETICAL_TFLOPS`](#nemo_rl-utils-flops_tracker-THEORETICAL_TFLOPS) - -### API - - - - - -```python -class nemo_rl.utils.flops_tracker.FLOPTracker( - model_name: str, - base_config: nemo_rl.utils.flops_formulas.FLOPSConfig | None = None, - flops_formula: typing.Callable[[FLOPSConfig], float] | None = None -) -``` - - - - - - - - - - - - -```python -nemo_rl.utils.flops_tracker.FLOPTracker.from_config( - model_name: str, - config: transformers.configuration_utils.PretrainedConfig -) -> nemo_rl.utils.flops_tracker.FLOPTracker -``` - - - - - - -classmethod - - - - - - - -```python -nemo_rl.utils.flops_tracker.FLOPTracker.reset() -``` - - - - - - - - - - - - -```python -nemo_rl.utils.flops_tracker.FLOPTracker.track( - n_samples: int, - padded_seq_len: int -) -``` - - - - - - - - - - - - -```python -nemo_rl.utils.flops_tracker.FLOPTracker.track_batch( - sequence_lengths: list[int] -) -``` - - - - - - -Track the flops for a batch of sequences. - - - - - - - - - -```python -nemo_rl.utils.flops_tracker.convert_config_to_flops_config( - config: transformers.configuration_utils.PretrainedConfig -) -> tuple[nemo_rl.utils.flops_formulas.FLOPSConfig, typing.Callable] -``` - - - - - - -Convert a pretrained config to a tuple containing a FLOPSConfig and a flops formula. - - - - - - - - -```python -nemo_rl.utils.flops_tracker.get_default_hf_config( - model_name: str -) -> transformers.configuration_utils.PretrainedConfig -``` - - - - - - -Get the default Hugging Face config for a model. - -Both the DTensor and MCore paths use the same default config, we initialize the model config -here to allow computation of theoretical flops which is agnostic to the backend. - - - - - - - - -```python -nemo_rl.utils.flops_tracker.get_theoretical_tflops( - device_name: str, - model_dtype: torch.dtype -) -> float -``` - - - - - - -Get the theoretical total flops for a device name. - - - - - - - - -```python -nemo_rl.utils.flops_tracker.is_using_tf32() -> bool -``` - - - - - - -Check if the current device is using TF32. - - - - - - - - -```python -nemo_rl.utils.flops_tracker.THEORETICAL_TFLOPS = {('NVIDIA A100 80GB PCIe', torch.bfloat16): 624 / 2, ('NVIDIA A100 80GB PCIe', t... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx deleted file mode 100644 index b78a4b3..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx +++ /dev/null @@ -1,1856 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/logger -title: nemo_rl.utils.logger ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`GPUMonitoringConfig`](#nemo_rl-utils-logger-GPUMonitoringConfig) | - | -| [`GpuMetricSnapshot`](#nemo_rl-utils-logger-GpuMetricSnapshot) | - | -| [`Logger`](#nemo_rl-utils-logger-Logger) | Main logger class that delegates to multiple backend loggers. | -| [`LoggerConfig`](#nemo_rl-utils-logger-LoggerConfig) | - | -| [`LoggerInterface`](#nemo_rl-utils-logger-LoggerInterface) | Abstract base class for logger backends. | -| [`MLflowConfig`](#nemo_rl-utils-logger-MLflowConfig) | - | -| [`MLflowLogger`](#nemo_rl-utils-logger-MLflowLogger) | MLflow logger backend. | -| [`RayGpuMonitorLogger`](#nemo_rl-utils-logger-RayGpuMonitorLogger) | Monitor GPU utilization across a Ray cluster and log metrics to a parent logger. | -| [`SwanlabConfig`](#nemo_rl-utils-logger-SwanlabConfig) | - | -| [`SwanlabLogger`](#nemo_rl-utils-logger-SwanlabLogger) | SwanLab logger backend. | -| [`TensorboardConfig`](#nemo_rl-utils-logger-TensorboardConfig) | - | -| [`TensorboardLogger`](#nemo_rl-utils-logger-TensorboardLogger) | Tensorboard logger backend. | -| [`WandbConfig`](#nemo_rl-utils-logger-WandbConfig) | - | -| [`WandbLogger`](#nemo_rl-utils-logger-WandbLogger) | Weights & Biases logger backend. | - -### Functions - -| Name | Description | -|------|-------------| -| [`configure_rich_logging`](#nemo_rl-utils-logger-configure_rich_logging) | Configure rich logging for more visually appealing log output. | -| [`flatten_dict`](#nemo_rl-utils-logger-flatten_dict) | Flatten a nested dictionary. | -| [`get_next_experiment_dir`](#nemo_rl-utils-logger-get_next_experiment_dir) | Create a new experiment directory with an incremented ID. | -| [`print_message_log_samples`](#nemo_rl-utils-logger-print_message_log_samples) | Visualization for message logs and rewards using a more visual approach with emoji indicators and horizontal layout. | - -### Data - -[`_rich_logging_configured`](#nemo_rl-utils-logger-_rich_logging_configured) - -### API - - - - - -```python -class nemo_rl.utils.logger.GPUMonitoringConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.utils.logger.GpuMetricSnapshot -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.utils.logger.Logger( - cfg: nemo_rl.utils.logger.LoggerConfig -) -``` - - - - - - -**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) - -Main logger class that delegates to multiple backend loggers. - - - - - - - - - - - -```python -nemo_rl.utils.logger.Logger.__del__() -> None -``` - - - - - - -Clean up resources when the logger is destroyed. - - - - - - - -```python -nemo_rl.utils.logger.Logger.log_batched_dict_as_jsonl( - to_log: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] | dict[str, typing.Any], - filename: str -) -> None -``` - - - - - - -Log a list of dictionaries to a JSONL file. - -**Parameters:** - - -BatchedDataDict to log - - - -Filename to log to (within the log directory) - - - - - - - - -```python -nemo_rl.utils.logger.Logger.log_histogram( - histogram: list[typing.Any], - step: int, - name: str -) -> None -``` - - - - - - -Log histogram metrics to all backends if available. - -**Parameters:** - - -List of histogram values - - - -Global step value - - - -Name of the metric - - - - - - - - -```python -nemo_rl.utils.logger.Logger.log_hyperparams( - params: typing.Mapping[str, typing.Any] -) -> None -``` - - - - - - -Log hyperparameters to all enabled backends. - -**Parameters:** - - -Dict of hyperparameters to log - - - - - - - - -```python -nemo_rl.utils.logger.Logger.log_metrics( - metrics: dict[str, typing.Any], - step: int, - prefix: typing.Optional[str] = '', - step_metric: typing.Optional[str] = None, - step_finished: bool = False -) -> None -``` - - - - - - -Log metrics to all enabled backends. - -**Parameters:** - - -Dict of metrics to log - - - -Global step value - - - -Optional prefix for metric names - - - -Optional name of a field in metrics to use as step instead - of the provided step value (currently only needed for wandb) - - - - - - - - -```python -nemo_rl.utils.logger.Logger.log_plot( - figure: matplotlib.pyplot.Figure, - step: int, - name: str -) -> None -``` - - - - - - -Log a matplotlib figure to all backends. - -**Parameters:** - - -Matplotlib figure to log - - - -Global step value - - - -Name of the plot - - - - - - - - -```python -nemo_rl.utils.logger.Logger.log_plot_per_worker_timeline_metrics( - metrics: dict[int, list[typing.Any]], - step: int, - prefix: str, - name: str, - timeline_interval: float -) -> None -``` - - - - - - -Log a plot of per-worker timeline metrics. - -**Parameters:** - - -Dictionary of metrics to log, where the keys are the worker IDs and the values are the lists of metric values - - - -dict[str, list[Any]] = {worker_id: [metric_value_1, metric_value_2, ...]} - - - -Global step value - - - -Name of the plot - - - -Interval between timeline points (in seconds) - - - - - - - - -```python -nemo_rl.utils.logger.Logger.log_plot_token_mult_prob_error( - data: dict[str, typing.Any], - step: int, - name: str -) -> None -``` - - - - - - -Log a plot of log probability errors in samples. - -This function logs & plots the per-token log-probabilities and errors over the sequence -for the sample with the highest multiplicative probability error in the batch. - -**Parameters:** - - -Dictionary of log probability samples - - - -Global step value - - - -Name of the plot - - - - - - - - -```python -nemo_rl.utils.logger.Logger.log_string_list_as_jsonl( - to_log: list[str], - filename: str -) -> None -``` - - - - - - -Log a list of strings to a JSONL file. - -**Parameters:** - - -list of strings to log - - - -Filename to log to (within the log directory) - - - - - - - - - - -```python -class nemo_rl.utils.logger.LoggerConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.utils.logger.LoggerInterface() -``` - - - - - - -Abstract - -Abstract base class for logger backends. - - - - - - -```python -nemo_rl.utils.logger.LoggerInterface.log_histogram( - histogram: list[typing.Any], - step: int, - name: str -) -> None -``` - - - - - - -abstract - -Log histogram metrics. - - - - - - - -```python -nemo_rl.utils.logger.LoggerInterface.log_hyperparams( - params: typing.Mapping[str, typing.Any] -) -> None -``` - - - - - - -abstract - -Log dictionary of hyperparameters. - - - - - - - -```python -nemo_rl.utils.logger.LoggerInterface.log_metrics( - metrics: dict[str, typing.Any], - step: int, - prefix: typing.Optional[str] = '', - step_metric: typing.Optional[str] = None, - step_finished: bool = False -) -> None -``` - - - - - - -abstract - -Log a dictionary of metrics. - - - - - - - -```python -nemo_rl.utils.logger.LoggerInterface.log_plot( - figure: matplotlib.pyplot.Figure, - step: int, - name: str -) -> None -``` - - - - - - -abstract - -Log a matplotlib figure. - - - - - - - - - -```python -class nemo_rl.utils.logger.MLflowConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -class nemo_rl.utils.logger.MLflowLogger( - cfg: nemo_rl.utils.logger.MLflowConfig, - log_dir: typing.Optional[str] = None -) -``` - - - - - - -**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) - -MLflow logger backend. - - - - - - - - -```python -nemo_rl.utils.logger.MLflowLogger.__del__() -> None -``` - - - - - - -Clean up resources when the logger is destroyed. - - - - - - - -```python -nemo_rl.utils.logger.MLflowLogger.log_histogram( - histogram: list[typing.Any], - step: int, - name: str -) -> None -``` - - - - - - -Log histogram metrics to MLflow. - - - - - - - -```python -nemo_rl.utils.logger.MLflowLogger.log_hyperparams( - params: typing.Mapping[str, typing.Any] -) -> None -``` - - - - - - -Log hyperparameters to MLflow. - -**Parameters:** - - -Dictionary of hyperparameters to log - - - - - - - - -```python -nemo_rl.utils.logger.MLflowLogger.log_metrics( - metrics: dict[str, typing.Any], - step: int, - prefix: typing.Optional[str] = '', - step_metric: typing.Optional[str] = None, - step_finished: bool = False -) -> None -``` - - - - - - -Log metrics to MLflow. - -**Parameters:** - - -Dict of metrics to log - - - -Global step value - - - -Optional prefix for metric names - - - -Optional step metric name (ignored in MLflow) - - - - - - - - -```python -nemo_rl.utils.logger.MLflowLogger.log_plot( - figure: matplotlib.pyplot.Figure, - step: int, - name: str -) -> None -``` - - - - - - -Log a plot to MLflow. - -**Parameters:** - - -Matplotlib figure to log - - - -Global step value - - - -Name of the plot - - - - - - - - - - -```python -class nemo_rl.utils.logger.RayGpuMonitorLogger( - collection_interval: int | float, - flush_interval: int | float, - metric_prefix: str, - step_metric: str, - parent_logger: typing.Optional[nemo_rl.utils.logger.Logger] = None -) -``` - - - - - - -Monitor GPU utilization across a Ray cluster and log metrics to a parent logger. - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.utils.logger.RayGpuMonitorLogger._collect( - metrics: bool = False, - sku: bool = False -) -> dict[str, typing.Any] -``` - - - - - - -Collect GPU metrics from all Ray nodes. - -**Returns:** `dict[str, Any]` - -Dictionary of collected metrics - - - - - - - -```python -nemo_rl.utils.logger.RayGpuMonitorLogger._collect_gpu_sku() -> dict[str, str] -``` - - - - - - -Collect GPU SKU from all Ray nodes. - -Note: This is an internal API and users are not expected to call this. - -**Returns:** `dict[str, str]` - -Dictionary of SKU types on all Ray nodes - - - - - - - -```python -nemo_rl.utils.logger.RayGpuMonitorLogger._collect_metrics() -> dict[str, typing.Any] -``` - - - - - - -Collect GPU metrics from all Ray nodes. - -**Returns:** `dict[str, Any]` - -Dictionary of collected metrics - - - - - - - -```python -nemo_rl.utils.logger.RayGpuMonitorLogger._collection_loop() -> None -``` - - - - - - -Main collection loop that runs in a separate thread. - - - - - - - -```python -nemo_rl.utils.logger.RayGpuMonitorLogger._fetch_and_parse_metrics( - node_idx: int, - metric_address: str, - parser_fn: typing.Callable -) -``` - - - - - - -Fetch metrics from a node and parse GPU metrics. - -**Parameters:** - - -Index of the node - - - -Address of the metrics endpoint - - -**Returns:** - -Dictionary of GPU metrics - - - - - - - -```python -nemo_rl.utils.logger.RayGpuMonitorLogger._parse_gpu_sku( - sample: prometheus_client.samples.Sample, - node_idx: int -) -> dict[str, str] -``` - - - - - - -Parse a GPU metric sample into a standardized format. - -**Parameters:** - - -Prometheus metric sample - - - -Index of the node - - -**Returns:** `dict[str, str]` - -Dictionary with metric name and value - - - - - - - -```python -nemo_rl.utils.logger.RayGpuMonitorLogger._parse_metric( - sample: prometheus_client.samples.Sample, - node_idx: int -) -> dict[str, typing.Any] -``` - - - - - - -Parse a metric sample into a standardized format. - -**Parameters:** - - -Prometheus metric sample - - - -Index of the node - - -**Returns:** `dict[str, Any]` - -Dictionary with metric name and value - - - - - - - -```python -nemo_rl.utils.logger.RayGpuMonitorLogger.flush() -> None -``` - - - - - - -Flush collected metrics to the parent logger. - - - - - - - -```python -nemo_rl.utils.logger.RayGpuMonitorLogger.start() -> None -``` - - - - - - -Start the GPU monitoring thread. - - - - - - - -```python -nemo_rl.utils.logger.RayGpuMonitorLogger.stop() -> None -``` - - - - - - -Stop the GPU monitoring thread. - - - - - - - - - -```python -class nemo_rl.utils.logger.SwanlabConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.utils.logger.SwanlabLogger( - cfg: nemo_rl.utils.logger.SwanlabConfig, - log_dir: typing.Optional[str] = None -) -``` - - - - - - -**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) - -SwanLab logger backend. - - - - - - - - -```python -nemo_rl.utils.logger.SwanlabLogger.log_histogram( - histogram: list[typing.Any], - step: int, - name: str -) -> None -``` - - - - - - -Log histogram metrics to swanlab. - - - - - - - -```python -nemo_rl.utils.logger.SwanlabLogger.log_hyperparams( - params: typing.Mapping[str, typing.Any] -) -> None -``` - - - - - - -Update the Swanlab run configuration with the provided hyperparameters. - -**Parameters:** - - -Mapping of hyperparameter names to values to store in the run configuration. - - - - - - - - -```python -nemo_rl.utils.logger.SwanlabLogger.log_metrics( - metrics: dict[str, typing.Any], - step: int, - prefix: typing.Optional[str] = '', - step_metric: typing.Optional[str] = None, - step_finished: bool = False -) -> None -``` - - - - - - -Log metrics to the associated Swanlab run. - -**Parameters:** - - -Mapping of metric names to metric values. - - - -Global step value to associate with all logged metrics. - - - -Optional prefix applied to metric names; metric names equal to `step_metric` are not prefixed. - - - -Name of a metric that should be excluded from prefixing. - - - - - - - - -```python -nemo_rl.utils.logger.SwanlabLogger.log_plot( - figure: matplotlib.pyplot.Figure, - step: int, - name: str -) -> None -``` - - - - - - -Log a plot to swanlab. - -**Parameters:** - - -Matplotlib figure to log - - - -Global step value - - - - - - - - - - -```python -class nemo_rl.utils.logger.TensorboardConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - -```python -class nemo_rl.utils.logger.TensorboardLogger( - cfg: nemo_rl.utils.logger.TensorboardConfig, - log_dir: typing.Optional[str] = None -) -``` - - - - - - -**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) - -Tensorboard logger backend. - - - - - - - - -```python -nemo_rl.utils.logger.TensorboardLogger._coerce_to_scalar( - value: typing.Any -) -> int | float | bool | str | None -``` - - - - - - -staticmethod - -Coerce a value to a Python scalar for TensorBoard logging. - -Returns the coerced value, or None if it can't be converted to a scalar. - - - - - - - -```python -nemo_rl.utils.logger.TensorboardLogger.log_histogram( - histogram: list[typing.Any], - step: int, - name: str -) -> None -``` - - - - - - -Log histogram metrics to Tensorboard. - - - - - - - -```python -nemo_rl.utils.logger.TensorboardLogger.log_hyperparams( - params: typing.Mapping[str, typing.Any] -) -> None -``` - - - - - - -Log hyperparameters to Tensorboard. - -**Parameters:** - - -Dictionary of hyperparameters to log - - - - - - - - -```python -nemo_rl.utils.logger.TensorboardLogger.log_metrics( - metrics: dict[str, typing.Any], - step: int, - prefix: typing.Optional[str] = '', - step_metric: typing.Optional[str] = None, - step_finished: bool = False -) -> None -``` - - - - - - -Log metrics to Tensorboard. - -**Parameters:** - - -Dict of metrics to log - - - -Global step value - - - -Optional prefix for metric names - - - -Optional step metric name (ignored in TensorBoard) - - - - - - - - -```python -nemo_rl.utils.logger.TensorboardLogger.log_plot( - figure: matplotlib.pyplot.Figure, - step: int, - name: str -) -> None -``` - - - - - - -Log a plot to Tensorboard. - -**Parameters:** - - -Dictionary of plot data - - - -Global step value - - - - - - - - - - -```python -class nemo_rl.utils.logger.WandbConfig -``` - - - - - - -**Bases:** `typing.TypedDict` - - - - - - - - - - - - - - - -```python -class nemo_rl.utils.logger.WandbLogger( - cfg: nemo_rl.utils.logger.WandbConfig, - log_dir: typing.Optional[str] = None -) -``` - - - - - - -**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) - -Weights & Biases logger backend. - - - - - - - - -```python -nemo_rl.utils.logger.WandbLogger._log_code() -``` - - - - - - -Log code that is tracked by git to wandb. - -This function gets a list of all files tracked by git in the project root -and manually uploads them to the current wandb run as an artifact. - - - - - - - -```python -nemo_rl.utils.logger.WandbLogger._log_diffs() -``` - - - - - - -Log git diffs to wandb. - -This function captures and logs two types of diffs: -1. Uncommitted changes (working tree diff against HEAD) -2. All changes (including uncommitted) against the main branch - -Each diff is saved as a text file in a wandb artifact. - - - - - - - -```python -nemo_rl.utils.logger.WandbLogger.define_metric( - name: str, - step_metric: typing.Optional[str] = None -) -> None -``` - - - - - - -Define a metric with custom step metric. - -**Parameters:** - - -Name of the metric or pattern (e.g. 'ray/*') - - - -Optional name of the step metric to use - - - - - - - - -```python -nemo_rl.utils.logger.WandbLogger.log_histogram( - histogram: list[typing.Any], - step: int, - name: str -) -> None -``` - - - - - - -Log histogram metrics to wandb. - -**Parameters:** - - -List of histogram values - - - -Global step value - - - -Name of the metric - - - - - - - - -```python -nemo_rl.utils.logger.WandbLogger.log_hyperparams( - params: typing.Mapping[str, typing.Any] -) -> None -``` - - - - - - -Log hyperparameters to wandb. - -**Parameters:** - - -Dict of hyperparameters to log - - - - - - - - -```python -nemo_rl.utils.logger.WandbLogger.log_metrics( - metrics: dict[str, typing.Any], - step: int, - prefix: typing.Optional[str] = '', - step_metric: typing.Optional[str] = None, - step_finished: bool = False -) -> None -``` - - - - - - -Log metrics to wandb. - -**Parameters:** - - -Dict of metrics to log - - - -Global step value - - - -Optional prefix for metric names - - - -Optional name of a field in metrics to use as step instead - of the provided step value - - - - - - - - -```python -nemo_rl.utils.logger.WandbLogger.log_plot( - figure: matplotlib.pyplot.Figure, - step: int, - name: str -) -> None -``` - - - - - - -Log a plot to wandb. - -**Parameters:** - - -Matplotlib figure to log - - - -Global step value - - - - - - - - - - -```python -nemo_rl.utils.logger.configure_rich_logging( - level: str = 'INFO', - show_time: bool = True, - show_path: bool = True -) -> None -``` - - - - - - -Configure rich logging for more visually appealing log output. - -**Parameters:** - - -The logging level to use - - - -Whether to show timestamps in logs - - - -Whether to show file paths in logs - - - - - - - - - -```python -nemo_rl.utils.logger.flatten_dict( - d: typing.Mapping[str, typing.Any], - sep: str = '.' -) -> dict[str, typing.Any] -``` - - - - - - -Flatten a nested dictionary. - -Handles nested dictionaries and lists by creating keys with separators. -For lists, the index is used as part of the key. - -**Parameters:** - - -Dictionary to flatten - - - -Separator to use between nested keys - - -**Returns:** `dict[str, Any]` - -Flattened dictionary with compound keys - -**Examples:** - - - -```python ->>> from nemo_rl.utils.logger import flatten_dict ->>> flatten_dict({"a": 1, "b": {"c": 2}}) -{'a': 1, 'b.c': 2} - ->>> flatten_dict({"a": [1, 2], "b": {"c": [3, 4]}}) -{'a.0': 1, 'a.1': 2, 'b.c.0': 3, 'b.c.1': 4} - ->>> flatten_dict({"a": [{"b": 1}, {"c": 2}]}) -{'a.0.b': 1, 'a.1.c': 2} -``` - - - - - - - - - - -```python -nemo_rl.utils.logger.get_next_experiment_dir( - base_log_dir: str -) -> str -``` - - - - - - -Create a new experiment directory with an incremented ID. - -**Parameters:** - - -The base log directory path - - -**Returns:** `str` - -Path to the new experiment directory with incremented ID - - - - - - - - -```python -nemo_rl.utils.logger.print_message_log_samples( - message_logs: list[nemo_rl.data.interfaces.LLMMessageLogType], - rewards: list[float], - num_samples: int = 5, - step: int = 0 -) -> None -``` - - - - - - -Visualization for message logs and rewards using a more visual approach with emoji indicators and horizontal layout. - -**Parameters:** - - -List of message logs to sample from - - - -List of rewards corresponding to each message log - - - -Number of samples to display (default: 5) - - - -Current training step (for display purposes) - - - - - - - - - -```python -nemo_rl.utils.logger._rich_logging_configured = False -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx deleted file mode 100644 index e06cd16..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx +++ /dev/null @@ -1,122 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/memory_tracker -title: nemo_rl.utils.memory_tracker ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`MemoryTracker`](#nemo_rl-utils-memory_tracker-MemoryTracker) | - | -| [`MemoryTrackerDataPoint`](#nemo_rl-utils-memory_tracker-MemoryTrackerDataPoint) | - | - -### API - - - - - -```python -class nemo_rl.utils.memory_tracker.MemoryTracker() -``` - - - - - - -**Bases:** `BaseModel` - - - - - - - -```python -nemo_rl.utils.memory_tracker.MemoryTracker.model_post_init( - context -) -``` - - - - - - - - - - - - -```python -nemo_rl.utils.memory_tracker.MemoryTracker.snapshot_start_of_stage( - new_stage: str, - all_current_variables: typing.List[str] -) -> None -``` - - - - - - - - - - - - - - -```python -class nemo_rl.utils.memory_tracker.MemoryTrackerDataPoint() -``` - - - - - - -**Bases:** `BaseModel` - - - - - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.utils.memory_tracker.MemoryTrackerDataPoint.get_snapshot_str() -> str -``` - - - - - - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx deleted file mode 100644 index 073652c..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx +++ /dev/null @@ -1,351 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/native_checkpoint -title: nemo_rl.utils.native_checkpoint ---- - -Checkpoint management utilities for HF models. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`ModelState`](#nemo_rl-utils-native_checkpoint-ModelState) | Helper class for tracking model state in distributed checkpointing. | -| [`OptimizerState`](#nemo_rl-utils-native_checkpoint-OptimizerState) | Helper class for tracking optimizer state in distributed checkpointing. | - -### Functions - -| Name | Description | -|------|-------------| -| [`convert_dcp_to_hf`](#nemo_rl-utils-native_checkpoint-convert_dcp_to_hf) | Convert a Torch DCP checkpoint to a Hugging Face checkpoint. | -| [`load_checkpoint`](#nemo_rl-utils-native_checkpoint-load_checkpoint) | Load a model weights and optionally optimizer state. | -| [`save_checkpoint`](#nemo_rl-utils-native_checkpoint-save_checkpoint) | Save a checkpoint of the model and optionally optimizer state. | - -### API - - - - - -```python -class nemo_rl.utils.native_checkpoint.ModelState( - model: torch.nn.Module -) -``` - - - - - - -**Bases:** `Stateful` - -Helper class for tracking model state in distributed checkpointing. - -This class is compliant with the Stateful protocol, allowing DCP to automatically -call state_dict/load_state_dict as needed in the dcp.save/load APIs. - -**Parameters:** - - -The PyTorch model to track. - - - - - - - -```python -nemo_rl.utils.native_checkpoint.ModelState.load_state_dict( - state_dict: dict[str, typing.Any] -) -> None -``` - - - - - - -Load the state dictionary into the model. - -**Parameters:** - - -State dictionary to load. - - - - - - - - -```python -nemo_rl.utils.native_checkpoint.ModelState.state_dict() -> dict[str, typing.Any] -``` - - - - - - -Get the model's state dictionary. - -**Returns:** `dict[str, Any]` - -Dictionary containing the model's state dict with CPU offloading enabled. - - - - - - - - - -```python -class nemo_rl.utils.native_checkpoint.OptimizerState( - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: typing.Optional[typing.Any] = None -) -``` - - - - - - -**Bases:** `Stateful` - -Helper class for tracking optimizer state in distributed checkpointing. - -This class is compliant with the Stateful protocol, allowing DCP to automatically -call state_dict/load_state_dict as needed in the dcp.save/load APIs. - -**Parameters:** - - -The PyTorch model associated with the optimizer. - - - -The optimizer to track. - - - -Optional learning rate scheduler. - - - - - - - -```python -nemo_rl.utils.native_checkpoint.OptimizerState.load_state_dict( - state_dict: dict[str, typing.Any] -) -> None -``` - - - - - - -Load the state dictionaries into the optimizer and scheduler. - -**Parameters:** - - -State dictionary containing optimizer and scheduler states to load. - - - - - - - - -```python -nemo_rl.utils.native_checkpoint.OptimizerState.state_dict() -> dict[str, typing.Any] -``` - - - - - - -Get the optimizer and scheduler state dictionaries. - -**Returns:** `dict[str, Any]` - -Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled. - - - - - - - - - -```python -nemo_rl.utils.native_checkpoint.convert_dcp_to_hf( - dcp_ckpt_path: str, - hf_ckpt_path: str, - model_name_or_path: str, - tokenizer_name_or_path: str, - overwrite: bool = False, - hf_overrides: typing.Optional[dict[str, typing.Any]] = {} -) -> str -``` - - - - - - -Convert a Torch DCP checkpoint to a Hugging Face checkpoint. - -This is not an optimized utility. If checkpoint is too large, consider saving DCP during training -and using this utility to convert to HF format. - -**Parameters:** - - -Path to DCP checkpoint - - - -Path to save HF checkpoint - - - -Model name or path for config - - - -Tokenizer name or path. - Defaults to model_name_or_path if None. - - - -Whether to overwrite existing checkpoint. Defaults to False. - - -**Returns:** `str` - -Path to the saved HF checkpoint - -**Raises:** - -- `FileExistsError`: If HF checkpoint already exists and overwrite is False - - - - - - - - -```python -nemo_rl.utils.native_checkpoint.load_checkpoint( - model: torch.nn.Module, - weights_path: str, - optimizer: typing.Optional[torch.optim.Optimizer] = None, - scheduler: typing.Optional[typing.Any] = None, - optimizer_path: typing.Optional[str] = None -) -> None -``` - - - - - - -Load a model weights and optionally optimizer state. - -**Parameters:** - - -The PyTorch model whose weights to update - - - -Path to load model weights from - - - -Optional optimizer to load state into - - - -Optional scheduler to load state into - - - -Path to load optimizer state from (required if optimizer provided) - - - - - - - - - -```python -nemo_rl.utils.native_checkpoint.save_checkpoint( - model: torch.nn.Module, - weights_path: str, - optimizer: typing.Optional[torch.optim.Optimizer] = None, - scheduler: typing.Optional[typing.Any] = None, - optimizer_path: typing.Optional[str] = None, - tokenizer: typing.Optional[typing.Any] = None, - tokenizer_path: typing.Optional[str] = None -) -> None -``` - - - - - - -Save a checkpoint of the model and optionally optimizer state. - -**Parameters:** - - -The PyTorch model to save - - - -Path to save model weights - - - -Optional optimizer to save - - - -Optional scheduler to save - - - -Path to save optimizer state (required if optimizer provided) - - - -Optional tokenizer to save - - - -Path to save tokenizer state (required if tokenizer provided) - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx deleted file mode 100644 index 1c32117..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx +++ /dev/null @@ -1,138 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/nsys -title: nemo_rl.utils.nsys ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`ProfilablePolicy`](#nemo_rl-utils-nsys-ProfilablePolicy) | - | - -### Functions - -| Name | Description | -|------|-------------| -| [`maybe_gpu_profile_step`](#nemo_rl-utils-nsys-maybe_gpu_profile_step) | - | -| [`wrap_with_nvtx_name`](#nemo_rl-utils-nsys-wrap_with_nvtx_name) | A decorator to wrap a function with an NVTX range with the given name. | - -### Data - -[`NRL_NSYS_PROFILE_STEP_RANGE`](#nemo_rl-utils-nsys-NRL_NSYS_PROFILE_STEP_RANGE) - -[`NRL_NSYS_WORKER_PATTERNS`](#nemo_rl-utils-nsys-NRL_NSYS_WORKER_PATTERNS) - -### API - - - - - -```python -class nemo_rl.utils.nsys.ProfilablePolicy() -``` - - - - - - -Protocol - - - - - -```python -nemo_rl.utils.nsys.ProfilablePolicy.start_gpu_profiling() -> None -``` - - - - - - - - - - - - -```python -nemo_rl.utils.nsys.ProfilablePolicy.stop_gpu_profiling() -> None -``` - - - - - - - - - - - - - - -```python -nemo_rl.utils.nsys.maybe_gpu_profile_step( - policy: nemo_rl.utils.nsys.ProfilablePolicy, - step: int -) -``` - - - - - - - - - - - - - -```python -nemo_rl.utils.nsys.wrap_with_nvtx_name( - name: str -) -``` - - - - - - -A decorator to wrap a function with an NVTX range with the given name. - - - - - - - - -```python -nemo_rl.utils.nsys.NRL_NSYS_PROFILE_STEP_RANGE = os.environ.get('NRL_NSYS_PROFILE_STEP_RANGE', '') -``` - - - - - - - - - -```python -nemo_rl.utils.nsys.NRL_NSYS_WORKER_PATTERNS = os.environ.get('NRL_NSYS_WORKER_PATTERNS', '') -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx deleted file mode 100644 index a8ada45..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx +++ /dev/null @@ -1,100 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/nvml -title: nemo_rl.utils.nvml ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`device_id_to_physical_device_id`](#nemo_rl-utils-nvml-device_id_to_physical_device_id) | Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES. | -| [`get_device_uuid`](#nemo_rl-utils-nvml-get_device_uuid) | Get the UUID of a CUDA device using NVML. | -| [`get_free_memory_bytes`](#nemo_rl-utils-nvml-get_free_memory_bytes) | Get the free memory of a CUDA device in bytes using NVML. | -| [`nvml_context`](#nemo_rl-utils-nvml-nvml_context) | Context manager for NVML initialization and shutdown. | - -### API - - - - - -```python -nemo_rl.utils.nvml.device_id_to_physical_device_id( - device_id: int -) -> int -``` - - - - - - -Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES. - - - - - - - - -```python -nemo_rl.utils.nvml.get_device_uuid( - device_idx: int -) -> str -``` - - - - - - -Get the UUID of a CUDA device using NVML. - - - - - - - - -```python -nemo_rl.utils.nvml.get_free_memory_bytes( - device_idx: int -) -> float -``` - - - - - - -Get the free memory of a CUDA device in bytes using NVML. - - - - - - - - -```python -nemo_rl.utils.nvml.nvml_context() -> typing.Generator[None, None, None] -``` - - - - - - -Context manager for NVML initialization and shutdown. - -**Raises:** - -- `RuntimeError`: If NVML initialization fails - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx deleted file mode 100644 index 786e38c..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx +++ /dev/null @@ -1,140 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/packed_tensor -title: nemo_rl.utils.packed_tensor ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`get_num_buffers`](#nemo_rl-utils-packed_tensor-get_num_buffers) | - | -| [`get_target_packed_tensor_size`](#nemo_rl-utils-packed_tensor-get_target_packed_tensor_size) | - | -| [`packed_broadcast_consumer`](#nemo_rl-utils-packed_tensor-packed_broadcast_consumer) | Consume a packed tensor and unpack it into a list of tensors. | -| [`packed_broadcast_producer`](#nemo_rl-utils-packed_tensor-packed_broadcast_producer) | Broadcast a list of tensors in a packed manner. | - -### API - - - - - -```python -nemo_rl.utils.packed_tensor.get_num_buffers() -``` - - - - - - - - - - - - - -```python -nemo_rl.utils.packed_tensor.get_target_packed_tensor_size() -``` - - - - - - - - - - - - - -```python -nemo_rl.utils.packed_tensor.packed_broadcast_consumer( - iterator, - group, - src, - post_unpack_func -) -``` - - - - - - -Consume a packed tensor and unpack it into a list of tensors. - -**Parameters:** - - -iterator of model parameters. Returns a tuple of (name, tensor) - - - -process group (vllm PyNcclCommunicator) - - - -source rank (0 in current implementation) - - - -function to apply to each tensor after unpacking - - -**Returns:** - -None - - - - - - - - -```python -nemo_rl.utils.packed_tensor.packed_broadcast_producer( - iterator, - group, - src, - post_iter_func -) -``` - - - - - - -Broadcast a list of tensors in a packed manner. - -**Parameters:** - - -iterator of model parameters. Returns a tuple of (name, tensor) - - - -process group (vllm PyNcclCommunicator) - - - -source rank (0 in current implementation) - - - -function to apply to each tensor before packing, should return a tensor - - -**Returns:** - -None - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx deleted file mode 100644 index b0fb043..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx +++ /dev/null @@ -1,108 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/prefetch_venvs -title: nemo_rl.utils.prefetch_venvs ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`create_frozen_environment_symlinks`](#nemo_rl-utils-prefetch_venvs-create_frozen_environment_symlinks) | Create python-{ClassName} wrapper scripts in /usr/local/bin for frozen environment support. | -| [`prefetch_venvs`](#nemo_rl-utils-prefetch_venvs-prefetch_venvs) | Prefetch all virtual environments that will be used by workers. | - -### Data - -[`args`](#nemo_rl-utils-prefetch_venvs-args) - -[`parser`](#nemo_rl-utils-prefetch_venvs-parser) - -### API - - - - - -```python -nemo_rl.utils.prefetch_venvs.create_frozen_environment_symlinks( - venv_configs -) -``` - - - - - - -Create python-{ClassName} wrapper scripts in /usr/local/bin for frozen environment support. - -Only runs in container (when NRL_CONTAINER=1 is set). - -**Parameters:** - - -Dictionary mapping py_executable to list of actor FQNs - - - - - - - - - -```python -nemo_rl.utils.prefetch_venvs.prefetch_venvs( - filters = None, - negative_filters = None -) -``` - - - - - - -Prefetch all virtual environments that will be used by workers. - -**Parameters:** - - -List of strings to match against actor FQNs. If provided, only - actors whose FQN contains at least one of the filter strings will - be prefetched. If None, all venvs are prefetched. - - - -List of strings to exclude from prefetching. Actors whose - FQN contains any of these strings will be skipped. - - - - - - - - - -```python -nemo_rl.utils.prefetch_venvs.args = parser.parse_args() -``` - - - - - - - - - -```python -nemo_rl.utils.prefetch_venvs.parser = argparse.ArgumentParser(description='Prefetch virtual environments for Ray actor... -``` - - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx deleted file mode 100644 index 13c7047..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx +++ /dev/null @@ -1,441 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/timer -title: nemo_rl.utils.timer ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`TimeoutChecker`](#nemo_rl-utils-timer-TimeoutChecker) | - | -| [`Timer`](#nemo_rl-utils-timer-Timer) | A utility for timing code execution. | - -### Functions - -| Name | Description | -|------|-------------| -| [`convert_to_seconds`](#nemo_rl-utils-timer-convert_to_seconds) | Converts a time string in the format 'DD:HH:MM:SS' to total seconds. | - -### API - - - - - -```python -class nemo_rl.utils.timer.TimeoutChecker( - timeout: typing.Optional[str] = '00:03:45:00', - fit_last_save_time: bool = False -) -``` - - - - - - - - - - - - - - - - - - - - - -```python -nemo_rl.utils.timer.TimeoutChecker.check_save() -``` - - - - - - - - - - - - -```python -nemo_rl.utils.timer.TimeoutChecker.mark_iteration() -``` - - - - - - - - - - - - -```python -nemo_rl.utils.timer.TimeoutChecker.start_iterations() -``` - - - - - - - - - - - - - - -```python -class nemo_rl.utils.timer.Timer() -``` - - - - - - -A utility for timing code execution. - -Supports two usage patterns: -1. Explicit start/stop: timer.start("label"), timer.stop("label") -2. Context manager: with timer.time("label"): ... - -The timer keeps track of multiple timing measurements for each label, -and supports different reductions on these measurements (mean, median, -min, max, std dev). - -Example usage: - - -```python -timer = Timer() - -# Method 1: start/stop -timer.start("load_data") -data = load_data() -timer.stop("load_data") - -# Method 2: context manager -with timer.time("model_forward"): - model_outputs = model(inputs) - -# Multiple timing measurements for the same operation -for batch in dataloader: - with timer.time("model_forward_multiple"): - outputs = model(batch) - -# Get all times for one label -model_forward_times = timer.get_elapsed("model_forward_multiple") - -# Get reductions for one label -mean_forward_time = timer.reduce("model_forward_multiple") -max_forward_time = timer.reduce("model_forward_multiple", "max") -``` - - - - - - - - - - - - - - - - -```python -nemo_rl.utils.timer.Timer.get_elapsed( - label: str -) -> list[float] -``` - - - - - - -Get all elapsed time measurements for a specific label. - -**Parameters:** - - -The timing label to get elapsed times for - - -**Returns:** `list[float]` - -A list of all elapsed time measurements in seconds - -**Raises:** - -- `KeyError`: If the label doesn't exist - - - - - - - -```python -nemo_rl.utils.timer.Timer.get_latest_elapsed( - label: str -) -> float -``` - - - - - - -Get the most recent elapsed time measurement for a specific label. - -**Parameters:** - - -The timing label to get the latest elapsed time for - - -**Returns:** `float` - -The most recent elapsed time measurement in seconds - -**Raises:** - -- `KeyError`: If the label doesn't exist -- `IndexError`: If the label exists but has no measurements - - - - - - - -```python -nemo_rl.utils.timer.Timer.get_timing_metrics( - reduction_op: typing.Union[str, dict[str, str]] = 'mean' -) -> dict[str, float | list[float]] -``` - - - - - - -Get all timing measurements with optional reduction. - -**Parameters:** - - -Either a string specifying a reduction operation to apply to all labels, - or a dictionary mapping specific labels to reduction operations. - Valid reduction operations are: "mean", "median", "min", "max", "std", "sum", "count". - If a label is not in the dictionary, no reduction is applied and all measurements are returned. - - -**Returns:** `dict[str, float | list[float]]` - -A dictionary mapping labels to either: - -**Raises:** - -- `ValueError`: If an invalid reduction operation is provided - - - - - - - -```python -nemo_rl.utils.timer.Timer.reduce( - label: str, - operation: str = 'mean' -) -> float -``` - - - - - - -Apply a reduction function to timing measurements for the specified label. - -**Parameters:** - - -The timing label to get reduction for - - - -The type of reduction to apply. Valid options are: -- "mean": Average time (default) -- "median": Median time -- "min": Minimum time -- "max": Maximum time -- "std": Standard deviation -- "sum": Total time -- "count": Number of measurements - - -**Returns:** `float` - -A single float with the reduction result - -**Raises:** - -- `KeyError`: If the label doesn't exist -- `ValueError`: If an invalid operation is provided - - - - - - - -```python -nemo_rl.utils.timer.Timer.reset( - label: typing.Optional[str] = None -) -> None -``` - - - - - - -Reset timings for the specified label or all labels. - -**Parameters:** - - -Optional label to reset. If None, resets all timers. - - - - - - - - -```python -nemo_rl.utils.timer.Timer.start( - label: str -) -> None -``` - - - - - - -Start timing for the given label. - - - - - - - -```python -nemo_rl.utils.timer.Timer.stop( - label: str -) -> float -``` - - - - - - -Stop timing for the given label and return the elapsed time. - -**Parameters:** - - -The label to stop timing for - - -**Returns:** `float` - -The elapsed time in seconds - -**Raises:** - -- `ValueError`: If the timer for the given label is not running - - - - - - - -```python -nemo_rl.utils.timer.Timer.time( - label: str -) -> typing.Generator[None, None, None] -``` - - - - - - -Context manager for timing a block of code. - -**Parameters:** - - -The label to use for this timing - - - - - - - - - - -```python -nemo_rl.utils.timer.convert_to_seconds( - time_string: str -) -> int -``` - - - - - - -Converts a time string in the format 'DD:HH:MM:SS' to total seconds. - -**Parameters:** - - -Time duration string, e.g., '00:03:45:00'. - - -**Returns:** `int` - -Total time in seconds. - - - diff --git a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx b/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx deleted file mode 100644 index f2e6973..0000000 --- a/fern/static/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx +++ /dev/null @@ -1,177 +0,0 @@ ---- -layout: overview -slug: nemo-rl/nemo_rl/utils/venvs -title: nemo_rl.utils.venvs ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`_env_builder`](#nemo_rl-utils-venvs-_env_builder) | - | -| [`create_local_venv`](#nemo_rl-utils-venvs-create_local_venv) | Create a virtual environment using uv and execute a command within it. | -| [`create_local_venv_on_each_node`](#nemo_rl-utils-venvs-create_local_venv_on_each_node) | Create a virtual environment on each Ray node. | - -### Data - -[`DEFAULT_VENV_DIR`](#nemo_rl-utils-venvs-DEFAULT_VENV_DIR) - -[`dir_path`](#nemo_rl-utils-venvs-dir_path) - -[`git_root`](#nemo_rl-utils-venvs-git_root) - -[`logger`](#nemo_rl-utils-venvs-logger) - -### API - - - - - -```python -nemo_rl.utils.venvs._env_builder( - py_executable: str, - venv_name: str, - node_idx: int, - force_rebuild: bool = False -) -``` - - - - - - - - - - - - - -```python -nemo_rl.utils.venvs.create_local_venv( - py_executable: str, - venv_name: str, - force_rebuild: bool = False -) -> str -``` - - - - - - -Create a virtual environment using uv and execute a command within it. - -The output can be used as a py_executable for a Ray worker assuming the worker -nodes also have access to the same file system as the head node. - -This function is cached to avoid multiple calls to uv to create the same venv, -which avoids duplicate logging. - -**Parameters:** - - -Command to run with the virtual environment (e.g., "uv.sh run --locked") - - - -Name of the virtual environment (e.g., "foobar.Worker") - - - -If True, force rebuild the venv even if it already exists - - -**Returns:** `str` - -Path to the python executable in the created virtual environment - - - - - - - - -```python -nemo_rl.utils.venvs.create_local_venv_on_each_node( - py_executable: str, - venv_name: str -) -``` - - - - - - -Create a virtual environment on each Ray node. - -**Parameters:** - - -Command to run with the virtual environment - - - -Name of the virtual environment - - -**Returns:** - -Path to the python executable in the created virtual environment - - - - - - - - -```python -nemo_rl.utils.venvs.DEFAULT_VENV_DIR = os.path.join(git_root, 'venvs') -``` - - - - - - - - - -```python -nemo_rl.utils.venvs.dir_path = os.path.dirname(os.path.abspath(__file__)) -``` - - - - - - - - - -```python -nemo_rl.utils.venvs.git_root = os.path.abspath(os.path.join(dir_path, '../..')) -``` - - - - - - - - - -```python -nemo_rl.utils.venvs.logger = logging.getLogger(__name__) -``` - - - - diff --git a/fern/static/ttl-docs/_navigation.yml b/fern/static/ttl-docs/_navigation.yml deleted file mode 100644 index e4a8b13..0000000 --- a/fern/static/ttl-docs/_navigation.yml +++ /dev/null @@ -1,149 +0,0 @@ -# AUTO-GENERATED by `fern docs md generate` — DO NOT EDIT -- type: section - title: _mlir_libs - slug: ttl/ttl/_mlir_libs - children: - - type: section - title: _site_initialize_1 - slug: ttl/ttl/_mlir_libs/_site_initialize_1 - children: - - type: page - title: _site_initialize_1 - slug: ttl/ttl/_mlir_libs/_site_initialize_1 - pageId: ttl/ttl/_mlir_libs/_site_initialize_1.mdx -- type: section - title: _src - slug: ttl/ttl/_src - children: - - type: section - title: auto_profile - slug: ttl/ttl/_src/auto_profile - children: - - type: page - title: auto_profile - slug: ttl/ttl/_src/auto_profile - pageId: ttl/ttl/_src/auto_profile.mdx - - type: section - title: tensor_registry - slug: ttl/ttl/_src/tensor_registry - children: - - type: page - title: tensor_registry - slug: ttl/ttl/_src/tensor_registry - pageId: ttl/ttl/_src/tensor_registry.mdx - - type: section - title: ttl_ast - slug: ttl/ttl/_src/ttl_ast - children: - - type: page - title: ttl_ast - slug: ttl/ttl/_src/ttl_ast - pageId: ttl/ttl/_src/ttl_ast.mdx -- type: section - title: circular_buffer - slug: ttl/ttl/circular_buffer - children: - - type: page - title: circular_buffer - slug: ttl/ttl/circular_buffer - pageId: ttl/ttl/circular_buffer.mdx -- type: section - title: constants - slug: ttl/ttl/constants - children: - - type: page - title: constants - slug: ttl/ttl/constants - pageId: ttl/ttl/constants.mdx -- type: section - title: diagnostics - slug: ttl/ttl/diagnostics - children: - - type: page - title: diagnostics - slug: ttl/ttl/diagnostics - pageId: ttl/ttl/diagnostics.mdx -- type: section - title: dialects - slug: ttl/ttl/dialects - children: - - type: section - title: _ods_common - slug: ttl/ttl/dialects/_ods_common - children: - - type: page - title: _ods_common - slug: ttl/ttl/dialects/_ods_common - pageId: ttl/ttl/dialects/_ods_common.mdx - - type: section - title: ttl - slug: ttl/ttl/dialects/ttl - children: - - type: page - title: ttl - slug: ttl/ttl/dialects/ttl - pageId: ttl/ttl/dialects/ttl.mdx -- type: section - title: dtype_utils - slug: ttl/ttl/dtype_utils - children: - - type: page - title: dtype_utils - slug: ttl/ttl/dtype_utils - pageId: ttl/ttl/dtype_utils.mdx -- type: section - title: kernel_runner - slug: ttl/ttl/kernel_runner - children: - - type: page - title: kernel_runner - slug: ttl/ttl/kernel_runner - pageId: ttl/ttl/kernel_runner.mdx -- type: section - title: layouts - slug: ttl/ttl/layouts - children: - - type: page - title: layouts - slug: ttl/ttl/layouts - pageId: ttl/ttl/layouts.mdx -- type: section - title: operators - slug: ttl/ttl/operators - children: - - type: page - title: operators - slug: ttl/ttl/operators - pageId: ttl/ttl/operators.mdx -- type: section - title: ttl - slug: ttl/ttl/ttl - children: - - type: page - title: ttl - slug: ttl/ttl/ttl - pageId: ttl/ttl/ttl.mdx -- type: section - title: ttl_api - slug: ttl/ttl/ttl_api - children: - - type: page - title: ttl_api - slug: ttl/ttl/ttl_api - pageId: ttl/ttl/ttl_api.mdx -- type: section - title: ttl_math - slug: ttl/ttl/ttl_math - children: - - type: page - title: ttl_math - slug: ttl/ttl/ttl_math - pageId: ttl/ttl/ttl_math.mdx -- type: section - title: ttl_utils - slug: ttl/ttl/ttl_utils - children: - - type: page - title: ttl_utils - slug: ttl/ttl/ttl_utils - pageId: ttl/ttl/ttl_utils.mdx diff --git a/fern/static/ttl-docs/ttl/ttl.mdx b/fern/static/ttl-docs/ttl/ttl.mdx deleted file mode 100644 index 30235a7..0000000 --- a/fern/static/ttl-docs/ttl/ttl.mdx +++ /dev/null @@ -1,60 +0,0 @@ ---- -layout: overview -slug: ttl/ttl -title: ttl ---- - -## Subpackages - -- **[`ttl._mlir_libs`](/ttl/ttl/_mlir_libs)** -- **[`ttl._src`](/ttl/ttl/_src)** -- **[`ttl.dialects`](/ttl/ttl/dialects)** - -## Submodules - -- **[`ttl.circular_buffer`](/ttl/ttl/circular_buffer)** -- **[`ttl.constants`](/ttl/ttl/constants)** -- **[`ttl.diagnostics`](/ttl/ttl/diagnostics)** -- **[`ttl.dtype_utils`](/ttl/ttl/dtype_utils)** -- **[`ttl.ir`](/ttl/ttl/ir)** -- **[`ttl.kernel_runner`](/ttl/ttl/kernel_runner)** -- **[`ttl.layouts`](/ttl/ttl/layouts)** -- **[`ttl.operators`](/ttl/ttl/operators)** -- **[`ttl.ttl`](/ttl/ttl/ttl)** -- **[`ttl.ttl_api`](/ttl/ttl/ttl_api)** -- **[`ttl.ttl_math`](/ttl/ttl/ttl_math)** -- **[`ttl.ttl_utils`](/ttl/ttl/ttl_utils)** - -## Package Contents - -### Data - -[`__all__`](#ttl-__all__) - -[`__version__`](#ttl-__version__) - -### API - - - - - -```python -ttl.__all__ = ['kernel', 'compute', 'datamovement', 'Program', 'CircularBuffer', 'TensorBlock'... -``` - - - - - - - - - -```python -ttl.__version__ = '0.1.0' -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/_mlir_libs.mdx b/fern/static/ttl-docs/ttl/ttl/_mlir_libs.mdx deleted file mode 100644 index 7dd8cd3..0000000 --- a/fern/static/ttl-docs/ttl/ttl/_mlir_libs.mdx +++ /dev/null @@ -1,9 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/_mlir_libs -title: ttl._mlir_libs ---- - -## Submodules - -- **[`ttl._mlir_libs._site_initialize_1`](/ttl/ttl/_mlir_libs/_site_initialize_1)** diff --git a/fern/static/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx b/fern/static/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx deleted file mode 100644 index 414f0c4..0000000 --- a/fern/static/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx +++ /dev/null @@ -1,35 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/_mlir_libs/_site_initialize_1 -title: ttl._mlir_libs._site_initialize_1 ---- - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`register_dialects`](#ttl-_mlir_libs-_site_initialize_1-register_dialects) | Called by MLIR site initialization to add TTL dialects to the registry. | - -### API - - - - - -```python -ttl._mlir_libs._site_initialize_1.register_dialects( - registry -) -``` - - - - - - -Called by MLIR site initialization to add TTL dialects to the registry. - - - diff --git a/fern/static/ttl-docs/ttl/ttl/_src.mdx b/fern/static/ttl-docs/ttl/ttl/_src.mdx deleted file mode 100644 index 20050a7..0000000 --- a/fern/static/ttl-docs/ttl/ttl/_src.mdx +++ /dev/null @@ -1,11 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/_src -title: ttl._src ---- - -## Submodules - -- **[`ttl._src.auto_profile`](/ttl/ttl/_src/auto_profile)** -- **[`ttl._src.tensor_registry`](/ttl/ttl/_src/tensor_registry)** -- **[`ttl._src.ttl_ast`](/ttl/ttl/_src/ttl_ast)** diff --git a/fern/static/ttl-docs/ttl/ttl/_src/auto_profile.mdx b/fern/static/ttl-docs/ttl/ttl/_src/auto_profile.mdx deleted file mode 100644 index 6e2407b..0000000 --- a/fern/static/ttl-docs/ttl/ttl/_src/auto_profile.mdx +++ /dev/null @@ -1,479 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/_src/auto_profile -title: ttl._src.auto_profile ---- - -Auto-profiling infrastructure for tt-lang kernels. - -Enabled via TTLANG_AUTO_PROFILE=1 environment variable. -Automatically instruments every operation with signposts and generates -a visual profile report showing cycle counts per source line. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`Colors`](#ttl-_src-auto_profile-Colors) | ANSI color codes for terminal output. | -| [`ProfileResult`](#ttl-_src-auto_profile-ProfileResult) | Represents profiling results for a single signpost. | -| [`SourceLineMapper`](#ttl-_src-auto_profile-SourceLineMapper) | Maps signpost markers back to source code lines. | - -### Functions - -| Name | Description | -|------|-------------| -| [`build_cb_wait_to_dma_map`](#ttl-_src-auto_profile-build_cb_wait_to_dma_map) | Build mapping from cb_wait locations to DMA barrier locations. | -| [`build_dma_producer_to_cb_map`](#ttl-_src-auto_profile-build_dma_producer_to_cb_map) | Build mapping from DMA barrier locations to CB index. | -| [`generate_signpost_name`](#ttl-_src-auto_profile-generate_signpost_name) | Generate before/after signpost names for an operation. | -| [`get_line_mapper`](#ttl-_src-auto_profile-get_line_mapper) | Get the global line mapper instance. | -| [`is_auto_profile_enabled`](#ttl-_src-auto_profile-is_auto_profile_enabled) | Check if auto-profiling is enabled via environment variable. | -| [`load_cb_flow_graph`](#ttl-_src-auto_profile-load_cb_flow_graph) | Load CB flow graph JSON from same directory as CSV. | -| [`parse_device_profile_csv`](#ttl-_src-auto_profile-parse_device_profile_csv) | Parse the device profile CSV and extract signpost timing data. | -| [`parse_signpost_name`](#ttl-_src-auto_profile-parse_signpost_name) | Parse op name and implicit flag from signpost name. | -| [`print_profile_report`](#ttl-_src-auto_profile-print_profile_report) | Print a profile report organized by thread. | - -### Data - -[`_global_line_mapper`](#ttl-_src-auto_profile-_global_line_mapper) - -### API - - - - - -```python -class ttl._src.auto_profile.Colors() -``` - - - - - - -ANSI color codes for terminal output. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -ttl._src.auto_profile.Colors.cb_bg( - cb_index: int -) -> str -``` - - - - - - -classmethod - -Get background color for a CB index, or empty if out of range. - - - - - - - - - -```python -class ttl._src.auto_profile.ProfileResult( - signpost: str, - thread: str, - cycles: int, - lineno: int, - source: str -) -``` - - - - - - -Represents profiling results for a single signpost. - - - - - - - - - - -```python -class ttl._src.auto_profile.SourceLineMapper() -``` - - - - - - -Maps signpost markers back to source code lines. - - - - - - - - - - - - - - -```python -ttl._src.auto_profile.SourceLineMapper.get_line_info( - signpost_name: str -) -> typing.Optional[typing.Tuple[int, str]] -``` - - - - - - -Get line number and source for a signpost. - - - - - - - -```python -ttl._src.auto_profile.SourceLineMapper.register_signpost( - signpost_name: str, - lineno: int, - source: str -) -``` - - - - - - -Register a signpost with its source line information. - - - - - - - -```python -ttl._src.auto_profile.SourceLineMapper.set_source( - source_lines: typing.List[str] -) -``` - - - - - - -Set the source code lines for display. - - - - - - - - - -```python -ttl._src.auto_profile.build_cb_wait_to_dma_map( - cb_flow: typing.Optional[typing.Dict] -) -> typing.Dict[typing.Tuple[str, int], typing.Tuple[str, int, int]] -``` - - - - - - -Build mapping from cb_wait locations to DMA barrier locations. - -Only maps consumers waiting for DMA reads (data flowing into CB). -cb_wait ops waiting for compute output (where DMA is a write) are not mapped. - -**Returns:** `Dict[Tuple[str, int], Tuple[str, int, int]]` - -Dict mapping (kernel, line) of cb_wait -> (barrier_kernel, barrier_line, cb_index) - - - - - - - - -```python -ttl._src.auto_profile.build_dma_producer_to_cb_map( - cb_flow: typing.Optional[typing.Dict] -) -> typing.Dict[typing.Tuple[str, int], int] -``` - - - - - - -Build mapping from DMA barrier locations to CB index. - -**Returns:** `Dict[Tuple[str, int], int]` - -Dict mapping (kernel, line) of DMA read barrier -> cb_index - - - - - - - - -```python -ttl._src.auto_profile.generate_signpost_name( - operation: str, - lineno: int, - col: int -) -> typing.Tuple[str, str] -``` - - - - - - -Generate before/after signpost names for an operation. - -**Returns:** `Tuple[str, str]` - -Tuple of (before_name, after_name) - - - - - - - - -```python -ttl._src.auto_profile.get_line_mapper() -> ttl._src.auto_profile.SourceLineMapper -``` - - - - - - -Get the global line mapper instance. - - - - - - - - -```python -ttl._src.auto_profile.is_auto_profile_enabled() -> bool -``` - - - - - - -Check if auto-profiling is enabled via environment variable. - - - - - - - - -```python -ttl._src.auto_profile.load_cb_flow_graph( - csv_path: pathlib.Path -) -> typing.Optional[typing.Dict] -``` - - - - - - -Load CB flow graph JSON from same directory as CSV. - - - - - - - - -```python -ttl._src.auto_profile.parse_device_profile_csv( - csv_path: pathlib.Path, - line_mapper: ttl._src.auto_profile.SourceLineMapper -) -> typing.List[ttl._src.auto_profile.ProfileResult] -``` - - - - - - -Parse the device profile CSV and extract signpost timing data. - -**Parameters:** - - -Path to profile_log_device.csv - - - -Mapper to correlate signposts to source lines - - -**Returns:** `List[ProfileResult]` - -List of ProfileResult objects sorted by line number - - - - - - - - -```python -ttl._src.auto_profile.parse_signpost_name( - signpost: str -) -> typing.Tuple[typing.Optional[str], bool] -``` - - - - - - -Parse op name and implicit flag from signpost name. - -Returns (op_name, is_implicit) where op_name is None for line-only signposts. -Examples: - "line_52_before" -> (None, False) - "line_52_cb_wait_before" -> ("cb_wait", False) - "line_52_implicit_cb_pop_before" -> ("cb_pop", True) - - - - - - - - -```python -ttl._src.auto_profile.print_profile_report( - results: typing.List[ttl._src.auto_profile.ProfileResult], - all_source_lines: typing.Dict[str, typing.List[str]], - thread_to_kernel: typing.Dict[str, str], - line_mapper: typing.Optional[ttl._src.auto_profile.SourceLineMapper] = None, - cb_wait_to_dma: typing.Optional[typing.Dict[typing.Tuple[str, int], typing.Tuple[str, int, int]]] = None, - dma_producer_to_cb: typing.Optional[typing.Dict[typing.Tuple[str, int], int]] = None, - kernel_line_offsets: typing.Optional[typing.Dict[str, int]] = None -) -``` - - - - - - -Print a profile report organized by thread. - -Shows full source context with cycle annotations where available. -Each thread displays its corresponding kernel's source code. - -**Parameters:** - - -List of ProfileResult from CSV parsing - - - -Dict mapping kernel name to source lines - - - -Dict mapping RISC thread name to kernel name - - - -Optional SourceLineMapper with line offset info - - - -Optional mapping from (kernel, line) -> (dma_kernel, dma_line, cb_index) - - - -Optional mapping from (kernel, line) -> cb_index for DMA producers - - - -Optional mapping from kernel name to line offset - - - - - - - - - -```python -ttl._src.auto_profile._global_line_mapper = SourceLineMapper() -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/_src/tensor_registry.mdx b/fern/static/ttl-docs/ttl/ttl/_src/tensor_registry.mdx deleted file mode 100644 index 0c2cdf6..0000000 --- a/fern/static/ttl-docs/ttl/ttl/_src/tensor_registry.mdx +++ /dev/null @@ -1,169 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/_src/tensor_registry -title: ttl._src.tensor_registry ---- - -Registry for tensor global names, used to track tensor parameter names. - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`get_tensor_global_index`](#ttl-_src-tensor_registry-get_tensor_global_index) | Get the global index for a tensor. | -| [`get_tensor_global_name`](#ttl-_src-tensor_registry-get_tensor_global_name) | Get the global name for a tensor, checking registry first then attribute. | -| [`get_tensor_source`](#ttl-_src-tensor_registry-get_tensor_source) | Get the source location where a tensor was assigned, if tracked. | -| [`register_tensor_name`](#ttl-_src-tensor_registry-register_tensor_name) | Register a global name and index for a tensor. | -| [`register_tensor_source`](#ttl-_src-tensor_registry-register_tensor_source) | Register the source location where a tensor variable was assigned. | - -### Data - -[`_tensor_index_registry`](#ttl-_src-tensor_registry-_tensor_index_registry) - -[`_tensor_name_registry`](#ttl-_src-tensor_registry-_tensor_name_registry) - -[`_tensor_source_registry`](#ttl-_src-tensor_registry-_tensor_source_registry) - -### API - - - - - -```python -ttl._src.tensor_registry.get_tensor_global_index( - tensor -) -> int -``` - - - - - - -Get the global index for a tensor. - - - - - - - - -```python -ttl._src.tensor_registry.get_tensor_global_name( - tensor -) -> str -``` - - - - - - -Get the global name for a tensor, checking registry first then attribute. - - - - - - - - -```python -ttl._src.tensor_registry.get_tensor_source( - tensor -) -> typing.Optional[typing.Tuple[str, int]] -``` - - - - - - -Get the source location where a tensor was assigned, if tracked. - - - - - - - - -```python -ttl._src.tensor_registry.register_tensor_name( - tensor, - name: str, - index: int = -1 -) -> None -``` - - - - - - -Register a global name and index for a tensor. - - - - - - - - -```python -ttl._src.tensor_registry.register_tensor_source( - tensor, - source_file: str, - line: int -) -> None -``` - - - - - - -Register the source location where a tensor variable was assigned. - - - - - - - - -```python -ttl._src.tensor_registry._tensor_index_registry: Dict[int, int] = {} -``` - - - - - - - - - -```python -ttl._src.tensor_registry._tensor_name_registry: Dict[int, str] = {} -``` - - - - - - - - - -```python -ttl._src.tensor_registry._tensor_source_registry: Dict[int, Tuple[str, int]] = {} -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/_src/ttl_ast.mdx b/fern/static/ttl-docs/ttl/ttl/_src/ttl_ast.mdx deleted file mode 100644 index cc33586..0000000 --- a/fern/static/ttl-docs/ttl/ttl/_src/ttl_ast.mdx +++ /dev/null @@ -1,731 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/_src/ttl_ast -title: ttl._src.ttl_ast ---- - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`CompilerContext`](#ttl-_src-ttl_ast-CompilerContext) | Immutable compilation context for TTL kernels. | -| [`TTLGenericCompiler`](#ttl-_src-ttl_ast-TTLGenericCompiler) | Compiler that generates TTL dialect ops from Python AST. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_build_tensor_type`](#ttl-_src-ttl_ast-_build_tensor_type) | Build MLIR tensor type for a ttnn tensor with TTNNLayoutAttr. | -| [`_get_annotation_name`](#ttl-_src-ttl_ast-_get_annotation_name) | Extract the type name from an annotation node. | -| [`_make_file_loc`](#ttl-_src-ttl_ast-_make_file_loc) | Create an MLIR file location from an AST node. | -| [`_raise_tensor_error`](#ttl-_src-ttl_ast-_raise_tensor_error) | Raise TTLangCompileError with tensor source location if available. | -| [`syntax`](#ttl-_src-ttl_ast-syntax) | - | - -### API - - - - - -```python -class ttl._src.ttl_ast.CompilerContext( - grid: typing.List[int], - memory_space: str, - tiled: bool -) -``` - - - - - - -Dataclass - -Immutable compilation context for TTL kernels. - - - - - - - - - - - - - - - - -```python -class ttl._src.ttl_ast.TTLGenericCompiler( - name, - kernel_type = None, - captures = {}, - args = (), - kwargs = {} -) -``` - - - - - - -**Bases:** `TTCompilerBase` - -Compiler that generates TTL dialect ops from Python AST. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._build_index_or_range( - node -) -``` - - - - - - -Convert AST node to (start_value, is_range) tuple. - -For slice syntax (start:end), returns (start_value, True). -For index syntax (value), returns (value, False). - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._close_final_signpost() -``` - - - - - - -Close the final signpost at the end of function body. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._emit_cb_from_capture( - cb -) -``` - - - - - - -Emit ttl.bind_cb for a captured CircularBuffer instance. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._emit_entry( - node -) -``` - - - - - - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._emit_line_signpost_if_needed( - node -) -``` - - - - - - -Emit signposts at line boundaries for auto-profiling. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._emit_op_signposts( - op_name: str, - node, - op_fn, - implicit = False -) -``` - - - - - - -Emit signposts for CB operations with op name included. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._emit_signpost( - name: str -) -``` - - - - - - -Emit a signpost operation into the MLIR. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._get_cb_tensor_type( - cb_val, - node = None -) -``` - - - - - - -Extract the tensor type from a TTL CB type. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._is_ttl_math_access( - node -) -``` - - - - - - -Check if node is ttl.math.XXX access pattern. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._is_ttl_module_access( - node -) -``` - - - - - - -Check if node is ttl.XXX access pattern. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._loc_for_node( - node -) -``` - - - - - - -Return file location for node if debug_locations enabled, else name location. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._raise_error( - node, - message: str -) -``` - - - - - - -Raise a TTLangCompileError with source location from AST node. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._resolve_ttl_function( - node, - func_args, - kwargs -) -``` - - - - - - -Resolve and call a ttl.XXX or ttl.math.XXX function. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._to_index_value( - node -) -``` - - - - - - -Convert AST node to MLIR index Value. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler._try_emit_auto_signposts( - node, - visit_fn -) -``` - - - - - - -Emit line-based signposts if auto-profiling is enabled. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler.visit_Assign( - node -) -``` - - - - - - -Handle tuple unpacking for TTL functions like core(dims=2). - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler.visit_AsyncFunctionDef( - node -) -``` - - - - - - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler.visit_Attribute( - node, - func_args = [], - kwargs = {} -) -``` - - - - - - -Override to set location context and catch errors for method calls. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler.visit_BinOp( - node -) -``` - - - - - - -Override to inject auto-profiling and provide better error messages. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler.visit_Call( - node -) -``` - - - - - - -Override to set location context, catch errors, and inject auto-profiling. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler.visit_Constant( - node -) -``` - - - - - - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler.visit_FunctionDef( - node -) -``` - - - - - - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler.visit_List( - node -) -``` - - - - - - -Parse a list of constants. Returns a Python list, not MLIR values. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler.visit_Name( - node -) -``` - - - - - - -Override to check function globals for simple constants. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler.visit_Subscript( - node -) -``` - - - - - - -Handle tensor[row, col] or tensor[r0:r1, c0:c1] indexing. - - - - - - - -```python -ttl._src.ttl_ast.TTLGenericCompiler.visit_With( - node -) -``` - - - - - - -Handle 'with' for CircularBuffer acquire/release. - -Acquire ops (wait/reserve) are generated left-to-right. -Release ops (pop/push) are generated in reverse order at scope end. - - - - - - - - - -```python -ttl._src.ttl_ast._build_tensor_type( - ctx, - tensor, - grid, - tiled, - memory_space -) -``` - - - - - - -Build MLIR tensor type for a ttnn tensor with TTNNLayoutAttr. - - - - - - - - -```python -ttl._src.ttl_ast._get_annotation_name( - annotation -) -``` - - - - - - -Extract the type name from an annotation node. - -Handles both simple names (CircularBuffer) and qualified names (ttl.CircularBuffer). -Returns the simple type name (e.g., 'CircularBuffer') in both cases. - - - - - - - - -```python -ttl._src.ttl_ast._make_file_loc( - ctx, - source_file: str, - node, - line_offset: int = 0 -) -> Location -``` - - - - - - -Create an MLIR file location from an AST node. - - - - - - - - -```python -ttl._src.ttl_ast._raise_tensor_error( - tensor, - message: str -) -``` - - - - - - -Raise TTLangCompileError with tensor source location if available. - - - - - - - - -```python -ttl._src.ttl_ast.syntax( - syntax_name -) -``` - - - - - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/circular_buffer.mdx b/fern/static/ttl-docs/ttl/ttl/circular_buffer.mdx deleted file mode 100644 index 38c6508..0000000 --- a/fern/static/ttl-docs/ttl/ttl/circular_buffer.mdx +++ /dev/null @@ -1,283 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/circular_buffer -title: ttl.circular_buffer ---- - -Circular buffer operations for inter-thread communication. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`CircularBuffer`](#ttl-circular_buffer-CircularBuffer) | Circular buffer for inter-thread communication. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_get_cb_tensor_type`](#ttl-circular_buffer-_get_cb_tensor_type) | Extract the tensor type from a TTL CB type. | -| [`_next_cb_index`](#ttl-circular_buffer-_next_cb_index) | Get next CB index and increment counter. | -| [`_reset_cb_counter`](#ttl-circular_buffer-_reset_cb_counter) | Reset the CB index counter. Called at kernel start. | -| [`get_cb_count`](#ttl-circular_buffer-get_cb_count) | Return number of CBs allocated so far. | -| [`make_circular_buffer_like`](#ttl-circular_buffer-make_circular_buffer_like) | Create a circular buffer with properties derived from a tensor. | - -### Data - -[`_cb_index_counter`](#ttl-circular_buffer-_cb_index_counter) - -### API - - - - - -```python -class ttl.circular_buffer.CircularBuffer( - tensor: typing.Any, - shape: typing.Tuple[int, int], - buffer_factor: int -) -``` - - - - - - -Circular buffer for inter-thread communication. - -Circular buffers provide producer-consumer synchronization between -compute and data movement threads. - -Can be instantiated via make_circular_buffer_like() in kernel body, -then captured by thread closures. Methods generate TTL ops during compilation. - - - - - - - - -```python -ttl.circular_buffer.CircularBuffer.pop( - ast_self: ttl.circular_buffer.CircularBuffer -) -> None -``` - - - - - - -Signal that data has been consumed (consumer release). - -Use in consumer threads after wait() to signal that data has been -consumed and space is available for producers. - - - - - - - -```python -ttl.circular_buffer.CircularBuffer.push( - ast_self: ttl.circular_buffer.CircularBuffer -) -> None -``` - - - - - - -Signal that data is ready in the circular buffer (producer release). - -Use in producer threads after reserve() to signal that data has been -written and is ready for consumers. - - - - - - - -```python -ttl.circular_buffer.CircularBuffer.reserve( - ast_self: ttl.circular_buffer.CircularBuffer -) -> ttl.ttl_api.TensorBlock -``` - - - - - - -Reserve space in the circular buffer (producer acquire). - -Use in producer threads to acquire space for writing. Must be followed -by push() to signal data is ready. - -**Returns:** `TensorBlock` - -The reserved space with CB association. - - - - - - - -```python -ttl.circular_buffer.CircularBuffer.wait( - ast_self: ttl.circular_buffer.CircularBuffer -) -> ttl.ttl_api.TensorBlock -``` - - - - - - -Wait for data from the circular buffer (consumer acquire). - -Use in consumer threads to acquire data. Must be followed by pop() -to signal consumption is complete. - -**Returns:** `TensorBlock` - -The acquired data with CB association. - - - - - - - - - -```python -ttl.circular_buffer._get_cb_tensor_type( - cb_val -) -``` - - - - - - -Extract the tensor type from a TTL CB type. - - - - - - - - -```python -ttl.circular_buffer._next_cb_index() -``` - - - - - - -Get next CB index and increment counter. - - - - - - - - -```python -ttl.circular_buffer._reset_cb_counter() -``` - - - - - - -Reset the CB index counter. Called at kernel start. - - - - - - - - -```python -ttl.circular_buffer.get_cb_count() -``` - - - - - - -Return number of CBs allocated so far. - - - - - - - - -```python -ttl.circular_buffer.make_circular_buffer_like( - tensor: typing.Any, - shape: typing.Tuple[int, int], - buffer_factor: int = 2 -) -> ttl.circular_buffer.CircularBuffer -``` - - - - - - -Create a circular buffer with properties derived from a tensor. - -**Parameters:** - - -Tensor that determines the CB's data type - - - -(rows, cols) in tiles for wait/reserve operations - - - -Capacity multiplier (default 2 for double-buffering) - - -**Returns:** `CircularBuffer` - -CircularBuffer for use in thread function closures - - - - - - - - -```python -ttl.circular_buffer._cb_index_counter = 0 -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/constants.mdx b/fern/static/ttl-docs/ttl/ttl/constants.mdx deleted file mode 100644 index 5aa63c2..0000000 --- a/fern/static/ttl-docs/ttl/ttl/constants.mdx +++ /dev/null @@ -1,41 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/constants -title: ttl.constants ---- - -Constants used throughout the DSL. - -## Module Contents - -### Data - -[`DEFAULT_TILE_SIZE`](#ttl-constants-DEFAULT_TILE_SIZE) - -[`SUPPORTED_MEMORY_SPACES`](#ttl-constants-SUPPORTED_MEMORY_SPACES) - -### API - - - - - -```python -ttl.constants.DEFAULT_TILE_SIZE = 32 -``` - - - - - - - - - -```python -ttl.constants.SUPPORTED_MEMORY_SPACES = frozenset(['L1', 'DRAM']) -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/diagnostics.mdx b/fern/static/ttl-docs/ttl/ttl/diagnostics.mdx deleted file mode 100644 index 709211b..0000000 --- a/fern/static/ttl-docs/ttl/ttl/diagnostics.mdx +++ /dev/null @@ -1,466 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/diagnostics -title: ttl.diagnostics ---- - -Diagnostic utilities for formatting compiler errors with source context. - -This module provides Rust/Swift-style error formatting that displays -source code snippets with ASCII arrows pointing to the error location. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`SourceDiagnostic`](#ttl-diagnostics-SourceDiagnostic) | Format errors with source context and ASCII arrows. | -| [`TTLangCompileError`](#ttl-diagnostics-TTLangCompileError) | Exception for tt-lang compilation errors with source context. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_extract_core_message`](#ttl-diagnostics-_extract_core_message) | Extract the core error message from MLIR diagnostic output. | -| [`_extract_note`](#ttl-diagnostics-_extract_note) | Extract any note from the MLIR error message. | -| [`_read_file_lines`](#ttl-diagnostics-_read_file_lines) | Read source lines from a file if it exists. | -| [`_verbose_errors_enabled`](#ttl-diagnostics-_verbose_errors_enabled) | Check if verbose MLIR error output is enabled. | -| [`extract_location_from_mlir_error`](#ttl-diagnostics-extract_location_from_mlir_error) | Extract source location from an MLIR error message. | -| [`find_variable_assignment`](#ttl-diagnostics-find_variable_assignment) | Find the line where a variable was assigned, searching backwards. | -| [`format_mlir_error`](#ttl-diagnostics-format_mlir_error) | Format an MLIR error with source context if location is available. | -| [`format_python_error`](#ttl-diagnostics-format_python_error) | Format a Python error with source context. | -| [`parse_mlir_location`](#ttl-diagnostics-parse_mlir_location) | Parse an MLIR location string to extract file, line, and column. | - -### API - - - - - -```python -class ttl.diagnostics.SourceDiagnostic( - source_lines: typing.List[str], - filename: str -) -``` - - - - - - -Format errors with source context and ASCII arrows. - -Produces error messages in the style of modern compilers (Rust, Swift): - - error: type mismatch in add operation - --> kernel.py:43:16 - | - 43 | result = l + r - | ^^^ expected bf16, got f32 - | - - - - - - -```python -ttl.diagnostics.SourceDiagnostic.format_error( - line: int, - col: int, - message: str, - label: str = 'error', - span_length: int = 1, - note: typing.Optional[str] = None -) -> str -``` - - - - - - -Format an error with source context. - -**Parameters:** - - -1-based line number - - - -1-based column number - - - -Main error message - - - -Error label (e.g., "error", "warning") - - - -Length of the underline (^^^) - - - -Optional additional note - - -**Returns:** `str` - -Formatted error string with source context - - - - - - - -```python -ttl.diagnostics.SourceDiagnostic.format_error_chain( - errors: typing.List[typing.Tuple[int, int, str, typing.Optional[str]]] -) -> str -``` - - - - - - -Format multiple related errors. - -**Parameters:** - - -List of (line, col, message, note) tuples - - -**Returns:** `str` - -Formatted error chain - - - - - - - - - -```python -class ttl.diagnostics.TTLangCompileError( - message: str, - source_file: typing.Optional[str] = None, - line: typing.Optional[int] = None, - col: typing.Optional[int] = None, - source_lines: typing.Optional[typing.List[str]] = None -) -``` - - - - - - -Exception - -**Bases:** `Exception` - -Exception for tt-lang compilation errors with source context. - -This exception carries enough information to produce pretty error messages -pointing to the exact source location where the error occurred. - - - - - - -```python -ttl.diagnostics.TTLangCompileError.format() -> str -``` - - - - - - -Format error with source context if available. - - - - - - - - - -```python -ttl.diagnostics._extract_core_message( - error_msg: str -) -> str -``` - - - - - - -Extract the core error message from MLIR diagnostic output. - -This extracts: "expects transfer handle to be synchronized with ttl.wait" - - - - - - - - -```python -ttl.diagnostics._extract_note( - error_msg: str -) -> typing.Optional[str] -``` - - - - - - -Extract any note from the MLIR error message. - - - - - - - - -```python -ttl.diagnostics._read_file_lines( - filepath: str -) -> typing.Optional[typing.List[str]] -``` - - - - - - -Read source lines from a file if it exists. - - - - - - - - -```python -ttl.diagnostics._verbose_errors_enabled() -> bool -``` - - - - - - -Check if verbose MLIR error output is enabled. - - - - - - - - -```python -ttl.diagnostics.extract_location_from_mlir_error( - error_msg: str -) -> typing.Optional[typing.Tuple[str, int, int]] -``` - - - - - - -Extract source location from an MLIR error message. - -**Parameters:** - - -Full MLIR error message - - -**Returns:** `Optional[Tuple[str, int, int]]` - -Tuple of (filename, line, col) or None if no location found - - - - - - - - -```python -ttl.diagnostics.find_variable_assignment( - source_lines: typing.List[str], - var_name: str, - before_line: int -) -> typing.Optional[int] -``` - - - - - - -Find the line where a variable was assigned, searching backwards. - -**Parameters:** - - -List of source lines (0-indexed) - - - -Variable name to search for - - - -Search backwards from this 1-based line number - - -**Returns:** `Optional[int]` - -1-based line number where assignment was found, or None - - - - - - - - -```python -ttl.diagnostics.format_mlir_error( - error_msg: str, - source_lines: typing.Optional[typing.List[str]] = None, - source_file: typing.Optional[str] = None -) -> str -``` - - - - - - -Format an MLIR error with source context if location is available. - -**Parameters:** - - -The MLIR error message - - - -Original Python source lines (optional, will read from file if needed) - - - -Source filename (optional, extracted from error if not provided) - - -**Returns:** `str` - -Formatted error message, with source context if available - - - - - - - - -```python -ttl.diagnostics.format_python_error( - error: Exception, - source_file: str, - line: int, - source_lines: typing.Optional[typing.List[str]] = None -) -> str -``` - - - - - - -Format a Python error with source context. - -**Parameters:** - - -The Python exception - - - -Source file path - - - -Line number in source file - - - -Source lines (will read from file if not provided) - - -**Returns:** `str` - -Formatted error message with source context - - - - - - - - -```python -ttl.diagnostics.parse_mlir_location( - loc_str: str -) -> typing.Optional[typing.Tuple[str, int, int]] -``` - - - - - - -Parse an MLIR location string to extract file, line, and column. - -MLIR locations can appear in several formats: -- loc("filename":line:col) -- loc("filename":line:col to :line:col) -- loc(#loc1) with #loc1 = loc("filename":line:col) - -**Parameters:** - - -MLIR location string - - -**Returns:** `Optional[Tuple[str, int, int]]` - -Tuple of (filename, line, col) or None if not parseable - - - diff --git a/fern/static/ttl-docs/ttl/ttl/dialects.mdx b/fern/static/ttl-docs/ttl/ttl/dialects.mdx deleted file mode 100644 index 2eb4b50..0000000 --- a/fern/static/ttl-docs/ttl/ttl/dialects.mdx +++ /dev/null @@ -1,12 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/dialects -title: ttl.dialects ---- - -TTLang dialect modules. - -## Submodules - -- **[`ttl.dialects._ods_common`](/ttl/ttl/dialects/_ods_common)** -- **[`ttl.dialects.ttl`](/ttl/ttl/dialects/ttl)** diff --git a/fern/static/ttl-docs/ttl/ttl/dialects/_ods_common.mdx b/fern/static/ttl-docs/ttl/ttl/dialects/_ods_common.mdx deleted file mode 100644 index 7b2c71e..0000000 --- a/fern/static/ttl-docs/ttl/ttl/dialects/_ods_common.mdx +++ /dev/null @@ -1,39 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/dialects/_ods_common -title: ttl.dialects._ods_common ---- - -## Module Contents - -### Data - -[`__all__`](#ttl-dialects-_ods_common-__all__) - -[`_cext`](#ttl-dialects-_ods_common-_cext) - -### API - - - - - -```python -ttl.dialects._ods_common.__all__ = ['_cext'] -``` - - - - - - - - - -```python -ttl.dialects._ods_common._cext = _upstream._cext -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/dialects/ttl.mdx b/fern/static/ttl-docs/ttl/ttl/dialects/ttl.mdx deleted file mode 100644 index d8c69a1..0000000 --- a/fern/static/ttl-docs/ttl/ttl/dialects/ttl.mdx +++ /dev/null @@ -1,81 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/dialects/ttl -title: ttl.dialects.ttl ---- - -TTL (TT-Lang) dialect Python bindings. - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`ensure_dialects_registered`](#ttl-dialects-ttl-ensure_dialects_registered) | Ensure TTL dialect is registered with the given MLIR context. | - -### Data - -[`CircularBufferType`](#ttl-dialects-ttl-CircularBufferType) - -[`SliceAttr`](#ttl-dialects-ttl-SliceAttr) - -[`__all__`](#ttl-dialects-ttl-__all__) - -### API - - - - - -```python -ttl.dialects.ttl.ensure_dialects_registered( - ctx -) -``` - - - - - - -Ensure TTL dialect is registered with the given MLIR context. - - - - - - - - -```python -ttl.dialects.ttl.CircularBufferType = ir.CircularBufferType -``` - - - - - - - - - -```python -ttl.dialects.ttl.SliceAttr = ir.SliceAttr -``` - - - - - - - - - -```python -ttl.dialects.ttl.__all__ = [*[name for name in (globals().keys()) if not name.startswith('_')]] -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/dtype_utils.mdx b/fern/static/ttl-docs/ttl/ttl/dtype_utils.mdx deleted file mode 100644 index fce940c..0000000 --- a/fern/static/ttl-docs/ttl/ttl/dtype_utils.mdx +++ /dev/null @@ -1,212 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/dtype_utils -title: ttl.dtype_utils ---- - -Data type conversion utilities between PyTorch, TTNN, and MLIR types. - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`is_ttnn_tensor`](#ttl-dtype_utils-is_ttnn_tensor) | Check if tensor is a ttnn.Tensor. | -| [`tensor_dtype_to_ttcore_datatype`](#ttl-dtype_utils-tensor_dtype_to_ttcore_datatype) | Convert tensor dtype to ttcore.DataType, supporting both torch and ttnn dtypes. | -| [`tile_bytes_from_dtype`](#ttl-dtype_utils-tile_bytes_from_dtype) | Calculate tile size in bytes from ttnn dtype. | -| [`torch_dtype_to_ttcore_datatype`](#ttl-dtype_utils-torch_dtype_to_ttcore_datatype) | Convert PyTorch dtype to ttcore.DataType enum. | -| [`torch_dtype_to_ttnn_datatype`](#ttl-dtype_utils-torch_dtype_to_ttnn_datatype) | Convert PyTorch dtype to ttnn.DataType enum. | -| [`ttnn_dtype_to_ttcore_datatype`](#ttl-dtype_utils-ttnn_dtype_to_ttcore_datatype) | Convert ttnn.DataType to ttcore.DataType enum. | - -### API - - - - - -```python -ttl.dtype_utils.is_ttnn_tensor( - tensor -) -> bool -``` - - - - - - -Check if tensor is a ttnn.Tensor. - - - - - - - - -```python -ttl.dtype_utils.tensor_dtype_to_ttcore_datatype( - dtype -) -``` - - - - - - -Convert tensor dtype to ttcore.DataType, supporting both torch and ttnn dtypes. - -**Parameters:** - - -Either torch dtype or ttnn.DataType - - -**Returns:** - -ttcore.DataType enum value - - - - - - - - -```python -ttl.dtype_utils.tile_bytes_from_dtype( - dtype -) -> int -``` - - - - - - -Calculate tile size in bytes from ttnn dtype. - -For tiled tensors, each tile is 32x32 elements. The byte size depends on -the data type's element size plus any format-specific overhead. - -**Parameters:** - - -ttnn.DataType enum value - - -**Returns:** `int` - -Tile size in bytes - -**Raises:** - -- `ValueError`: If dtype is not supported - - - - - - - - -```python -ttl.dtype_utils.torch_dtype_to_ttcore_datatype( - torch_dtype -) -``` - - - - - - -Convert PyTorch dtype to ttcore.DataType enum. - -**Parameters:** - - -PyTorch dtype (torch.float32, torch.int32, etc.) - - -**Returns:** - -ttcore.DataType enum value - -**Raises:** - -- `ValueError`: If dtype is not supported - - - - - - - - -```python -ttl.dtype_utils.torch_dtype_to_ttnn_datatype( - torch_dtype -) -``` - - - - - - -Convert PyTorch dtype to ttnn.DataType enum. - -**Parameters:** - - -PyTorch dtype (torch.float32, torch.bfloat16, etc.) - - -**Returns:** - -ttnn.DataType enum value - -**Raises:** - -- `ImportError`: If ttnn is not available -- `ValueError`: If dtype is not supported - - - - - - - - -```python -ttl.dtype_utils.ttnn_dtype_to_ttcore_datatype( - ttnn_dtype -) -``` - - - - - - -Convert ttnn.DataType to ttcore.DataType enum. - -**Parameters:** - - -ttnn.DataType enum value - - -**Returns:** - -ttcore.DataType enum value - -**Raises:** - -- `ValueError`: If dtype is not supported - - - diff --git a/fern/static/ttl-docs/ttl/ttl/kernel_runner.mdx b/fern/static/ttl-docs/ttl/ttl/kernel_runner.mdx deleted file mode 100644 index a87c92a..0000000 --- a/fern/static/ttl-docs/ttl/ttl/kernel_runner.mdx +++ /dev/null @@ -1,274 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/kernel_runner -title: ttl.kernel_runner ---- - -Shared kernel execution logic for tt-lang. - -Provides functions for building kernel descriptors, CB descriptors, and -executing kernels on device via ttnn.generic_op. Used by both the Python -DSL (CompiledTTNNKernel) and ME2E tests. - -This module provides a single reusable implementation of kernel argument -building and execution. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`KernelSpec`](#ttl-kernel_runner-KernelSpec) | Specification for a single kernel to execute. | - -### Functions - -| Name | Description | -|------|-------------| -| [`build_cb_descriptors`](#ttl-kernel_runner-build_cb_descriptors) | Build circular buffer descriptors for ttnn.generic_op. | -| [`build_kernel_descriptors`](#ttl-kernel_runner-build_kernel_descriptors) | Build kernel descriptors for ttnn.generic_op. | -| [`build_tensor_accessor_args`](#ttl-kernel_runner-build_tensor_accessor_args) | Build compile-time args for tensor accessors. | -| [`run_kernel_on_device`](#ttl-kernel_runner-run_kernel_on_device) | Execute kernels on device using ttnn.generic_op. | - -### Data - -[`__all__`](#ttl-kernel_runner-__all__) - -### API - - - - - -```python -class ttl.kernel_runner.KernelSpec( - path: str, - thread_type: str, - tensor_indices: typing.List[int], - config: typing.Any -) -``` - - - - - - -Dataclass - -Specification for a single kernel to execute. - - - - - - - - - - - - - - - - -```python -ttl.kernel_runner.build_cb_descriptors( - tensors: typing.List[typing.Any], - cb_configs: typing.List[typing.Any], - core_ranges: typing.Any -) -> typing.List[typing.Any] -``` - - - - - - -Build circular buffer descriptors for ttnn.generic_op. - -**Parameters:** - - -List of ttnn.Tensor objects. Each tensor's position (0, 1, 2, ...) -corresponds to its CB index. For intermediate CBs (not backed by -input/output tensors), pass None in the corresponding position. - - - -List of CircularBuffer objects for each CB, indexed by CB index. -Each CB has shape, buffer_factor, tensor (for dtype), and _cb_index attributes. - - - -ttnn.CoreRangeSet for CB allocation. - - -**Returns:** `List[Any]` - -List of ttnn.CBDescriptor objects. - - - - - - - - -```python -ttl.kernel_runner.build_kernel_descriptors( - kernel_specs: typing.List[ttl.kernel_runner.KernelSpec], - tensors: typing.List[typing.Any], - tensor_accessor_args: typing.List[int], - core_ranges: typing.Any, - grid_cols: int, - grid_rows: int, - num_cbs: int -) -> typing.List[typing.Any] -``` - - - - - - -Build kernel descriptors for ttnn.generic_op. - -**Parameters:** - - -List of kernel specifications. - - - -List of ttnn.Tensor objects. Position in this list determines -the global tensor index. Individual kernels access subsets via -tensor_indices in each KernelSpec. - - - -Flattened compile-time args from all tensors. - - - -ttnn.CoreRangeSet for kernel execution. - - - -Number of grid columns (x dimension). - - - -Number of grid rows (y dimension). - - - -Total number of circular buffers (including intermediate CBs). - - -**Returns:** `List[Any]` - -List of ttnn.KernelDescriptor objects. - - - - - - - - -```python -ttl.kernel_runner.build_tensor_accessor_args( - tensors: typing.List[typing.Any] -) -> typing.List[int] -``` - - - - - - -Build compile-time args for tensor accessors. - -**Parameters:** - - -List of ttnn.Tensor objects on device. - - -**Returns:** `List[int]` - -List of compile-time args (flattened TensorAccessorArgs for all tensors). - - - - - - - - -```python -ttl.kernel_runner.run_kernel_on_device( - kernel_specs: typing.List[ttl.kernel_runner.KernelSpec], - tensors: typing.List[typing.Any], - cb_configs: typing.List[typing.Any], - core_ranges: typing.Any, - program_hash: int = None -) -> typing.Any -``` - - - - - - -Execute kernels on device using ttnn.generic_op. - -This is the main entry point for kernel execution. It builds all -descriptors and runs the program. - -**Parameters:** - - -List of kernel specifications (path, thread_type, tensor_indices, config). - - - -List of ttnn.Tensor objects. Position in this list determines the -global tensor index. Individual kernels access subsets via tensor_indices -in each KernelSpec. - - - -List of CircularBuffer objects for each CB, indexed by CB index. -Includes both tensor-backed CBs and intermediate CBs. Each CB has shape, -buffer_factor, tensor (for dtype), and _cb_index attributes. - - - -ttnn.CoreRangeSet for kernel execution. - - - -Hash for tt-metal program cache (not yet used). - - -**Returns:** `Any` - -Result from ttnn.generic_op (typically None or output tensor). - - - - - - - - -```python -ttl.kernel_runner.__all__ = ['KernelSpec', 'build_tensor_accessor_args', 'build_kernel_descriptors', 'build_... -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/layouts.mdx b/fern/static/ttl-docs/ttl/ttl/layouts.mdx deleted file mode 100644 index 41784be..0000000 --- a/fern/static/ttl-docs/ttl/ttl/layouts.mdx +++ /dev/null @@ -1,126 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/layouts -title: ttl.layouts ---- - -Layout creation utilities for tensor distribution across cores. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`TTNNLayoutConfig`](#ttl-layouts-TTNNLayoutConfig) | Configuration for TTNN layout creation. Supports L1/DRAM interleaved tiled layouts. | - -### Functions - -| Name | Description | -|------|-------------| -| [`create_ttnn_layout`](#ttl-layouts-create_ttnn_layout) | Create a TTNNLayoutAttr for L1 interleaved tiled tensors. | - -### Data - -[`_TTNN_BUFFER_TYPE_L1`](#ttl-layouts-_TTNN_BUFFER_TYPE_L1) - -[`_TTNN_TENSOR_MEMORY_LAYOUT_INTERLEAVED`](#ttl-layouts-_TTNN_TENSOR_MEMORY_LAYOUT_INTERLEAVED) - -### API - - - - - -```python -class ttl.layouts.TTNNLayoutConfig( - logical_shape: typing.List[int], - grid: typing.List[int], - dtype: str -) -``` - - - - - - -Dataclass - -Configuration for TTNN layout creation. Supports L1/DRAM interleaved tiled layouts. - - - - - - - - - - - - - - - - -```python -ttl.layouts.create_ttnn_layout( - ctx, - config: ttl.layouts.TTNNLayoutConfig -) -``` - - - - - - -Create a TTNNLayoutAttr for L1 interleaved tiled tensors. - -Supports: L1/DRAM memory, Interleaved layout, tiled (32x32 tiles). - -**Parameters:** - - -MLIR context - - - -Configuration with logical_shape, grid, and dtype - - -**Returns:** - -TTNNLayoutAttr - -**Raises:** - -- `ValueError`: If configuration is unsupported - - - - - - - - -```python -ttl.layouts._TTNN_BUFFER_TYPE_L1 = 1 -``` - - - - - - - - - -```python -ttl.layouts._TTNN_TENSOR_MEMORY_LAYOUT_INTERLEAVED = 0 -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/operators.mdx b/fern/static/ttl-docs/ttl/ttl/operators.mdx deleted file mode 100644 index 14fabc9..0000000 --- a/fern/static/ttl-docs/ttl/ttl/operators.mdx +++ /dev/null @@ -1,642 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/operators -title: ttl.operators ---- - -DSL operators for tensor operations and data movement. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`CopyTransferHandler`](#ttl-operators-CopyTransferHandler) | Transfer handle for asynchronous copy operations. | -| [`TensorBlock`](#ttl-operators-TensorBlock) | Represents a block of tensor data in the TTL dialect. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_get_cb_from_block`](#ttl-operators-_get_cb_from_block) | Extract the CB from a block (result of ttl.attach_cb). | -| [`_get_cb_shape`](#ttl-operators-_get_cb_shape) | Extract the block shape from a CB value. | -| [`_get_constant_int`](#ttl-operators-_get_constant_int) | Extract Python int from MLIR arith.ConstantOp or return as-is if already int. | -| [`_get_current_grid`](#ttl-operators-_get_current_grid) | Get the current grid dimensions. | -| [`_is_block`](#ttl-operators-_is_block) | Check if a value is a block (result of cb.reserve() or cb.wait()). | -| [`_make_tensor_slice`](#ttl-operators-_make_tensor_slice) | Create a ttl.tensor_slice from a tensor, tile indices, and shape. | -| [`_process_tensor_subscript`](#ttl-operators-_process_tensor_subscript) | Process tensor subscript and create tensor slice. | -| [`_set_current_grid`](#ttl-operators-_set_current_grid) | Set the current grid dimensions. Called before compiling threads. | -| [`broadcast`](#ttl-operators-broadcast) | Broadcast over specified dimensions. | -| [`copy`](#ttl-operators-copy) | Initiate an asynchronous data transfer using ttl.copy. | -| [`core`](#ttl-operators-core) | Get the coordinates of the current core. | -| [`grid_size`](#ttl-operators-grid_size) | Get the size of the grid. | -| [`signpost`](#ttl-operators-signpost) | Emit a profiling marker visible in Tracy. | - -### Data - -[`CoreCoordinate`](#ttl-operators-CoreCoordinate) - -[`IndexedTensor`](#ttl-operators-IndexedTensor) - -[`__all__`](#ttl-operators-__all__) - -[`_current_grid`](#ttl-operators-_current_grid) - -### API - - - - - -```python -class ttl.operators.CopyTransferHandler() -``` - - - - - - -Transfer handle for asynchronous copy operations. - -CopyTransferHandler objects are returned by copy() calls and must be -explicitly waited on to ensure transfer completion. - - - - - - -```python -ttl.operators.CopyTransferHandler.wait( - ast_self: ttl.operators.CopyTransferHandler -) -``` - - - - - - -Block until the copy operation completes. - - - - - - - - - -```python -class ttl.operators.TensorBlock( - shape, - dtype -) -``` - - - - - - -Represents a block of tensor data in the TTL dialect. - -TensorBlock supports arithmetic operations through operator -overloading. Operations generate TTL high-level ops that get lowered -to ttl.compute blocks. - - - - - - -```python -ttl.operators.TensorBlock.__add__( - ast_self: ttl.operators.TensorBlock, - rhs: ttl.operators.TensorBlock -) -> ttl.operators.TensorBlock -``` - - - - - - -Element-wise addition using ttl.add. - -**Parameters:** - - -Right operand tensor. Must have the same shape as self. - - -**Returns:** `TensorBlock` - -Result tensor with the same shape as inputs. - - - - - - - -```python -ttl.operators.TensorBlock.__matmul__( - ast_self: ttl.operators.TensorBlock, - rhs: ttl.operators.TensorBlock -) -> ttl.operators.TensorBlock -``` - - - - - - -Matrix multiplication is not yet supported in TTL mode. - - - - - - - -```python -ttl.operators.TensorBlock.__mul__( - ast_self: ttl.operators.TensorBlock, - rhs: ttl.operators.TensorBlock -) -> ttl.operators.TensorBlock -``` - - - - - - -Element-wise multiplication using ttl.mul. - - - - - - - -```python -ttl.operators.TensorBlock.__sub__( - ast_self: ttl.operators.TensorBlock, - rhs: ttl.operators.TensorBlock -) -> ttl.operators.TensorBlock -``` - - - - - - -Element-wise subtraction using ttl.sub. - - - - - - - -```python -ttl.operators.TensorBlock.store( - ast_self: ttl.operators.TensorBlock, - rhs: ttl.operators.TensorBlock -) -> None -``` - - - - - - -Store result tensor to CB by propagating CB association from output view. - - - - - - - - - -```python -ttl.operators._get_cb_from_block( - block -) -``` - - - - - - -Extract the CB from a block (result of ttl.attach_cb). - -The attach_cb op has signature: (tensor, cb) -> tensor -So the CB is operand[1]. - - - - - - - - -```python -ttl.operators._get_cb_shape( - cb_val -) -``` - - - - - - -Extract the block shape from a CB value. - - - - - - - - -```python -ttl.operators._get_constant_int( - val -) -``` - - - - - - -Extract Python int from MLIR arith.ConstantOp or return as-is if already int. - - - - - - - - -```python -ttl.operators._get_current_grid() -> typing.Tuple[int, int] -``` - - - - - - -Get the current grid dimensions. - - - - - - - - -```python -ttl.operators._is_block( - value -) -> bool -``` - - - - - - -Check if a value is a block (result of cb.reserve() or cb.wait()). - -A block is a tensor with an attached CB, produced by ttl.attach_cb. - - - - - - - - -```python -ttl.operators._make_tensor_slice( - tensor, - indices, - slice_shape -) -``` - - - - - - -Create a ttl.tensor_slice from a tensor, tile indices, and shape. - -**Parameters:** - - -The source tensor to slice from - - - -(row, col) tile indices for the slice start position - - - -(rows, cols) shape for the slice in tiles - - - - - - - - - -```python -ttl.operators._process_tensor_subscript( - subscript_tuple, - cb_shape -) -``` - - - - - - -Process tensor subscript and create tensor slice. - -**Parameters:** - - -(tensor, indices) where indices are [(value, is_range), ...] - - - -[rows, cols] shape from the CB - - -**Returns:** - -Tensor slice with shape matching cb_shape - - - - - - - - -```python -ttl.operators._set_current_grid( - grid: typing.Tuple[int, int] -) -> None -``` - - - - - - -Set the current grid dimensions. Called before compiling threads. - - - - - - - - -```python -ttl.operators.broadcast( - input: ttl.operators.TensorBlock, - output: ttl.operators.TensorBlock, - dims: typing.List[int] -) -> ttl.operators.TensorBlock -``` - - - - - - -Broadcast over specified dimensions. - -**Parameters:** - - -Input tensor (CB-attached) - - - -Output tensor (CB-attached, used for output CB tracking) - - - -Dimensions to broadcast over - - -**Returns:** `TensorBlock` - -Result tensor with broadcast values - - - - - - - - -```python -ttl.operators.copy( - src, - dst -) -> ttl.operators.CopyTransferHandler -``` - - - - - - -Initiate an asynchronous data transfer using ttl.copy. - -For multi-tile CBs (shape > 1x1), use range syntax: tensor[0:2, 0:2] -For single-tile CBs (shape 1x1), use index syntax: tensor[0, 0] - -**Parameters:** - - -Source tensor/slice (for reads) or block (for writes) - - - -Destination block (for reads) or tensor/slice (for writes) - - -**Returns:** `CopyTransferHandler` - -CopyTransferHandler handle that must be waited on for completion - - - - - - - - -```python -ttl.operators.core( - dims -) -``` - - - - - - -Get the coordinates of the current core. - -Currently only dims=2 is supported (temporary restriction). - -**Parameters:** - - -Number of dimensions to return (must be 2) - - -**Returns:** - -For dims=2: Tuple (x, y) where x is column coordinate and y is row coordinate - -**Raises:** - -- `ValueError`: If dims is not 2 - - - - - - - - -```python -ttl.operators.grid_size( - dims -) -``` - - - - - - -Get the size of the grid. - -Currently only dims=2 is supported (temporary restriction). - -**Parameters:** - - -Number of dimensions to return (must be 2) - - -**Returns:** - -For dims=2: Tuple (x_size, y_size) where x_size is columns and y_size is rows - -**Raises:** - -- `ValueError`: If dims is not 2 - - - - - - - - -```python -ttl.operators.signpost( - name: str -) -``` - - - - - - -Emit a profiling marker visible in Tracy. - -The marker creates a DeviceZoneScopedN in the generated C++ code, -which will appear in Tracy profiler traces when TT_METAL_DEVICE_PROFILER=1. - -**Parameters:** - - -Name for the profiling region (must be a string literal) - - - - - - - - - -```python -ttl.operators.CoreCoordinate = Tuple[int, int] -``` - - - - - - - - - -```python -ttl.operators.IndexedTensor = Union['TensorBlock', Tuple['TensorBlock', Tuple[int, ...]]] -``` - - - - - - - - - -```python -ttl.operators.__all__ = ['TensorBlock', 'CopyTransferHandler', 'copy', 'core', 'grid_size', 'signpost', ... -``` - - - - - - - - - -```python -ttl.operators._current_grid: Tuple[int, int] = (-1, -1) -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/ttl.mdx b/fern/static/ttl-docs/ttl/ttl/ttl.mdx deleted file mode 100644 index 80fe5c9..0000000 --- a/fern/static/ttl-docs/ttl/ttl/ttl.mdx +++ /dev/null @@ -1,27 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/ttl -title: ttl.ttl ---- - -TTL DSL module providing the unified ttl.* API namespace. - -## Module Contents - -### Data - -[`__all__`](#ttl-ttl-__all__) - -### API - - - - - -```python -ttl.ttl.__all__ = ['kernel', 'compute', 'datamovement', 'Program', 'make_circular_buffer_like', 'c... -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/ttl_api.mdx b/fern/static/ttl-docs/ttl/ttl/ttl_api.mdx deleted file mode 100644 index e890735..0000000 --- a/fern/static/ttl-docs/ttl/ttl/ttl_api.mdx +++ /dev/null @@ -1,907 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/ttl_api -title: ttl.ttl_api ---- - -Main API for the TTL dialect Python DSL. - -## Module Contents - -### Classes - -| Name | Description | -|------|-------------| -| [`CompiledTTNNKernel`](#ttl-ttl_api-CompiledTTNNKernel) | A compiled tt-lang kernel ready for execution via ttnn.generic_op. | -| [`Program`](#ttl-ttl_api-Program) | Immutable container for kernel threads and their arguments. | - -### Functions - -| Name | Description | -|------|-------------| -| [`_clear_thread_registry`](#ttl-ttl_api-_clear_thread_registry) | Clear the thread registry before kernel execution. | -| [`_collect_captures`](#ttl-ttl_api-_collect_captures) | Collect and convert captured variables from function closure. | -| [`_collect_cb_configs`](#ttl-ttl_api-_collect_cb_configs) | Extract CircularBuffer objects from thread closures, indexed by cb_index. | -| [`_compile`](#ttl-ttl_api-_compile) | Internal decorator for compiling kernel threads. | -| [`_compile_kernel`](#ttl-ttl_api-_compile_kernel) | Compile kernel function to MLIR and return CompiledTTNNKernel. | -| [`_compile_ttnn_kernel`](#ttl-ttl_api-_compile_ttnn_kernel) | Compile kernel to CompiledTTNNKernel for execution via ttnn.generic_op. | -| [`_detect_memory_space_from_tensor`](#ttl-ttl_api-_detect_memory_space_from_tensor) | Detect memory space (L1/DRAM) from a ttnn tensor's buffer type. | -| [`_get_registered_threads`](#ttl-ttl_api-_get_registered_threads) | Get all registered threads and clear the registry. | -| [`_get_source_line_offset`](#ttl-ttl_api-_get_source_line_offset) | Get the line offset to convert parsed AST line numbers to actual file lines. | -| [`_get_tensor_cache_info`](#ttl-ttl_api-_get_tensor_cache_info) | Extract cache-relevant info from a tensor: (shape, dtype, memory_space, layout). | -| [`_has_float32_args`](#ttl-ttl_api-_has_float32_args) | Check if any input tensor uses float32 dtype. | -| [`_is_interleaved_tensor`](#ttl-ttl_api-_is_interleaved_tensor) | Check if a ttnn tensor has interleaved memory layout. | -| [`_make_cache_key`](#ttl-ttl_api-_make_cache_key) | Create cache key from tensor properties and runtime compute config parameters. | -| [`_register_thread`](#ttl-ttl_api-_register_thread) | Register a thread function during decoration. | -| [`_resolve_grid`](#ttl-ttl_api-_resolve_grid) | Resolve grid, evaluating callable or 'auto' if needed. | -| [`_run_profiling_pipeline`](#ttl-ttl_api-_run_profiling_pipeline) | Read device profiler data and display profile report. | -| [`_should_execute`](#ttl-ttl_api-_should_execute) | Check if kernel execution should proceed (not compile-only mode). | -| [`_track_tensor_sources`](#ttl-ttl_api-_track_tensor_sources) | Track source locations for tensor arguments. | -| [`_write_kernel_to_tmp`](#ttl-ttl_api-_write_kernel_to_tmp) | Write kernel source to /tmp and return the file path. | -| [`compute`](#ttl-ttl_api-compute) | Decorator for compute thread functions. | -| [`datamovement`](#ttl-ttl_api-datamovement) | Decorator for data movement thread functions. | -| [`pykernel_gen`](#ttl-ttl_api-pykernel_gen) | Decorator for generating TTL kernels from Python functions. | - -### Data - -[`__all__`](#ttl-ttl_api-__all__) - -[`_thread_registry`](#ttl-ttl_api-_thread_registry) - -[`kernel`](#ttl-ttl_api-kernel) - -### API - - - - - -```python -class ttl.ttl_api.CompiledTTNNKernel( - kernel_paths, - kernel_configs, - kernel_arg_specs, - num_tensors, - core_ranges, - kernel_tensor_indices, - cb_configs = None, - program_hash = None, - source_lines = None, - all_source_lines = None, - thread_to_kernel = None, - kernel_line_offsets = None -) -``` - - - - - - -A compiled tt-lang kernel ready for execution via ttnn.generic_op. - -Caches compilation artifacts (kernel paths, CB descriptors) so the kernel -can be executed multiple times with different tensors without recompiling. - - - - - - - - - - - - - - - - - -```python -ttl.ttl_api.CompiledTTNNKernel.__call__( - args = () -) -``` - - - - - - -Execute the kernel with the given tensors. - - - - - - - - - -```python -class ttl.ttl_api.Program( - threads = (), - args = (), - kwargs = None -) -``` - - - - - - -Immutable container for kernel threads and their arguments. - -A Program encapsulates compute and data movement threads along with -the arguments to be passed during execution. After construction, all -fields should be treated as read-only. - - - - - - - - - - - - - - - - - -```python -ttl.ttl_api.Program.__call__( - args = (), - kwargs = {} -) -``` - - - - - - - - - - - - - - -```python -ttl.ttl_api._clear_thread_registry() -> None -``` - - - - - - -Clear the thread registry before kernel execution. - - - - - - - - -```python -ttl.ttl_api._collect_captures( - f: typing.Callable -) -> typing.Dict[str, typing.Union[int, ttl.circular_buffer.CircularBuffer]] -``` - - - - - - -Collect and convert captured variables from function closure. - -**Parameters:** - - -Function with closure to inspect - - -**Returns:** `Dict[str, Union[int, CircularBuffer]]` - -Dictionary mapping variable names to converted values - -**Raises:** - -- `TypeError`: If closure contains unsupported variable types - - - - - - - - -```python -ttl.ttl_api._collect_cb_configs( - threads -) -``` - - - - - - -Extract CircularBuffer objects from thread closures, indexed by cb_index. - -Returns a list of CircularBuffer objects indexed by cb_index. Each CB has -shape, buffer_factor, tensor (for dtype), and _cb_index attributes. - - - - - - - - -```python -ttl.ttl_api._compile( - kernel_type: typing.Optional[str] = None, - verbose: bool = False -) -> typing.Callable -``` - - - - - - -Internal decorator for compiling kernel threads. - -**Parameters:** - - -Type of kernel ("compute" or "datamovement") - - - -Enable verbose compilation output - - -**Returns:** `Callable` - -Decorator function for kernel compilation - - - - - - - - -```python -ttl.ttl_api._compile_kernel( - f: typing.Callable, - args: tuple, - kwargs: dict, - grid: typing.Union[tuple, typing.List[int]], - indexing_maps: typing.List[typing.Callable], - iterator_types: typing.List[str], - num_outs: int, - memory_space: str, - tiled: bool, - program_hash: int, - fp32_dest_acc_en: typing.Optional[bool] = None, - dst_full_sync_en: typing.Optional[bool] = None -) -> typing.Optional[ttl.ttl_api.CompiledTTNNKernel] -``` - - - - - - -Compile kernel function to MLIR and return CompiledTTNNKernel. - -**Parameters:** - - -User kernel function - - - -Positional arguments for the kernel - - - -Keyword arguments for the kernel - - - -Grid dimensions - - - -List of lambda functions for indexing - - - -List of iterator type strings - - - -Number of output arguments - - - -"L1" or "DRAM" - - - -Whether to use tiled layout - - - -Hash for tt-metal program cache - - - -Optional override for fp32_dest_acc_en - - - -Optional override for dst_full_sync_en - - -**Returns:** `Optional[CompiledTTNNKernel]` - -CompiledTTNNKernel ready for execution - - - - - - - - -```python -ttl.ttl_api._compile_ttnn_kernel( - module, - args, - grid, - num_outs, - thread_tensor_indices, - cb_configs = None, - program_hash = None, - fp32_dest_acc_en: typing.Optional[bool] = None, - dst_full_sync_en: typing.Optional[bool] = None, - verbose = True, - source_lines = None, - all_source_lines = None, - kernel_line_offsets = None -) -``` - - - - - - -Compile kernel to CompiledTTNNKernel for execution via ttnn.generic_op. - -Builds kernel paths, configs, and CB descriptors from compiled MLIR module. - -**Parameters:** - - -MLIR module after D2M pipeline (with EmitC kernels) - - - -Input/output tensors (used for shape/dtype info) - - - -Grid dimensions tuple - - - -Number of output tensors - - - -Hash for tt-metal program cache - - - -Print compilation info - - - -Source code lines for auto-profiling reports - - -**Returns:** - -CompiledTTNNKernel ready for execution - - - - - - - - -```python -ttl.ttl_api._detect_memory_space_from_tensor( - tensor, - default: str -) -> str -``` - - - - - - -Detect memory space (L1/DRAM) from a ttnn tensor's buffer type. - - - - - - - - -```python -ttl.ttl_api._get_registered_threads() -> typing.List[typing.Callable] -``` - - - - - - -Get all registered threads and clear the registry. - - - - - - - - -```python -ttl.ttl_api._get_source_line_offset( - f -) -> int -``` - - - - - - -Get the line offset to convert parsed AST line numbers to actual file lines. - - - - - - - - -```python -ttl.ttl_api._get_tensor_cache_info( - tensor -) -> tuple -``` - - - - - - -Extract cache-relevant info from a tensor: (shape, dtype, memory_space, layout). - - - - - - - - -```python -ttl.ttl_api._has_float32_args( - args -) -> bool -``` - - - - - - -Check if any input tensor uses float32 dtype. - -Inspects the tensor arguments to detect float32. This is used to -automatically enable fp32_dest_acc_en configuration for compute kernels. - -**Parameters:** - - -List of tensor arguments (torch or ttnn) - - -**Returns:** `bool` - -True if any tensor uses float32 dtype, False otherwise - - - - - - - - -```python -ttl.ttl_api._is_interleaved_tensor( - tensor -) -> bool -``` - - - - - - -Check if a ttnn tensor has interleaved memory layout. - - - - - - - - -```python -ttl.ttl_api._make_cache_key( - args: tuple, - fp32_dest_acc_en: typing.Optional[bool], - dst_full_sync_en: typing.Optional[bool] -) -> tuple -``` - - - - - - -Create cache key from tensor properties and runtime compute config parameters. - - - - - - - - -```python -ttl.ttl_api._register_thread( - thread_fn: typing.Callable -) -> None -``` - - - - - - -Register a thread function during decoration. - - - - - - - - -```python -ttl.ttl_api._resolve_grid( - grid, - args, - kwargs -) -``` - - - - - - -Resolve grid, evaluating callable or 'auto' if needed. - - - - - - - - -```python -ttl.ttl_api._run_profiling_pipeline( - tensors: tuple, - all_source_lines: typing.Dict[str, typing.List[str]], - thread_to_kernel: typing.Dict[str, str], - kernel_line_offsets: typing.Optional[typing.Dict[str, int]] = None -) -``` - - - - - - -Read device profiler data and display profile report. - -Called after kernel execution when auto-profiling is enabled. - -**Parameters:** - - -Tuple of tensor arguments passed to the kernel - - - -Dict mapping kernel name to source lines - - - -Dict mapping RISC thread name to kernel name - - - - - - - - - -```python -ttl.ttl_api._should_execute() -> bool -``` - - - - - - -Check if kernel execution should proceed (not compile-only mode). - - - - - - - - -```python -ttl.ttl_api._track_tensor_sources( - f_params, - args, - source_file: str -) -> None -``` - - - - - - -Track source locations for tensor arguments. - -Searches backwards from the kernel call site to find where each -tensor variable was assigned, then registers that location. - - - - - - - - -```python -ttl.ttl_api._write_kernel_to_tmp( - name: str, - source: str -) -> str -``` - - - - - - -Write kernel source to /tmp and return the file path. - - - - - - - - -```python -ttl.ttl_api.compute( - verbose: bool = False -) -> typing.Callable -``` - - - - - - -Decorator for compute thread functions. - -Compute threads execute on Tensix cores and perform mathematical operations. - -**Parameters:** - - -Enable verbose compilation output - - -**Returns:** `Callable` - -Decorator for compute kernel compilation - - - - - - - - -```python -ttl.ttl_api.datamovement( - verbose: bool = False -) -> typing.Callable -``` - - - - - - -Decorator for data movement thread functions. - -Data movement threads handle DMA operations between memory hierarchies. - -**Parameters:** - - -Enable verbose compilation output - - -**Returns:** `Callable` - -Decorator for data movement kernel compilation - - - - - - - - -```python -ttl.ttl_api.pykernel_gen( - grid: typing.Optional[typing.Union[tuple, typing.Callable]] = None, - indexing_maps: typing.Optional[typing.List[typing.Callable]] = None, - iterator_types: typing.Optional[typing.List[str]] = None, - num_outs: int = 1, - memory_space: str = 'L1', - tiled: bool = True, - fp32_dest_acc_en: typing.Optional[bool] = None, - dst_full_sync_en: typing.Optional[bool] = None -) -> typing.Callable -``` - - - - - - -Decorator for generating TTL kernels from Python functions. - -This decorator compiles Python functions into TTL dialect operations, -handling thread compilation, stream creation, and pipeline execution. -Kernels are compiled to C++ for execution via ttnn.generic_op. - -**Parameters:** - - -Grid dimensions as tuple (e.g., (2, 2)) or callable - - - -List of lambda functions for indexing (optional) - - - -List of iterator types ("parallel", "reduction") - - - -Number of output arguments - - - -"L1" or "DRAM" - - - -Whether to use tiled layout - - - -Optional override for fp32_dest_acc_en - - - -Optional override for dst_full_sync_en - - -**Returns:** `Callable` - -Decorated function that compiles and executes the kernel - -**Raises:** - -- `AssertionError`: If required parameters are missing or invalid - - - - - - - - -```python -ttl.ttl_api.__all__ = ['pykernel_gen', 'kernel', 'Program', 'compute', 'datamovement', 'TensorBlock', ... -``` - - - - - - - - - -```python -ttl.ttl_api._thread_registry: List[Callable] = [] -``` - - - - - - - - - -```python -ttl.ttl_api.kernel = pykernel_gen -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/ttl_math.mdx b/fern/static/ttl-docs/ttl/ttl/ttl_math.mdx deleted file mode 100644 index 9472960..0000000 --- a/fern/static/ttl-docs/ttl/ttl/ttl_math.mdx +++ /dev/null @@ -1,29 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/ttl_math -title: ttl.ttl_math ---- - -TTL math operations namespace (ttl.math). - -Re-exports elementwise operations from the generated module. - -## Module Contents - -### Data - -[`__all__`](#ttl-ttl_math-__all__) - -### API - - - - - -```python -ttl.ttl_math.__all__ = ['broadcast', *_generated_all] -``` - - - - diff --git a/fern/static/ttl-docs/ttl/ttl/ttl_utils.mdx b/fern/static/ttl-docs/ttl/ttl/ttl_utils.mdx deleted file mode 100644 index 813bd24..0000000 --- a/fern/static/ttl-docs/ttl/ttl/ttl_utils.mdx +++ /dev/null @@ -1,70 +0,0 @@ ---- -layout: overview -slug: ttl/ttl/ttl_utils -title: ttl.ttl_utils ---- - -Utility functions for tt-lang. - -## Module Contents - -### Functions - -| Name | Description | -|------|-------------| -| [`get_thread_type_string`](#ttl-ttl_utils-get_thread_type_string) | Map kernel type to thread type string. | - -### Data - -[`_KERNEL_TYPE_TO_THREAD_TYPE`](#ttl-ttl_utils-_KERNEL_TYPE_TO_THREAD_TYPE) - -### API - - - - - -```python -ttl.ttl_utils.get_thread_type_string( - input: typing.Union[str, object] -) -> str -``` - - - - - - -Map kernel type to thread type string. - -Handles both string kernel types and MLIR ThreadTypeAttr. - -**Parameters:** - - -Either a string kernel type ("compute", "datamovement", "ethernet") - or a ttkernel.ThreadTypeAttr from MLIR IR - - -**Returns:** `str` - -Thread type string: "compute", "noc", "ethernet" - -**Raises:** - -- `ValueError`: If input is a string that's not a valid kernel type - - - - - - - - -```python -ttl.ttl_utils._KERNEL_TYPE_TO_THREAD_TYPE = {'compute': 'compute', 'datamovement': 'noc', 'ethernet': 'ethernet'} -``` - - - - From 13035eef401fc524e544b17445620afff5046668 Mon Sep 17 00:00:00 2001 From: Paarth Gupta Date: Mon, 2 Mar 2026 17:45:53 -0500 Subject: [PATCH 5/6] golden data set for c++ rendering --- fern/docs.yml | 47 +- fern/pages/cub/block_reduce.mdx | 521 ++++++ fern/pages/cub/block_reduce_v3.mdx | 550 ++++++ fern/pages/cub/block_scan.mdx | 1586 ++++++++++++++++ fern/pages/cub/block_scan_v4.mdx | 1597 +++++++++++++++++ fern/pages/cub/simple_struct.mdx | 48 + fern/pages/cub/simple_struct_v4.mdx | 48 + fern/pages/cub/warp_reduce.mdx | 837 +++++++++ fern/pages/cub/warp_reduce_v4.mdx | 847 +++++++++ fern/pages/libcudacxx/concept_example.mdx | 43 + fern/pages/libcudacxx/concept_example_v3.mdx | 53 + fern/pages/libcudacxx/deep_template_class.mdx | 245 +++ .../libcudacxx/deep_template_class_v4.mdx | 247 +++ .../libcudacxx/empty_docstring_class.mdx | 846 +++++++++ .../libcudacxx/empty_docstring_class_v4.mdx | 830 +++++++++ fern/pages/libcudacxx/raises_example.mdx | 506 ++++++ fern/pages/libcudacxx/raises_example_v4.mdx | 495 +++++ fern/pages/thrust/deprecated_example.mdx | 134 ++ fern/pages/thrust/deprecated_example_v4.mdx | 134 ++ fern/pages/thrust/device_vector.mdx | 1188 ++++++++++++ fern/pages/thrust/device_vector_v3.mdx | 1424 +++++++++++++++ fern/pages/thrust/group_member_example.mdx | 408 +++++ fern/pages/thrust/group_member_example_v4.mdx | 408 +++++ fern/pages/thrust/pointer.mdx | 295 +++ fern/pages/thrust/pointer_v4.mdx | 300 ++++ 25 files changed, 13636 insertions(+), 1 deletion(-) create mode 100644 fern/pages/cub/block_reduce.mdx create mode 100644 fern/pages/cub/block_reduce_v3.mdx create mode 100644 fern/pages/cub/block_scan.mdx create mode 100644 fern/pages/cub/block_scan_v4.mdx create mode 100644 fern/pages/cub/simple_struct.mdx create mode 100644 fern/pages/cub/simple_struct_v4.mdx create mode 100644 fern/pages/cub/warp_reduce.mdx create mode 100644 fern/pages/cub/warp_reduce_v4.mdx create mode 100644 fern/pages/libcudacxx/concept_example.mdx create mode 100644 fern/pages/libcudacxx/concept_example_v3.mdx create mode 100644 fern/pages/libcudacxx/deep_template_class.mdx create mode 100644 fern/pages/libcudacxx/deep_template_class_v4.mdx create mode 100644 fern/pages/libcudacxx/empty_docstring_class.mdx create mode 100644 fern/pages/libcudacxx/empty_docstring_class_v4.mdx create mode 100644 fern/pages/libcudacxx/raises_example.mdx create mode 100644 fern/pages/libcudacxx/raises_example_v4.mdx create mode 100644 fern/pages/thrust/deprecated_example.mdx create mode 100644 fern/pages/thrust/deprecated_example_v4.mdx create mode 100644 fern/pages/thrust/device_vector.mdx create mode 100644 fern/pages/thrust/device_vector_v3.mdx create mode 100644 fern/pages/thrust/group_member_example.mdx create mode 100644 fern/pages/thrust/group_member_example_v4.mdx create mode 100644 fern/pages/thrust/pointer.mdx create mode 100644 fern/pages/thrust/pointer_v4.mdx diff --git a/fern/docs.yml b/fern/docs.yml index 28406aa..8c9bef7 100644 --- a/fern/docs.yml +++ b/fern/docs.yml @@ -22,6 +22,9 @@ tabs: API Reference: display-name: API Reference icon: puzzle + C++ Golden Pages: + display-name: C++ Golden Pages + icon: code Library Reference: display-name: Library Reference icon: book @@ -29,7 +32,10 @@ tabs: display-name: TTL Reference icon: book LangChain Core Reference: - display-name: LangChain Core + display-name: LangChain Core + icon: book + Django Core Reference: + display-name: Django Core Reference icon: book libraries: @@ -54,6 +60,7 @@ libraries: path: ./library-docs/langchain-core-docs lang: python + navigation: - tab: home layout: @@ -74,6 +81,12 @@ navigation: - page: Customize your docs path: docs/pages/customization.mdx icon: fa-duotone fa-palette + - page: Steps TOC Test + path: docs/pages/steps-toc-test.mdx + icon: fa-duotone fa-list-ol + - page: Nominal Data Model + path: docs/pages/nominal-data-model.mdx + icon: fa-duotone fa-database - page: Support path: docs/pages/support.mdx icon: fa-duotone fa-headset @@ -97,6 +110,38 @@ navigation: referenced-packages: - user contents: [] + - tab: C++ Golden Pages + layout: + - section: CUB + contents: + - page: BlockReduce + path: pages/cub/block_reduce_v3.mdx + - page: BlockScan + path: pages/cub/block_scan_v4.mdx + - page: WarpReduce + path: pages/cub/warp_reduce_v4.mdx + - page: ArgMax + path: pages/cub/simple_struct_v4.mdx + - section: Thrust + contents: + - page: device_vector + path: pages/thrust/device_vector_v3.mdx + - page: pointer + path: pages/thrust/pointer_v4.mdx + - page: strided_iterator + path: pages/thrust/deprecated_example_v4.mdx + - page: disjoint_unsynchronized_pool_resource + path: pages/thrust/group_member_example_v4.mdx + - section: libcudacxx + contents: + - page: resource_with + path: pages/libcudacxx/concept_example_v3.mdx + - page: counting_iterator + path: pages/libcudacxx/deep_template_class_v4.mdx + - page: buffer + path: pages/libcudacxx/empty_docstring_class_v4.mdx + - page: stream + path: pages/libcudacxx/raises_example_v4.mdx - tab: Library Reference layout: - library: nemo-rl diff --git a/fern/pages/cub/block_reduce.mdx b/fern/pages/cub/block_reduce.mdx new file mode 100644 index 0000000..013402c --- /dev/null +++ b/fern/pages/cub/block_reduce.mdx @@ -0,0 +1,521 @@ +--- +title: cub::BlockReduce +description: "Collective methods for computing parallel reductions across a CUDA thread block." +--- + +# BlockReduce + +The `BlockReduce` class provides collective methods for computing a parallel reduction of items partitioned across a CUDA thread block. + +A [reduction](http://en.wikipedia.org/wiki/Reduce_(higher-order_function)) (or *fold*) uses a binary combining operator to compute a single aggregate from a list of input elements. Threads are assumed to be in row-major order. + +`BlockReduce` can be optionally specialized by algorithm to accommodate different latency/throughput workload profiles: + +1. [`cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY`](/library/api/cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY): + An efficient "raking" reduction algorithm that only supports commutative reduction operators. +2. [`cub::BLOCK_REDUCE_RAKING`](/library/api/cub::BLOCK_REDUCE_RAKING): + An efficient "raking" reduction algorithm that supports commutative and non-commutative reduction operators. +3. [`cub::BLOCK_REDUCE_WARP_REDUCTIONS`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS): + A quick "tiled warp-reductions" reduction algorithm that supports commutative and non-commutative reduction operators. +4. [`cub::BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC): + A quick "tiled warp-reductions" reduction algorithm that supports commutative and non-commutative reduction operators. This variant uses atomic operations to reduce the warp-wide reduction results, making it non-deterministic, i.e. the order of reduction operations is not guaranteed to be the same across different invocations of the same kernel. + +### Performance considerations + +- Performance is sensitive to the degree of data movement across the block. +- Very efficient (only one synchronization barrier). +- Incurs zero bank conflicts for most types. +- Computation is slightly more efficient (i.e., having lower instruction overhead) for: + - Summation (vs. generic reduction) + - `BLOCK_THREADS` is a multiple of the architecture's warp size + - Every thread has a valid input (i.e., full vs. partial-tiles) +- See `cub::BlockReduceAlgorithm` for performance details regarding algorithmic alternatives. + +### Example + +The code snippet below illustrates a sum reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data); +} +``` + + + + +Data type being reduced + + + +The thread block length in threads along the X dimension + + + +**[optional]** [cub::BlockReduceAlgorithm](/library/api/cub::BlockReduceAlgorithm) enumerator specifying the underlying algorithm to use (default: [cub::BLOCK_REDUCE_WARP_REDUCTIONS](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS)) + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + +--- + +## Collective constructors + +### BlockReduce inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + +```cpp +cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::BlockReduce() +``` + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + +```cpp +cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::BlockReduce(TempStorage &temp_storage) +``` + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +#### Parameters + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockReduce::TempStorage) + + + + + +--- + +## Generic reductions + +### Reduce inline + + + + +Computes a block-wide reduction for thread0 using the specified binary reduction functor. Each thread contributes one input element. + +```cpp +template +T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Reduce( + T input, + ReductionOp reduction_op) +``` + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +#### Template parameters + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + +#### Parameters + + +Calling thread's input + + + +Binary reduction functor + + +#### Example + +The code snippet below illustrates a max reduction of 128 integer items that are partitioned across 128 threads. + +```cpp +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item + int thread_data; + ... + + // Compute the block-wide max for thread0 + int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cuda::maximum<>{}); +} +``` + + + + +Computes a block-wide reduction for thread0 using the specified binary reduction functor. Each thread contributes an array of consecutive input elements. + +```cpp +template +T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Reduce( + T(&inputs)[ITEMS_PER_THREAD], + ReductionOp reduction_op) +``` + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Performance is sensitive to the degree of data movement across the block. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +#### Template parameters + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + +#### Parameters + + +Calling thread's input segment + + + +Binary reduction functor + + +#### Example + +The code snippet below illustrates a max reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Compute the block-wide max for thread0 + int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cuda::maximum<>{}); +} +``` + + + + +Computes a block-wide reduction for thread0 using the specified binary reduction functor. The first `num_valid` threads each contribute one input element. + +```cpp +template +T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Reduce( + T input, + ReductionOp reduction_op, + int num_valid) +``` + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +#### Template parameters + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + +#### Parameters + + +Calling thread's input + + + +Binary reduction functor + + + +Number of threads containing valid elements (may be less than BLOCK_THREADS) + + +#### Example + +The code snippet below illustrates a max reduction of a partially-full tile of integer items that are partitioned across 128 threads. + +```cpp +#include // or equivalently + +__global__ void ExampleKernel(int num_valid, ...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item + int thread_data; + if (threadIdx.x < num_valid) thread_data = ... + + // Compute the block-wide max for thread0 + int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cuda::maximum<>{}, num_valid); +} +``` + + + + +--- + +## Summation reductions + +### Sum inline + + + + +Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. Each thread contributes one input element. + +```cpp +T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Sum(T input) +``` + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +#### Parameters + + +Calling thread's input + + +#### Example + +The code snippet below illustrates a sum reduction of 128 integer items that are partitioned across 128 threads. + +```cpp +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item + int thread_data; + ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data); +} +``` + + + + +Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. Each thread contributes an array of consecutive input elements. + +```cpp +template +T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Sum( + T(&inputs)[ITEMS_PER_THREAD]) +``` + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Performance is sensitive to the degree of data movement across the block. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +#### Template parameters + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +#### Parameters + + +Calling thread's input segment + + +#### Example + +The code snippet below illustrates a sum reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data); +} +``` + + + + +Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. The first `num_valid` threads each contribute one input element. + +```cpp +T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Sum( + T input, + int num_valid) +``` + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +#### Parameters + + +Calling thread's input + + + +Number of threads containing valid elements (may be less than BLOCK_THREADS) + + +#### Example + +The code snippet below illustrates a sum reduction of a partially-full tile of integer items that are partitioned across 128 threads. + +```cpp +#include // or equivalently + +__global__ void ExampleKernel(int num_valid, ...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item (up to num_items) + int thread_data; + if (threadIdx.x < num_valid) + thread_data = ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data, num_valid); +} +``` + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + +```cpp +_TempStorage & cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::PrivateStorage() +``` + +**Returns:** Reference to [_TempStorage](/library/api/cub::BlockReduce::_TempStorage) + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `InternalBlockReduce` | `::cuda::std::_If< Algorithm==`[`BLOCK_REDUCE_WARP_REDUCTIONS`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS)`, `[`WarpReductions`](/library/api/cub::BlockReduce::WarpReductions)`, ::cuda::std::_If< Algorithm==`[`BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC)`, `[`WarpReductionsNondeterministic`](/library/api/cub::BlockReduce::WarpReductionsNondeterministic)`, ::cuda::std::_If< Algorithm==`[`BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY`](/library/api/cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY)`, `[`RakingCommutativeOnly`](/library/api/cub::BlockReduce::RakingCommutativeOnly)`, `[`Raking`](/library/api/cub::BlockReduce::Raking)` > > >` | Internal specialization type. | +| `_TempStorage` | `typename InternalBlockReduce::TempStorage` | Shared memory storage layout type for [BlockReduce](/library/api/cub::BlockReduce). | +| `WarpReductions` | `detail::BlockReduceWarpReductions< T, BlockDimX, BlockDimY, BlockDimZ >` | | +| `WarpReductionsNondeterministic` | `detail::BlockReduceWarpReductions< T, BlockDimX, BlockDimY, BlockDimZ, false >` | | +| `RakingCommutativeOnly` | `detail::BlockReduceRakingCommutativeOnly< T, BlockDimX, BlockDimY, BlockDimZ >` | | +| `Raking` | `detail::BlockReduceRaking< T, BlockDimX, BlockDimY, BlockDimZ >` | | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `temp_storage` | [`_TempStorage`](/library/api/cub::BlockReduce::_TempStorage) `&` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | + +--- + +## Inner classes + +### TempStorage + +```cpp +struct cub::BlockReduce::TempStorage +``` + +The operations exposed by [BlockReduce](/library/api/cub::BlockReduce) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/pages/cub/block_reduce_v3.mdx b/fern/pages/cub/block_reduce_v3.mdx new file mode 100644 index 0000000..8bd6d0d --- /dev/null +++ b/fern/pages/cub/block_reduce_v3.mdx @@ -0,0 +1,550 @@ +--- +title: cub::BlockReduce +description: "Collective methods for computing parallel reductions across a CUDA thread block." +--- + +The `BlockReduce` class provides collective methods for computing a parallel reduction of items partitioned across a CUDA thread block. + +A [reduction](http://en.wikipedia.org/wiki/Reduce_(higher-order_function)) (or *fold*) uses a binary combining operator to compute a single aggregate from a list of input elements. Threads are assumed to be in row-major order. + +`BlockReduce` can be optionally specialized by algorithm to accommodate different latency/throughput workload profiles: + +1. [`cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY`](/library/api/cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY): + An efficient "raking" reduction algorithm that only supports commutative reduction operators. +2. [`cub::BLOCK_REDUCE_RAKING`](/library/api/cub::BLOCK_REDUCE_RAKING): + An efficient "raking" reduction algorithm that supports commutative and non-commutative reduction operators. +3. [`cub::BLOCK_REDUCE_WARP_REDUCTIONS`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS): + A quick "tiled warp-reductions" reduction algorithm that supports commutative and non-commutative reduction operators. +4. [`cub::BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC): + A quick "tiled warp-reductions" reduction algorithm that supports commutative and non-commutative reduction operators. This variant uses atomic operations to reduce the warp-wide reduction results, making it non-deterministic, i.e. the order of reduction operations is not guaranteed to be the same across different invocations of the same kernel. + +## Performance considerations + +- Performance is sensitive to the degree of data movement across the block. +- Very efficient (only one synchronization barrier). +- Incurs zero bank conflicts for most types. +- Computation is slightly more efficient (i.e., having lower instruction overhead) for: + - Summation (vs. generic reduction) + - `BLOCK_THREADS` is a multiple of the architecture's warp size + - Every thread has a valid input (i.e., full vs. partial-tiles) +- See `cub::BlockReduceAlgorithm` for performance details regarding algorithmic alternatives. + +## Example + +The code snippet below illustrates a sum reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data); +} +``` + + + + + +Data type being reduced + + + +The thread block length in threads along the X dimension + + + +**[optional]** [cub::BlockReduceAlgorithm](/library/api/cub::BlockReduceAlgorithm) enumerator specifying the underlying algorithm to use (default: [cub::BLOCK_REDUCE_WARP_REDUCTIONS](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS)) + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockReduce inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockReduce::BlockReduce() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockReduce::BlockReduce( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockReduce::TempStorage) + + + + + +--- + +## Generic reductions + +### Reduce inline + + + + +Computes a block-wide reduction for thread0 using the specified binary reduction functor. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +template +T cub::BlockReduce::Reduce( + T input, + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction functor + + +**Example** + +The code snippet below illustrates a max reduction of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item + int thread_data; + ... + + // Compute the block-wide max for thread0 + int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cuda::maximum<>{}); +} +``` + + + + +Computes a block-wide reduction for thread0 using the specified binary reduction functor. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +T cub::BlockReduce::Reduce( + T (&inputs)[ITEMS_PER_THREAD], + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Performance is sensitive to the degree of data movement across the block. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input segment + + + +Binary reduction functor + + +**Example** + +The code snippet below illustrates a max reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Compute the block-wide max for thread0 + int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cuda::maximum<>{}); +} +``` + + + + +Computes a block-wide reduction for thread0 using the specified binary reduction functor. The first `num_valid` threads each contribute one input element. + + +```cpp showLineNumbers={false} +template +T cub::BlockReduce::Reduce( + T input, + ReductionOp reduction_op, + int num_valid +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction functor + + + +Number of threads containing valid elements (may be less than BLOCK_THREADS) + + +**Example** + +The code snippet below illustrates a max reduction of a partially-full tile of integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int num_valid, ...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item + int thread_data; + if (threadIdx.x < num_valid) thread_data = ... + + // Compute the block-wide max for thread0 + int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cuda::maximum<>{}, num_valid); +} +``` + + + + +--- + +## Summation reductions + +### Sum inline + + + + +Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +T cub::BlockReduce::Sum( + T input +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + +**Example** + +The code snippet below illustrates a sum reduction of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item + int thread_data; + ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data); +} +``` + + + + +Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +T cub::BlockReduce::Sum( + T (&inputs)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Performance is sensitive to the degree of data movement across the block. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input segment + + +**Example** + +The code snippet below illustrates a sum reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data); +} +``` + + + + +Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. The first `num_valid` threads each contribute one input element. + + +```cpp showLineNumbers={false} +T cub::BlockReduce::Sum( + T input, + int num_valid +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined in threads other than thread0. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + +Number of threads containing valid elements (may be less than BLOCK_THREADS) + + +**Example** + +The code snippet below illustrates a sum reduction of a partially-full tile of integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int num_valid, ...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item (up to num_items) + int thread_data; + if (threadIdx.x < num_valid) + thread_data = ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data, num_valid); +} +``` + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage& cub::BlockReduce::PrivateStorage() +``` + + +**Returns:** Reference to [_TempStorage](/library/api/cub::BlockReduce::_TempStorage) + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `InternalBlockReduce` | `::cuda::std::_If< Algorithm==`[`BLOCK_REDUCE_WARP_REDUCTIONS`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS)`, `[`WarpReductions`](/library/api/cub::BlockReduce::WarpReductions)`, ::cuda::std::_If< Algorithm==`[`BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC)`, `[`WarpReductionsNondeterministic`](/library/api/cub::BlockReduce::WarpReductionsNondeterministic)`, ::cuda::std::_If< Algorithm==`[`BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY`](/library/api/cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY)`, `[`RakingCommutativeOnly`](/library/api/cub::BlockReduce::RakingCommutativeOnly)`, `[`Raking`](/library/api/cub::BlockReduce::Raking)` > > >` | Internal specialization type. | +| `_TempStorage` | `typename InternalBlockReduce::TempStorage` | Shared memory storage layout type for [BlockReduce](/library/api/cub::BlockReduce). | +| `WarpReductions` | `detail::BlockReduceWarpReductions< T, BlockDimX, BlockDimY, BlockDimZ >` | | +| `WarpReductionsNondeterministic` | `detail::BlockReduceWarpReductions< T, BlockDimX, BlockDimY, BlockDimZ, false >` | | +| `RakingCommutativeOnly` | `detail::BlockReduceRakingCommutativeOnly< T, BlockDimX, BlockDimY, BlockDimZ >` | | +| `Raking` | `detail::BlockReduceRaking< T, BlockDimX, BlockDimY, BlockDimZ >` | | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `temp_storage` | [`_TempStorage`](/library/api/cub::BlockReduce::_TempStorage) `&` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockReduce::TempStorage +``` + + +The operations exposed by [BlockReduce](/library/api/cub::BlockReduce) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) \ No newline at end of file diff --git a/fern/pages/cub/block_scan.mdx b/fern/pages/cub/block_scan.mdx new file mode 100644 index 0000000..fd3b17c --- /dev/null +++ b/fern/pages/cub/block_scan.mdx @@ -0,0 +1,1586 @@ +--- +title: cub::BlockScan +description: "Collective methods for computing parallel prefix sums/scans across a CUDA thread block." +--- + +The `BlockScan` class provides collective methods for computing a parallel prefix sum/scan of items partitioned across a CUDA thread block. + +Given a list of input elements and a binary reduction operator, a [prefix scan](http://en.wikipedia.org/wiki/Prefix_sum) produces an output list where each element is computed to be the reduction of the elements occurring earlier in the input list. *Prefix sum* connotes a prefix scan with the addition operator. The term *inclusive* indicates that the *i*th output reduction incorporates the *i*th input. The term *exclusive* indicates the *i*th input is not incorporated into the *i*th output reduction. Threads are assumed to be in row-major order. + +`BlockScan` can be optionally specialized by algorithm to accommodate different workload profiles: + +1. [`cub::BLOCK_SCAN_RAKING`](/library/api/cub::BLOCK_SCAN_RAKING): + An efficient (high throughput) "raking reduce-then-scan" prefix scan algorithm. +2. [`cub::BLOCK_SCAN_RAKING_MEMOIZE`](/library/api/cub::BLOCK_SCAN_RAKING_MEMOIZE): + Similar to `cub::BLOCK_SCAN_RAKING`, but having higher throughput at the expense of additional register pressure for intermediate storage. +3. [`cub::BLOCK_SCAN_WARP_SCANS`](/library/api/cub::BLOCK_SCAN_WARP_SCANS): + A quick (low latency) "tiled warpscans" prefix scan algorithm. + +## Performance considerations + +- Uses special instructions when applicable (e.g., warp `SHFL` instructions) +- Uses synchronization-free communication between warp lanes when applicable +- Invokes a minimal number of minimal block-wide synchronization barriers (only one or two depending on algorithm selection) +- Incurs zero bank conflicts for most types +- Computation is slightly more efficient (i.e., having lower instruction overhead) for: + - Prefix sum variants (vs. generic scan) + - `BLOCK_THREADS` is a multiple of the architecture's warp size +- See `cub::BlockScanAlgorithm` for performance details regarding algorithmic alternatives + +## Example + +The code snippet below illustrates an exclusive prefix sum of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute the block-wide exclusive prefix sum + BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); +} +``` + + + + + +Data type being scanned + + + +The thread block length in threads along the X dimension + + + +**[optional]** [cub::BlockScanAlgorithm](/library/api/cub::BlockScanAlgorithm) enumerator specifying the underlying algorithm to use (default: [cub::BLOCK_SCAN_RAKING](/library/api/cub::BLOCK_SCAN_RAKING)) + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockScan inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockScan::BlockScan() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockScan::BlockScan( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockScan::TempStorage) + + + + + +--- + +## Exclusive prefix sum operations + +### ExclusiveSum inline + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. The value of 0 is applied as the initial value, and is assigned to `output` in *thread*0. + + +```cpp showLineNumbers={false} +void cub::BlockScan::ExclusiveSum( + T input, + T &output +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + +**Example** + +The code snippet below illustrates an exclusive prefix sum of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide exclusive prefix sum + BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); +} +``` + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. The value of 0 is applied as the initial value, and is assigned to `output` in *thread*0. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +void cub::BlockScan::ExclusiveSum( + T input, + T &output, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +block-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates an exclusive prefix sum of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide exclusive prefix sum + int block_aggregate; + BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data, block_aggregate); +} +``` + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. Instead of using 0 as the block-wide prefix, the call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T input, + T &output, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. The value of 0 is applied as the initial value, and is assigned to `output[0]` in *thread*0. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + +**Example** + +The code snippet below illustrates an exclusive prefix sum of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute the block-wide exclusive prefix sum + BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); +} +``` + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. The value of 0 is applied as the initial value, and is assigned to `output[0]` in *thread*0. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +block-wide aggregate reduction of input items + + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Instead of using 0 as the block-wide prefix, the call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +--- + +## Exclusive prefix scan operations + +### ExclusiveScan inline + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T input, + T &output, + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Initial value to seed the exclusive scan (and is assigned to `output` in *thread*0) + + + +Binary scan functor + + +**Example** + +The code snippet below illustrates an exclusive prefix max scan of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide exclusive prefix max scan + BlockScan(temp_storage).ExclusiveScan(thread_data, thread_data, INT_MIN, cuda::maximum<>{}); +} +``` + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T input, + T &output, + T initial_value, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +`initial_value` is not applied to the block-wide aggregate. + + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the exclusive scan (and is assigned to `output` in *thread*0) + + + +Binary scan functor + + + +block-wide aggregate reduction of input items + + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T input, + T &output, + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*0) + + + +Binary scan functor + + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +`initial_value` is not applied to the block-wide aggregate. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*0) + + + +Binary scan functor + + + +block-wide aggregate reduction of input items + + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +--- + +## Inclusive prefix sum operations + +### InclusiveSum inline + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +void cub::BlockScan::InclusiveSum( + T input, + T &output +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + +**Example** + +The code snippet below illustrates an inclusive prefix sum of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide inclusive prefix sum + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); +} +``` + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +void cub::BlockScan::InclusiveSum( + T input, + T &output, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +block-wide aggregate reduction of input items + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T input, + T &output, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +block-wide aggregate reduction of input items + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Instead of using 0 as the block-wide prefix, the call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +--- + +## Inclusive prefix scan operations + +### InclusiveScan inline + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T input, + T &output, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + +**Example** + +The code snippet below illustrates an inclusive prefix max scan of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide inclusive prefix max scan + BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}); +} +``` + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T input, + T &output, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + + +block-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates an inclusive prefix max scan of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide inclusive prefix max scan + int block_aggregate; + BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}, block_aggregate); +} +``` + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T input, + T &output, + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the inclusive scan + + + +Binary scan functor + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + +block-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates an inclusive prefix max scan of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute the block-wide inclusive prefix max scan + int block_aggregate; + BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}, block_aggregate); +} +``` + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +`initial_value` is not applied to the block-wide aggregate. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the inclusive scan + + + +Binary scan functor + + + +block-wide aggregate reduction of input items + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage& cub::BlockScan::PrivateStorage() +``` + + +**Returns:** Reference to [_TempStorage](/library/api/cub::BlockScan::_TempStorage) + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `InternalBlockScan` | `::cuda::std::_If< SAFE_ALGORITHM==`[`BLOCK_SCAN_WARP_SCANS`](/library/api/cub::BLOCK_SCAN_WARP_SCANS)`, `[`WarpScans`](/library/api/cub::BlockScan::WarpScans)`, `[`Raking`](/library/api/cub::BlockScan::Raking)` >` | Define the delegate type for the desired algorithm. | +| `_TempStorage` | `typename InternalBlockScan::TempStorage` | Shared memory storage layout type for [BlockScan](/library/api/cub::BlockScan). | +| `WarpScans` | `detail::BlockScanWarpScans< T, BlockDimX, BlockDimY, BlockDimZ >` | | +| `Raking` | `detail::BlockScanRaking< T, BlockDimX, BlockDimY, BlockDimZ, (SAFE_ALGORITHM==`[`BLOCK_SCAN_RAKING_MEMOIZE`](/library/api/cub::BLOCK_SCAN_RAKING_MEMOIZE)`) >` | | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockScan::TempStorage +``` + + +The operations exposed by [BlockScan](/library/api/cub::BlockScan) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/pages/cub/block_scan_v4.mdx b/fern/pages/cub/block_scan_v4.mdx new file mode 100644 index 0000000..0b94a1f --- /dev/null +++ b/fern/pages/cub/block_scan_v4.mdx @@ -0,0 +1,1597 @@ +--- +title: cub::BlockScan +description: "Collective methods for computing parallel prefix sums/scans across a CUDA thread block." +--- + +The `BlockScan` class provides collective methods for computing a parallel prefix sum/scan of items partitioned across a CUDA thread block. + +Given a list of input elements and a binary reduction operator, a [prefix scan](http://en.wikipedia.org/wiki/Prefix_sum) produces an output list where each element is computed to be the reduction of the elements occurring earlier in the input list. *Prefix sum* connotes a prefix scan with the addition operator. The term *inclusive* indicates that the *i*th output reduction incorporates the *i*th input. The term *exclusive* indicates the *i*th input is not incorporated into the *i*th output reduction. Threads are assumed to be in row-major order. + +`BlockScan` can be optionally specialized by algorithm to accommodate different workload profiles: + +1. [`cub::BLOCK_SCAN_RAKING`](/library/api/cub::BLOCK_SCAN_RAKING): + An efficient (high throughput) "raking reduce-then-scan" prefix scan algorithm. +2. [`cub::BLOCK_SCAN_RAKING_MEMOIZE`](/library/api/cub::BLOCK_SCAN_RAKING_MEMOIZE): + Similar to `cub::BLOCK_SCAN_RAKING`, but having higher throughput at the expense of additional register pressure for intermediate storage. +3. [`cub::BLOCK_SCAN_WARP_SCANS`](/library/api/cub::BLOCK_SCAN_WARP_SCANS): + A quick (low latency) "tiled warpscans" prefix scan algorithm. + +## Performance considerations + +- Uses special instructions when applicable (e.g., warp `SHFL` instructions) +- Uses synchronization-free communication between warp lanes when applicable +- Invokes a minimal number of minimal block-wide synchronization barriers (only one or two depending on algorithm selection) +- Incurs zero bank conflicts for most types +- Computation is slightly more efficient (i.e., having lower instruction overhead) for: + - Prefix sum variants (vs. generic scan) + - `BLOCK_THREADS` is a multiple of the architecture's warp size +- See `cub::BlockScanAlgorithm` for performance details regarding algorithmic alternatives + +## Example + +The code snippet below illustrates an exclusive prefix sum of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute the block-wide exclusive prefix sum + BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); +} +``` + + + + + +Data type being scanned + + + +The thread block length in threads along the X dimension + + + +**[optional]** [cub::BlockScanAlgorithm](/library/api/cub::BlockScanAlgorithm) enumerator specifying the underlying algorithm to use (default: [cub::BLOCK_SCAN_RAKING](/library/api/cub::BLOCK_SCAN_RAKING)) + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockScan inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockScan::BlockScan() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockScan::BlockScan( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockScan::TempStorage) + + + + + +--- + +## Exclusive prefix sum operations + +### ExclusiveSum inline + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. The value of 0 is applied as the initial value, and is assigned to `output` in *thread*0. + + +```cpp showLineNumbers={false} +void cub::BlockScan::ExclusiveSum( + T input, + T &output +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + +**Example** + +The code snippet below illustrates an exclusive prefix sum of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide exclusive prefix sum + BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); +} +``` + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. The value of 0 is applied as the initial value, and is assigned to `output` in *thread*0. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +void cub::BlockScan::ExclusiveSum( + T input, + T &output, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +block-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates an exclusive prefix sum of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide exclusive prefix sum + int block_aggregate; + BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data, block_aggregate); +} +``` + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. Instead of using 0 as the block-wide prefix, the call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T input, + T &output, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. The value of 0 is applied as the initial value, and is assigned to `output[0]` in *thread*0. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + +**Example** + +The code snippet below illustrates an exclusive prefix sum of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute the block-wide exclusive prefix sum + BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); +} +``` + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. The value of 0 is applied as the initial value, and is assigned to `output[0]` in *thread*0. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +block-wide aggregate reduction of input items + + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Instead of using 0 as the block-wide prefix, the call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Uses the identity element (zero) as the initial value. +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +--- + +## Exclusive prefix scan operations + +### ExclusiveScan inline + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T input, + T &output, + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Initial value to seed the exclusive scan (and is assigned to `output` in *thread*0) + + + +Binary scan functor + + +**Example** + +The code snippet below illustrates an exclusive prefix max scan of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide exclusive prefix max scan + BlockScan(temp_storage).ExclusiveScan(thread_data, thread_data, INT_MIN, cuda::maximum<>{}); +} +``` + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T input, + T &output, + T initial_value, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +`initial_value` is not applied to the block-wide aggregate. + + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the exclusive scan (and is assigned to `output` in *thread*0) + + + +Binary scan functor + + + +block-wide aggregate reduction of input items + + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T input, + T &output, + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*0) + + + +Binary scan functor + + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +`initial_value` is not applied to the block-wide aggregate. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*0) + + + +Binary scan functor + + + +block-wide aggregate reduction of input items + + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +--- + +## Inclusive prefix sum operations + +### InclusiveSum inline + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +void cub::BlockScan::InclusiveSum( + T input, + T &output +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + +**Example** + +The code snippet below illustrates an inclusive prefix sum of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide inclusive prefix sum + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); +} +``` + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +void cub::BlockScan::InclusiveSum( + T input, + T &output, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +block-wide aggregate reduction of input items + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T input, + T &output, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +block-wide aggregate reduction of input items + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Instead of using 0 as the block-wide prefix, the call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +--- + +## Inclusive prefix scan operations + +### InclusiveScan inline + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T input, + T &output, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + +**Example** + +The code snippet below illustrates an inclusive prefix max scan of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide inclusive prefix max scan + BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}); +} +``` + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T input, + T &output, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + + +block-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates an inclusive prefix max scan of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide inclusive prefix max scan + int block_aggregate; + BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}, block_aggregate); +} +``` + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T input, + T &output, + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Supports non-commutative scan operators. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the inclusive scan + + + +Binary scan functor + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + +block-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates an inclusive prefix max scan of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute the block-wide inclusive prefix max scan + int block_aggregate; + BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}, block_aggregate); +} +``` + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +`initial_value` is not applied to the block-wide aggregate. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the inclusive scan + + + +Binary scan functor + + + +block-wide aggregate reduction of input items + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +- Supports non-commutative scan operators. +- Data is in a blocked arrangement across threads. +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + +*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence + + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage& cub::BlockScan::PrivateStorage() +``` + + +**Returns:** Reference to [_TempStorage](/library/api/cub::BlockScan::_TempStorage) + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `InternalBlockScan` | `::cuda::std::_If< SAFE_ALGORITHM==`[`BLOCK_SCAN_WARP_SCANS`](/library/api/cub::BLOCK_SCAN_WARP_SCANS)`, `[`WarpScans`](/library/api/cub::BlockScan::WarpScans)`, `[`Raking`](/library/api/cub::BlockScan::Raking)` >` | Define the delegate type for the desired algorithm. | +| `_TempStorage` | `typename InternalBlockScan::TempStorage` | Shared memory storage layout type for [BlockScan](/library/api/cub::BlockScan). | +| `WarpScans` | `detail::BlockScanWarpScans< T, BlockDimX, BlockDimY, BlockDimZ >` | | +| `Raking` | `detail::BlockScanRaking< T, BlockDimX, BlockDimY, BlockDimZ, (SAFE_ALGORITHM==`[`BLOCK_SCAN_RAKING_MEMOIZE`](/library/api/cub::BLOCK_SCAN_RAKING_MEMOIZE)`) >` | | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `SAFE_ALGORITHM` static constexpr | [`BlockScanAlgorithm`](/library/api/cub::BlockScanAlgorithm) | Ensure the template parameterization meets the requirements of the specified algorithm. Currently, the BLOCK_SCAN_WARP_SCANS policy cannot be used with thread block sizes not a multiple of the architectural warp size. | +| `temp_storage` | [`_TempStorage`](/library/api/cub::BlockScan::_TempStorage) `&` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockScan::TempStorage +``` + + +The operations exposed by [BlockScan](/library/api/cub::BlockScan) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/pages/cub/simple_struct.mdx b/fern/pages/cub/simple_struct.mdx new file mode 100644 index 0000000..a956942 --- /dev/null +++ b/fern/pages/cub/simple_struct.mdx @@ -0,0 +1,48 @@ +--- +title: cub::ArgMax +description: "Arg max functor that keeps the value and offset of the first occurrence of the larger item." +--- + +Arg max functor (keeps the value and offset of the first occurrence of the larger item). + +`ArgMax` is a binary functor that operates on [`KeyValuePair`](/library/api/cub::KeyValuePair) instances, returning the pair with the larger value. In case of ties, the pair with the smaller offset is preferred. + +--- + +## Methods + +### operator() inline const + +Boolean max operator, preferring the item having the smaller offset in case of ties. + + +```cpp showLineNumbers={false} +template +KeyValuePair cub::ArgMax::operator()( + const KeyValuePair &a, + const KeyValuePair &b +) const +``` + + +**Template parameters** + + +**[inferred]** Value type + + + +**[inferred]** Offset type + + +**Parameters** + + +First input key-value pair + + + +Second input key-value pair + + +**Returns:** The [`KeyValuePair`](/library/api/cub::KeyValuePair) with the larger value (ties broken by smaller offset) diff --git a/fern/pages/cub/simple_struct_v4.mdx b/fern/pages/cub/simple_struct_v4.mdx new file mode 100644 index 0000000..a956942 --- /dev/null +++ b/fern/pages/cub/simple_struct_v4.mdx @@ -0,0 +1,48 @@ +--- +title: cub::ArgMax +description: "Arg max functor that keeps the value and offset of the first occurrence of the larger item." +--- + +Arg max functor (keeps the value and offset of the first occurrence of the larger item). + +`ArgMax` is a binary functor that operates on [`KeyValuePair`](/library/api/cub::KeyValuePair) instances, returning the pair with the larger value. In case of ties, the pair with the smaller offset is preferred. + +--- + +## Methods + +### operator() inline const + +Boolean max operator, preferring the item having the smaller offset in case of ties. + + +```cpp showLineNumbers={false} +template +KeyValuePair cub::ArgMax::operator()( + const KeyValuePair &a, + const KeyValuePair &b +) const +``` + + +**Template parameters** + + +**[inferred]** Value type + + + +**[inferred]** Offset type + + +**Parameters** + + +First input key-value pair + + + +Second input key-value pair + + +**Returns:** The [`KeyValuePair`](/library/api/cub::KeyValuePair) with the larger value (ties broken by smaller offset) diff --git a/fern/pages/cub/warp_reduce.mdx b/fern/pages/cub/warp_reduce.mdx new file mode 100644 index 0000000..0ccffc2 --- /dev/null +++ b/fern/pages/cub/warp_reduce.mdx @@ -0,0 +1,837 @@ +--- +title: cub::WarpReduce +description: "Collective methods for computing parallel reductions across a CUDA thread warp." +--- + +The `WarpReduce` class provides collective methods for computing a parallel reduction of items partitioned across a CUDA thread warp. + +A [reduction](http://en.wikipedia.org/wiki/Reduce_(higher-order_function)) (or *fold*) uses a binary combining operator to compute a single aggregate from a list of input elements. + +- Supports "logical" warps smaller than the physical warp size (e.g., logical warps of 8 threads) +- The number of entrant threads must be a multiple of `LogicalWarpThreads` + +## Performance considerations + +- Uses special instructions when applicable (e.g., warp `SHFL` instructions) +- Uses synchronization-free communication between warp lanes when applicable +- Incurs zero bank conflicts for most types +- Computation is slightly more efficient (i.e., having lower instruction overhead) for: + - Summation (vs. generic reduction) + - The architecture's warp size is a whole multiple of `LogicalWarpThreads` + +## Example + +The code snippet below illustrates four concurrent warp sum reductions within a block of 128 threads (one per each of the 32-thread warps). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + // Allocate WarpReduce shared memory for 4 warps + __shared__ typename WarpReduce::TempStorage temp_storage[4]; + // Obtain one input item per thread + int thread_data = ... + // Return the warp-wide sums to each lane0 (threads 0, 32, 64, and 96) + int warp_id = threadIdx.x / 32; + int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data); +} +``` + +Suppose the set of input `thread_data` across the block of threads is `{0, 1, 2, 3, ..., 127}`. The corresponding output `aggregate` in threads 0, 32, 64, and 96 will be `496`, `1520`, `2544`, and `3568`, respectively (and is undefined in other threads). + + + + + +Data type being reduced + + + +**[optional]** The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute architecture (e.g., 32 threads for SM3x). + + + + + +--- + +## Collective constructors + +### WarpReduce inline + +Collective constructor using the specified memory allocation as temporary storage. Logical warp and lane identifiers are constructed from `threadIdx.x`. + + +```cpp showLineNumbers={false} +cub::WarpReduce::WarpReduce( + TempStorage &temp_storage +) +``` + + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::WarpReduce::TempStorage) + + +--- + +## Summation reductions + +### Sum inline + + + + +Computes a warp-wide sum in the calling warp. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Sum( + T input +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + +**Example** + +The code snippet below illustrates four concurrent warp sum reductions within a block of 128 threads (one per each of the 32-thread warps). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + // Allocate WarpReduce shared memory for 4 warps + __shared__ typename WarpReduce::TempStorage temp_storage[4]; + // Obtain one input item per thread + int thread_data = ... + // Return the warp-wide sums to each lane0 + int warp_id = threadIdx.x / 32; + int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data); +} +``` + + + + +Computes a warp-wide sum in the calling warp. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +template , int> = 0> +T cub::WarpReduce::Sum( + const InputType &input +) +``` + + +**Template parameters** + + +**[inferred]** Input type, must be a fixed-size random access range + + +**Parameters** + + +Calling thread's input + + + + + +Computes a partially-full warp-wide sum in the calling warp. The output is valid in warp *lane*0. + +All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Sum( + T input, + int valid_items +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + +Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) + + +**Example** + +The code snippet below illustrates a sum reduction within a single, partially-full block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(int *d_data, int valid_items) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item per thread if in range + int thread_data; + if (threadIdx.x < valid_items) + thread_data = d_data[threadIdx.x]; + + // Return the warp-wide sums to each lane0 + int aggregate = WarpReduce(temp_storage).Sum(thread_data, valid_items); +} +``` + + + + +--- + +## Max reductions + +### Max inline + + + + +Computes a warp-wide maximum in the calling warp. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Max( + T input +) +``` + + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + + + +Computes a warp-wide maximum in the calling warp. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +template , int> = 0> +T cub::WarpReduce::Max( + const InputType &input +) +``` + + +**Template parameters** + + +**[inferred]** Input type, must be a fixed-size random access range + + +**Parameters** + + +Calling thread's input + + + + + +Computes a partially-full warp-wide maximum in the calling warp. The output is valid in warp *lane*0. + +All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Max( + T input, + int valid_items +) +``` + + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + +Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) + + + + + +--- + +## Min reductions + +### Min inline + + + + +Computes a warp-wide minimum in the calling warp. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Min( + T input +) +``` + + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + + + +Computes a warp-wide minimum in the calling warp. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +template , int> = 0> +T cub::WarpReduce::Min( + const InputType &input +) +``` + + +**Template parameters** + + +**[inferred]** Input type, must be a fixed-size random access range + + +**Parameters** + + +Calling thread's input + + + + + +Computes a partially-full warp-wide minimum in the calling warp. The output is valid in warp *lane*0. + +All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Min( + T input, + int valid_items +) +``` + + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + +Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) + + + + + +--- + +## Generic reductions + +### Reduce inline + + + + +Computes a warp-wide reduction in the calling warp using the specified binary reduction functor. The output is valid in warp *lane*0. + +Supports non-commutative reduction operators. + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Reduce( + T input, + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction operator + + +**Example** + +The code snippet below illustrates four concurrent warp max reductions within a block of 128 threads (one per each of the 32-thread warps). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for 4 warps + __shared__ typename WarpReduce::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Return the warp-wide reductions to each lane0 + int warp_id = threadIdx.x / 32; + int aggregate = WarpReduce(temp_storage[warp_id]).Reduce( + thread_data, cuda::maximum<>{}); +} +``` + + + + +Computes a warp-wide reduction in the calling warp using the specified binary reduction functor. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Reduce( + const InputType &input, + ReductionOp reduction_op +) +``` + + +**Template parameters** + + +**[inferred]** Input type, must be a fixed-size random access range + + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction operator + + + + + +Computes a partially-full warp-wide reduction in the calling warp using the specified binary reduction functor. The output is valid in warp *lane*0. + +All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. + +Supports non-commutative reduction operators. + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Reduce( + T input, + ReductionOp reduction_op, + int valid_items +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction operator + + + +Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) + + +**Example** + +The code snippet below illustrates a max reduction within a single, partially-full block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(int *d_data, int valid_items) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item per thread if in range + int thread_data; + if (threadIdx.x < valid_items) + thread_data = d_data[threadIdx.x]; + + // Return the warp-wide reductions to each lane0 + int aggregate = WarpReduce(temp_storage).Reduce( + thread_data, cuda::maximum<>{}, valid_items); +} +``` + + + + +--- + +## Segmented reductions + +### HeadSegmentedSum inline + +Computes a segmented sum in the calling warp where segments are defined by head-flags. The sum of each segment is returned to the first lane in that segment (which always includes *lane*0). + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::HeadSegmentedSum( + T input, + FlagT head_flag +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + +Head flag denoting whether or not `input` is the start of a new segment + + +**Example** + +The code snippet below illustrates a head-segmented warp sum reduction within a block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int head_flag = ... + + // Return the warp-wide sums to each lane0 + int aggregate = WarpReduce(temp_storage).HeadSegmentedSum( + thread_data, head_flag); +} +``` + +--- + +### TailSegmentedSum inline + +Computes a segmented sum in the calling warp where segments are defined by tail-flags. The sum of each segment is returned to the first lane in that segment (which always includes *lane*0). + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::TailSegmentedSum( + T input, + FlagT tail_flag +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + +Tail flag denoting whether or not `input` is the end of the current segment + + +**Example** + +The code snippet below illustrates a tail-segmented warp sum reduction within a block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int tail_flag = ... + + // Return the warp-wide sums to each lane0 + int aggregate = WarpReduce(temp_storage).TailSegmentedSum( + thread_data, tail_flag); +} +``` + +--- + +### HeadSegmentedReduce inline + +Computes a segmented reduction in the calling warp where segments are defined by head-flags. The reduction of each segment is returned to the first lane in that segment (which always includes *lane*0). + +Supports non-commutative reduction operators. + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::HeadSegmentedReduce( + T input, + FlagT head_flag, + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Head flag denoting whether or not `input` is the start of a new segment + + + +Binary reduction operator + + +**Example** + +The code snippet below illustrates a head-segmented warp max reduction within a block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int head_flag = ... + + // Return the warp-wide reductions to each lane0 + int aggregate = WarpReduce(temp_storage).HeadSegmentedReduce( + thread_data, head_flag, cuda::maximum<>{}); +} +``` + +--- + +### TailSegmentedReduce inline + +Computes a segmented reduction in the calling warp where segments are defined by tail-flags. The reduction of each segment is returned to the first lane in that segment (which always includes *lane*0). + +Supports non-commutative reduction operators. + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::TailSegmentedReduce( + T input, + FlagT tail_flag, + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Tail flag denoting whether or not `input` is the end of the current segment + + + +Binary reduction operator + + +**Example** + +The code snippet below illustrates a tail-segmented warp max reduction within a block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int tail_flag = ... + + // Return the warp-wide reductions to each lane0 + int aggregate = WarpReduce(temp_storage).TailSegmentedReduce( + thread_data, tail_flag, cuda::maximum<>{}); +} +``` + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `_TempStorage` | `typename InternalWarpReduce::TempStorage` | Shared memory storage layout type for [WarpReduce](/library/api/cub::WarpReduce). | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::WarpReduce::TempStorage +``` + + +The operations exposed by [WarpReduce](/library/api/cub::WarpReduce) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/pages/cub/warp_reduce_v4.mdx b/fern/pages/cub/warp_reduce_v4.mdx new file mode 100644 index 0000000..2676b1d --- /dev/null +++ b/fern/pages/cub/warp_reduce_v4.mdx @@ -0,0 +1,847 @@ +--- +title: cub::WarpReduce +description: "Collective methods for computing parallel reductions across a CUDA thread warp." +--- + +The `WarpReduce` class provides collective methods for computing a parallel reduction of items partitioned across a CUDA thread warp. + +A [reduction](http://en.wikipedia.org/wiki/Reduce_(higher-order_function)) (or *fold*) uses a binary combining operator to compute a single aggregate from a list of input elements. + +- Supports "logical" warps smaller than the physical warp size (e.g., logical warps of 8 threads) +- The number of entrant threads must be a multiple of `LogicalWarpThreads` + +## Performance considerations + +- Uses special instructions when applicable (e.g., warp `SHFL` instructions) +- Uses synchronization-free communication between warp lanes when applicable +- Incurs zero bank conflicts for most types +- Computation is slightly more efficient (i.e., having lower instruction overhead) for: + - Summation (vs. generic reduction) + - The architecture's warp size is a whole multiple of `LogicalWarpThreads` + +## Example + +The code snippet below illustrates four concurrent warp sum reductions within a block of 128 threads (one per each of the 32-thread warps). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + // Allocate WarpReduce shared memory for 4 warps + __shared__ typename WarpReduce::TempStorage temp_storage[4]; + // Obtain one input item per thread + int thread_data = ... + // Return the warp-wide sums to each lane0 (threads 0, 32, 64, and 96) + int warp_id = threadIdx.x / 32; + int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data); +} +``` + +Suppose the set of input `thread_data` across the block of threads is `{0, 1, 2, 3, ..., 127}`. The corresponding output `aggregate` in threads 0, 32, 64, and 96 will be `496`, `1520`, `2544`, and `3568`, respectively (and is undefined in other threads). + + + + + +Data type being reduced + + + +**[optional]** The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute architecture (e.g., 32 threads for SM3x). + + + + + +--- + +## Collective constructors + +### WarpReduce inline + +Collective constructor using the specified memory allocation as temporary storage. Logical warp and lane identifiers are constructed from `threadIdx.x`. + + +```cpp showLineNumbers={false} +cub::WarpReduce::WarpReduce( + TempStorage &temp_storage +) +``` + + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::WarpReduce::TempStorage) + + +--- + +## Generic reductions + +### Reduce inline + + + + +Computes a warp-wide reduction in the calling warp using the specified binary reduction functor. The output is valid in warp *lane*0. + +Supports non-commutative reduction operators. + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Reduce( + T input, + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction operator + + +**Example** + +The code snippet below illustrates four concurrent warp max reductions within a block of 128 threads (one per each of the 32-thread warps). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for 4 warps + __shared__ typename WarpReduce::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Return the warp-wide reductions to each lane0 + int warp_id = threadIdx.x / 32; + int aggregate = WarpReduce(temp_storage[warp_id]).Reduce( + thread_data, cuda::maximum<>{}); +} +``` + + + + +Computes a warp-wide reduction in the calling warp using the specified binary reduction functor. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Reduce( + const InputType &input, + ReductionOp reduction_op +) +``` + + +**Template parameters** + + +**[inferred]** Input type, must be a fixed-size random access range + + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction operator + + + + + +Computes a partially-full warp-wide reduction in the calling warp using the specified binary reduction functor. The output is valid in warp *lane*0. + +All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. + +Supports non-commutative reduction operators. + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Reduce( + T input, + ReductionOp reduction_op, + int valid_items +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction operator + + + +Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) + + +**Example** + +The code snippet below illustrates a max reduction within a single, partially-full block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(int *d_data, int valid_items) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item per thread if in range + int thread_data; + if (threadIdx.x < valid_items) + thread_data = d_data[threadIdx.x]; + + // Return the warp-wide reductions to each lane0 + int aggregate = WarpReduce(temp_storage).Reduce( + thread_data, cuda::maximum<>{}, valid_items); +} +``` + + + + +--- + +## Summation reductions + +### Sum inline + + + + +Computes a warp-wide sum in the calling warp. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Sum( + T input +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + +**Example** + +The code snippet below illustrates four concurrent warp sum reductions within a block of 128 threads (one per each of the 32-thread warps). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + // Allocate WarpReduce shared memory for 4 warps + __shared__ typename WarpReduce::TempStorage temp_storage[4]; + // Obtain one input item per thread + int thread_data = ... + // Return the warp-wide sums to each lane0 + int warp_id = threadIdx.x / 32; + int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data); +} +``` + + + + +Computes a warp-wide sum in the calling warp. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +template , int> = 0> +T cub::WarpReduce::Sum( + const InputType &input +) +``` + + +**Template parameters** + + +**[inferred]** Input type, must be a fixed-size random access range + + +**Parameters** + + +Calling thread's input + + + + + +Computes a partially-full warp-wide sum in the calling warp. The output is valid in warp *lane*0. + +All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Sum( + T input, + int valid_items +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + +Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) + + +**Example** + +The code snippet below illustrates a sum reduction within a single, partially-full block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(int *d_data, int valid_items) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item per thread if in range + int thread_data; + if (threadIdx.x < valid_items) + thread_data = d_data[threadIdx.x]; + + // Return the warp-wide sums to each lane0 + int aggregate = WarpReduce(temp_storage).Sum(thread_data, valid_items); +} +``` + + + + +--- + +## Max reductions + +### Max inline + + + + +Computes a warp-wide maximum in the calling warp. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Max( + T input +) +``` + + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + + + +Computes a warp-wide maximum in the calling warp. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +template , int> = 0> +T cub::WarpReduce::Max( + const InputType &input +) +``` + + +**Template parameters** + + +**[inferred]** Input type, must be a fixed-size random access range + + +**Parameters** + + +Calling thread's input + + + + + +Computes a partially-full warp-wide maximum in the calling warp. The output is valid in warp *lane*0. + +All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Max( + T input, + int valid_items +) +``` + + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + +Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) + + + + + +--- + +## Min reductions + +### Min inline + + + + +Computes a warp-wide minimum in the calling warp. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Min( + T input +) +``` + + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + + + +Computes a warp-wide minimum in the calling warp. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +template , int> = 0> +T cub::WarpReduce::Min( + const InputType &input +) +``` + + +**Template parameters** + + +**[inferred]** Input type, must be a fixed-size random access range + + +**Parameters** + + +Calling thread's input + + + + + +Computes a partially-full warp-wide minimum in the calling warp. The output is valid in warp *lane*0. + +All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Min( + T input, + int valid_items +) +``` + + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + +Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) + + + + + +--- + +## Segmented reductions + +### HeadSegmentedSum inline + +Computes a segmented sum in the calling warp where segments are defined by head-flags. The sum of each segment is returned to the first lane in that segment (which always includes *lane*0). + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::HeadSegmentedSum( + T input, + FlagT head_flag +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + +Head flag denoting whether or not `input` is the start of a new segment + + +**Example** + +The code snippet below illustrates a head-segmented warp sum reduction within a block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int head_flag = ... + + // Return the warp-wide sums to each lane0 + int aggregate = WarpReduce(temp_storage).HeadSegmentedSum( + thread_data, head_flag); +} +``` + +--- + +### TailSegmentedSum inline + +Computes a segmented sum in the calling warp where segments are defined by tail-flags. The sum of each segment is returned to the first lane in that segment (which always includes *lane*0). + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::TailSegmentedSum( + T input, + FlagT tail_flag +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Parameters** + + +Calling thread's input + + + +Tail flag denoting whether or not `input` is the end of the current segment + + +**Example** + +The code snippet below illustrates a tail-segmented warp sum reduction within a block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int tail_flag = ... + + // Return the warp-wide sums to each lane0 + int aggregate = WarpReduce(temp_storage).TailSegmentedSum( + thread_data, tail_flag); +} +``` + +--- + +### HeadSegmentedReduce inline + +Computes a segmented reduction in the calling warp where segments are defined by head-flags. The reduction of each segment is returned to the first lane in that segment (which always includes *lane*0). + +Supports non-commutative reduction operators. + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::HeadSegmentedReduce( + T input, + FlagT head_flag, + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Head flag denoting whether or not `input` is the start of a new segment + + + +Binary reduction operator + + +**Example** + +The code snippet below illustrates a head-segmented warp max reduction within a block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int head_flag = ... + + // Return the warp-wide reductions to each lane0 + int aggregate = WarpReduce(temp_storage).HeadSegmentedReduce( + thread_data, head_flag, cuda::maximum<>{}); +} +``` + +--- + +### TailSegmentedReduce inline + +Computes a segmented reduction in the calling warp where segments are defined by tail-flags. The reduction of each segment is returned to the first lane in that segment (which always includes *lane*0). + +Supports non-commutative reduction operators. + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::TailSegmentedReduce( + T input, + FlagT tail_flag, + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Tail flag denoting whether or not `input` is the end of the current segment + + + +Binary reduction operator + + +**Example** + +The code snippet below illustrates a tail-segmented warp max reduction within a block of 32 threads (one warp). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int tail_flag = ... + + // Return the warp-wide reductions to each lane0 + int aggregate = WarpReduce(temp_storage).TailSegmentedReduce( + thread_data, tail_flag, cuda::maximum<>{}); +} +``` + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `_TempStorage` | `typename InternalWarpReduce::TempStorage` | Shared memory storage layout type for [WarpReduce](/library/api/cub::WarpReduce). | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `is_full_warp` static constexpr | `bool` | | +| `is_power_of_two` static constexpr | `bool` | | +| `temp_storage` | [`_TempStorage`](/library/api/cub::WarpReduce::_TempStorage) `&` | Shared storage reference. | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::WarpReduce::TempStorage +``` + + +The operations exposed by [WarpReduce](/library/api/cub::WarpReduce) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/pages/libcudacxx/concept_example.mdx b/fern/pages/libcudacxx/concept_example.mdx new file mode 100644 index 0000000..363fb12 --- /dev/null +++ b/fern/pages/libcudacxx/concept_example.mdx @@ -0,0 +1,43 @@ +--- +title: "cuda::mr::resource_with" +--- + +# resource_with + +The `resource_with` concept verifies that a type Resource satisfies the [`resource`](/library/api/cuda::mr::resource) concept and also satisfies all the provided Properties. + + + + +The resource type to check against the [`resource`](/library/api/cuda::mr::resource) concept. + + + +A variadic pack of property types that the resource must additionally satisfy. + + + + +--- + +## Description + +`resource_with` is a compound concept that combines two requirements: + +1. The type `_Resource` must satisfy [`cuda::mr::resource`](/library/api/cuda::mr::resource), meaning it supports both synchronous and stream-ordered allocation interfaces. +2. The type `_Resource` must satisfy [`cuda::has_property`](/library/api/cuda::has_property) for each property type in `_Properties`. + +This concept is useful when writing generic code that requires a memory resource with specific properties, such as device accessibility or a particular allocation strategy. + +--- + +## Related concepts + +| Concept | Description | +|---|---| +| [`cuda::mr::resource`](/library/api/cuda::mr::resource) | Verifies that a type satisfies the basic requirements of a memory resource with stream-ordered allocations. | +| [`cuda::mr::synchronous_resource`](/library/api/cuda::mr::synchronous_resource) | Verifies that a type satisfies the basic requirements of a synchronous memory resource. | +| [`cuda::mr::synchronous_resource_with`](/library/api/cuda::mr::synchronous_resource_with) | The synchronous counterpart: verifies a synchronous resource that also satisfies a set of properties. | +| [`cuda::has_property`](/library/api/cuda::has_property) | Verifies that a resource satisfies a given property. | +| [`cuda::has_property_with`](/library/api/cuda::has_property_with) | Verifies that a resource satisfies a given stateful property. | +| [`cuda::property_with_value`](/library/api/cuda::property_with_value) | Verifies that a property is stateful and exposes a `value_type` alias. | diff --git a/fern/pages/libcudacxx/concept_example_v3.mdx b/fern/pages/libcudacxx/concept_example_v3.mdx new file mode 100644 index 0000000..bc6ebb7 --- /dev/null +++ b/fern/pages/libcudacxx/concept_example_v3.mdx @@ -0,0 +1,53 @@ +--- +title: "cuda::mr::resource_with" +description: "A concept that verifies a memory resource satisfies both the resource concept and a set of property requirements." +--- + +C++20 concept + +The `resource_with` concept verifies that a type Resource satisfies the [`resource`](/library/api/cuda::mr::resource) concept and also satisfies all the provided Properties. + + +```cpp showLineNumbers={false} +template +concept resource_with = /* see description */; +``` + + + + + + +The resource type to check against the [`resource`](/library/api/cuda::mr::resource) concept. + + + +A variadic pack of property types that the resource must additionally satisfy. + + + + + +--- + +## Description + +`resource_with` is a compound concept that combines two requirements: + +1. The type `_Resource` must satisfy [`cuda::mr::resource`](/library/api/cuda::mr::resource), meaning it supports both synchronous and stream-ordered allocation interfaces. +2. The type `_Resource` must satisfy [`cuda::has_property`](/library/api/cuda::has_property) for each property type in `_Properties`. + +This concept is useful when writing generic code that requires a memory resource with specific properties, such as device accessibility or a particular allocation strategy. + +--- + +## Related concepts + +| Concept | Description | +|---|---| +| [`cuda::mr::resource`](/library/api/cuda::mr::resource) | Verifies that a type satisfies the basic requirements of a memory resource with stream-ordered allocations. | +| [`cuda::mr::synchronous_resource`](/library/api/cuda::mr::synchronous_resource) | Verifies that a type satisfies the basic requirements of a synchronous memory resource. | +| [`cuda::mr::synchronous_resource_with`](/library/api/cuda::mr::synchronous_resource_with) | The synchronous counterpart: verifies a synchronous resource that also satisfies a set of properties. | +| [`cuda::has_property`](/library/api/cuda::has_property) | Verifies that a resource satisfies a given property. | +| [`cuda::has_property_with`](/library/api/cuda::has_property_with) | Verifies that a resource satisfies a given stateful property. | +| [`cuda::property_with_value`](/library/api/cuda::property_with_value) | Verifies that a property is stateful and exposes a `value_type` alias. | \ No newline at end of file diff --git a/fern/pages/libcudacxx/deep_template_class.mdx b/fern/pages/libcudacxx/deep_template_class.mdx new file mode 100644 index 0000000..58c2622 --- /dev/null +++ b/fern/pages/libcudacxx/deep_template_class.mdx @@ -0,0 +1,245 @@ +--- +title: "cuda::counting_iterator" +description: "An iterator that represents a range of sequentially increasing values without storing them in memory." +--- + +A `counting_iterator` represents an iterator into a range of sequentially increasing values. + +This iterator is useful for creating a range filled with a sequence without explicitly storing it in memory. Using `counting_iterator` saves memory capacity and bandwidth. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The value type of the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator). + + +The remaining template parameters are SFINAE constraints that require `_Start` to model `::cuda::std::weakly_incrementable` and `::cuda::std::copyable`. They are not intended to be specified directly. + + + + +**Inherits from:** `__counting_iterator_category<_Start>` (public) + +--- + +## Example + +The code snippet below demonstrates how to create a `counting_iterator` whose `value_type` is `int`. + +```cpp showLineNumbers={false} +#include +... +// create iterators +cuda::counting_iterator first(10); +cuda::counting_iterator last = first + 3; + +first[0] // returns 10 +first[1] // returns 11 +first[100] // returns 110 + +// sum of [first, last) +std::reduce(first, last); // returns 33 (i.e. 10 + 11 + 12) + +// initialize vector to [0,1,2,..] +cuda::counting_iterator iter(0); +std::vector vec(500); +std::copy(iter, iter + vec.size(), vec.begin()); +``` + +--- + +## Constructors + +### counting_iterator inline constexpr noexcept + + + + +Default-constructs the stored value. + + +```cpp showLineNumbers={false} +cuda::counting_iterator<_Start,,>::counting_iterator() +``` + + + + + +explicit + +Creates a `counting_iterator` from an initial value. + + +```cpp showLineNumbers={false} +cuda::counting_iterator<_Start,,>::counting_iterator( + _Start __value +) +``` + + +**Parameters** + + +The value to store in the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator). + + + + + +--- + +## Element access + +### operator* inline constexpr const noexcept + +nodiscard + +Returns the value currently stored in the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator). + + +```cpp showLineNumbers={false} +_Start cuda::counting_iterator<_Start,,>::operator*() const +``` + + +**Returns:** `_Start` + +### operator[] inline constexpr const noexcept + +nodiscard + +Returns the value currently stored in the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator) advanced by a number of steps. + + +```cpp showLineNumbers={false} +_Start2 cuda::counting_iterator<_Start,,>::operator[]( + difference_type __n +) const +``` + + +**Returns:** `_Start2` + +**Parameters** + + +The amount of elements to advance. + + +--- + +## Increment operators + +### operator++ inline constexpr noexcept + + + + + +```cpp showLineNumbers={false} +counting_iterator& cuda::counting_iterator<_Start,,>::operator++() +``` + + +**Returns:** `counting_iterator &` + + + + + +```cpp showLineNumbers={false} +auto cuda::counting_iterator<_Start,,>::operator++(int) +``` + + +**Returns:** `auto` + + + + +### operator-- inline constexpr noexcept + + + + + +```cpp showLineNumbers={false} +counting_iterator& cuda::counting_iterator<_Start,,>::operator--() +``` + + +**Returns:** `counting_iterator &` + + + + + +```cpp showLineNumbers={false} +counting_iterator cuda::counting_iterator<_Start,,>::operator--(int) +``` + + +**Returns:** `counting_iterator` + + + + +--- + +## Compound assignment operators + +### operator+= inline constexpr noexcept + + +```cpp showLineNumbers={false} +counting_iterator& cuda::counting_iterator<_Start,,>::operator+=( + difference_type __n +) +``` + + +**Returns:** `counting_iterator &` + +**Parameters** + + +The number of positions to advance. + + +### operator-= inline constexpr noexcept + + +```cpp showLineNumbers={false} +counting_iterator& cuda::counting_iterator<_Start,,>::operator-=( + difference_type __n +) +``` + + +**Returns:** `counting_iterator &` + +**Parameters** + + +The number of positions to retreat. + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `::cuda::std::conditional_t<__advanceable<_Start>, ::cuda::std::random_access_iterator_tag, ::cuda::std::conditional_t<__decrementable<_Start>, ::cuda::std::bidirectional_iterator_tag, ::cuda::std::conditional_t<::cuda::std::incrementable<_Start>, ::cuda::std::forward_iterator_tag, ::cuda::std::input_iterator_tag>>>` | +| `value_type` | `_Start` | +| `difference_type` | `_IotaDiffT<_Start>` | +| `reference` | `_Start` | +| `pointer` | `void` | diff --git a/fern/pages/libcudacxx/deep_template_class_v4.mdx b/fern/pages/libcudacxx/deep_template_class_v4.mdx new file mode 100644 index 0000000..9126960 --- /dev/null +++ b/fern/pages/libcudacxx/deep_template_class_v4.mdx @@ -0,0 +1,247 @@ +--- +title: "cuda::counting_iterator" +description: "An iterator that represents a range of sequentially increasing values without storing them in memory." +--- + +A `counting_iterator` represents an iterator into a range of sequentially increasing values. + +This iterator is useful for creating a range filled with a sequence without explicitly storing it in memory. Using `counting_iterator` saves memory capacity and bandwidth. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `__counting_iterator_category<_Start>` (public) + +## Example + +The code snippet below demonstrates how to create a `counting_iterator` whose `value_type` is `int`. + +```cpp showLineNumbers={false} +#include +... +// create iterators +cuda::counting_iterator first(10); +cuda::counting_iterator last = first + 3; + +first[0] // returns 10 +first[1] // returns 11 +first[100] // returns 110 + +// sum of [first, last) +std::reduce(first, last); // returns 33 (i.e. 10 + 11 + 12) + +// initialize vector to [0,1,2,..] +cuda::counting_iterator iter(0); +std::vector vec(500); +std::copy(iter, iter + vec.size(), vec.begin()); +``` + +The remaining template parameters are SFINAE constraints that require `_Start` to model `::cuda::std::weakly_incrementable` and `::cuda::std::copyable`. They are not intended to be specified directly. + + + + + +The value type of the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator). + + + + + +--- + +## Constructors + +### counting_iterator inline constexpr noexcept + + + + +Default-constructs the stored value. + + +```cpp showLineNumbers={false} +cuda::counting_iterator<_Start,,>::counting_iterator() +``` + + + + + +explicit + +Creates a `counting_iterator` from an initial value. + + +```cpp showLineNumbers={false} +cuda::counting_iterator<_Start,,>::counting_iterator( + _Start __value +) +``` + + +**Parameters** + + +The value to store in the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator). + + + + + +--- + +## Element access + +### operator* inline constexpr const noexcept nodiscard + +Returns the value currently stored in the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator). + + +```cpp showLineNumbers={false} +_Start cuda::counting_iterator<_Start,,>::operator*() const +``` + + +**Returns:** `_Start` + +### operator[] inline constexpr const noexcept nodiscard + +Returns the value currently stored in the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator) advanced by a number of steps. + + +```cpp showLineNumbers={false} +_Start2 cuda::counting_iterator<_Start,,>::operator[]( + difference_type __n +) const +``` + + +**Returns:** `_Start2` + +**Parameters** + + +The amount of elements to advance. + + +--- + +## Increment operators + +### operator++ inline constexpr noexcept + +Increments the stored value. + + + + + +```cpp showLineNumbers={false} +counting_iterator& cuda::counting_iterator<_Start,,>::operator++() +``` + + +**Returns:** `counting_iterator &` + + + + + +```cpp showLineNumbers={false} +auto cuda::counting_iterator<_Start,,>::operator++(int) +``` + + +**Returns:** `auto` + + + + +### operator-- inline constexpr noexcept + +Decrements the stored value. + + + + + +```cpp showLineNumbers={false} +counting_iterator& cuda::counting_iterator<_Start,,>::operator--() +``` + + +**Returns:** `counting_iterator &` + + + + + +```cpp showLineNumbers={false} +counting_iterator cuda::counting_iterator<_Start,,>::operator--(int) +``` + + +**Returns:** `counting_iterator` + + + + +--- + +## Compound assignment operators + +### operator+= inline constexpr noexcept + +Increments the stored value by a given number of elements. + + +```cpp showLineNumbers={false} +counting_iterator& cuda::counting_iterator<_Start,,>::operator+=( + difference_type __n +) +``` + + +**Returns:** `counting_iterator &` + +**Parameters** + + +The number of positions to advance. + + +### operator-= inline constexpr noexcept + +Decrements the stored value by a given number of elements. + + +```cpp showLineNumbers={false} +counting_iterator& cuda::counting_iterator<_Start,,>::operator-=( + difference_type __n +) +``` + + +**Returns:** `counting_iterator &` + +**Parameters** + + +The number of positions to retreat. + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `::cuda::std::conditional_t<__advanceable<_Start>, ::cuda::std::random_access_iterator_tag, ::cuda::std::conditional_t<__decrementable<_Start>, ::cuda::std::bidirectional_iterator_tag, ::cuda::std::conditional_t<::cuda::std::incrementable<_Start>, ::cuda::std::forward_iterator_tag, ::cuda::std::input_iterator_tag>>>` | +| `value_type` | `_Start` | +| `difference_type` | `_IotaDiffT<_Start>` | +| `reference` | `_Start` | +| `pointer` | `void` | diff --git a/fern/pages/libcudacxx/empty_docstring_class.mdx b/fern/pages/libcudacxx/empty_docstring_class.mdx new file mode 100644 index 0000000..c29fa6b --- /dev/null +++ b/fern/pages/libcudacxx/empty_docstring_class.mdx @@ -0,0 +1,846 @@ +--- +title: "cuda::buffer" +description: "A memory-safe buffer for managing typed, property-annotated device memory with stream-ordered allocation." +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type to be stored in the buffer. + + + +The properties the allocated memory satisfies. + + + + + +--- + +## Constructors + +### Copy and move constructors + + + + +inline explicit + +Copy-constructs from a buffer. + + +```cpp showLineNumbers={false} +cuda::buffer<_Tp, _Properties>::buffer( + const buffer &__other +) +``` + + +**Parameters** + + +The other buffer. + + + + + +inline noexcept + +Move-constructs from a buffer. + + +```cpp showLineNumbers={false} +cuda::buffer<_Tp, _Properties>::buffer( + buffer &&__other +) noexcept +``` + + +**Parameters** + + +The other buffer. After move construction, the other buffer can only be assigned to or destroyed. + + + + + +inline explicit + +Copy-constructs from a buffer with matching properties. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + const buffer<_Tp, _OtherProperties...> &__other +) +``` + + +**Parameters** + + +The other buffer. + + + + + +inline noexcept + +Move-constructs from a buffer with matching properties. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + buffer<_Tp, _OtherProperties...> &&__other +) noexcept +``` + + +**Parameters** + + +The other buffer. After move construction, the other buffer can only be assigned to or destroyed. + + + + + +### Resource constructors + + + + +inline + +Constructs an empty buffer using an environment. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + const _Env &__env = {} +) +``` + + + +No memory is allocated. + + +**Parameters** + + +The environment providing the needed information. + + + + + +inline explicit + +Constructs a buffer of size `__size` using a memory and leaves all elements uninitialized. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + const size_type __size, + ::cuda::no_init_t, + const _Env &__env = {} +) +``` + + + +This constructor does **NOT** initialize any elements. It is the user's responsibility to ensure that the elements within `[vec.begin(), vec.end())` are properly initialized. + + +**Parameters** + + +The size of the buffer. + + + +The environment used to query the memory resource. + + + + + +inline + +Constructs a buffer using a memory resource and copy-constructs all elements from the forward range `[__first, __last)`. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + _Iter __first, + _Iter __last, + const _Env &__env = {} +) +``` + + + +If `__first == __last` then no memory is allocated. + + +**Parameters** + + +The start of the input sequence. + + + +The end of the input sequence. + + + +The environment used to query the memory resource. + + + + + +inline + +Constructs a buffer using a memory resource and copy-constructs all elements from `__ilist`. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + ::cuda::std::initializer_list<_Tp> __ilist, + const _Env &__env = {} +) +``` + + + +If `__ilist.size() == 0` then no memory is allocated. + + +**Parameters** + + +The initializer_list being copied into the buffer. + + + +The environment used to query the memory resource. + + + + + +inline + +Constructs a buffer using a memory resource and an input range. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + _Range &&__range, + const _Env &__env = {} +) +``` + + + +If `__range.size() == 0` then no memory is allocated. + + +**Parameters** + + +The input range to be moved into the buffer. + + + +The environment used to query the memory resource. + + + + + +--- + +## Assignment operators + +### operator= inline + +Move assignment operator. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::operator=( + buffer &&__other +) +``` + + +**Parameters** + + +The buffer to move from. + + +--- + +## Element access + +### get_unsynchronized inline noexcept + + + + +nodiscard + +Returns a reference to the `__n`'th element of the async_vector. + + +```cpp showLineNumbers={false} +reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( + const size_type __n +) noexcept +``` + + +**Returns:** `reference` + +**Parameters** + + +The index of the element. + + + + + +const nodiscard + +Returns a reference to the `__n`'th element of the async_vector. + + +```cpp showLineNumbers={false} +const_reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( + const size_type __n +) const noexcept +``` + + +**Returns:** `const_reference` + +**Parameters** + + +The index of the element. + + + + + +### data inline noexcept + + + + +nodiscard + +Returns a pointer to the first element of the buffer. + + +```cpp showLineNumbers={false} +pointer cuda::buffer<_Tp, _Properties>::data() noexcept +``` + + +**Returns:** `pointer` + + + + +const nodiscard + +Returns a pointer to the first element of the buffer. + + +```cpp showLineNumbers={false} +const_pointer cuda::buffer<_Tp, _Properties>::data() const noexcept +``` + + +**Returns:** `const_pointer` + + + + +--- + +## Iterators + +### begin inline noexcept + + + + +nodiscard + +Returns an iterator to the first element of the buffer. + + +```cpp showLineNumbers={false} +iterator cuda::buffer<_Tp, _Properties>::begin() noexcept +``` + + +**Returns:** `iterator` + + + + +const nodiscard + +Returns an immutable iterator to the first element of the buffer. + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::begin() const noexcept +``` + + +**Returns:** `const_iterator` + + + + +### cbegin inline const noexcept + +nodiscard + +Returns an immutable iterator to the first element of the buffer. + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::cbegin() const noexcept +``` + + +**Returns:** `const_iterator` + +### end inline noexcept + + + + +nodiscard + +Returns an iterator to the element following the last element of the buffer. + + +```cpp showLineNumbers={false} +iterator cuda::buffer<_Tp, _Properties>::end() noexcept +``` + + +**Returns:** `iterator` + + + + +const nodiscard + +Returns an immutable iterator to the element following the last element of the buffer. + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::end() const noexcept +``` + + +**Returns:** `const_iterator` + + + + +### cend inline const noexcept + +nodiscard + +Returns an immutable iterator to the element following the last element of the buffer. + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::cend() const noexcept +``` + + +**Returns:** `const_iterator` + +### rbegin inline noexcept + + + + +nodiscard + +Returns a reverse iterator to the first element of the reversed buffer. + + +```cpp showLineNumbers={false} +reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() noexcept +``` + + +**Returns:** `reverse_iterator` + + + + +const nodiscard + +Returns an immutable reverse iterator to the first element of the reversed buffer. + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() const noexcept +``` + + +**Returns:** `const_reverse_iterator` + + + + +### crbegin inline const noexcept + +nodiscard + +Returns an immutable reverse iterator to the first element of the reversed buffer. + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::crbegin() const noexcept +``` + + +**Returns:** `const_reverse_iterator` + +### rend inline noexcept + + + + +nodiscard + +Returns a reverse iterator to the element following the last element of the reversed buffer. + + +```cpp showLineNumbers={false} +reverse_iterator cuda::buffer<_Tp, _Properties>::rend() noexcept +``` + + +**Returns:** `reverse_iterator` + + + + +const nodiscard + +Returns an immutable reverse iterator to the element following the last element of the reversed buffer. + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::rend() const noexcept +``` + + +**Returns:** `const_reverse_iterator` + + + + +### crend inline const noexcept + +nodiscard + +Returns an immutable reverse iterator to the element following the last element of the reversed buffer. + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::crend() const noexcept +``` + + +**Returns:** `const_reverse_iterator` + +--- + +## Capacity + +### size inline const noexcept + +nodiscard + +Returns the current number of elements stored in the buffer. + + +```cpp showLineNumbers={false} +size_type cuda::buffer<_Tp, _Properties>::size() const noexcept +``` + + +**Returns:** `size_type` + +### empty inline const noexcept + +nodiscard + +Returns true if the buffer is empty. + + +```cpp showLineNumbers={false} +bool cuda::buffer<_Tp, _Properties>::empty() const noexcept +``` + + +**Returns:** `bool` + +--- + +## Resource and stream management + +### memory_resource inline const noexcept + +nodiscard + + +```cpp showLineNumbers={false} +const __resource_t& cuda::buffer<_Tp, _Properties>::memory_resource() const noexcept +``` + + +**Returns:** `const __resource_t &` + +### stream inline const constexpr noexcept + +nodiscard + +Returns the stored stream. + + +```cpp showLineNumbers={false} +stream_ref cuda::buffer<_Tp, _Properties>::stream() const noexcept +``` + + +**Returns:** [`stream_ref`](/libcudacxx/api/cuda::stream_ref) + + +Stream used to allocate the buffer is initially stored in the buffer, but can be changed with [`set_stream`](/libcudacxx/api/cuda::buffer::set_stream). + + +### set_stream inline constexpr + +Replaces the stored stream. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::set_stream( + stream_ref __new_stream +) +``` + + + +Always synchronizes with the old stream. + + +**Parameters** + + +The new stream. + + +--- + +## Modifiers + +### swap inline noexcept + +Swaps the contents of a buffer with those of `__other`. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::swap( + buffer &__other +) noexcept +``` + + +**Parameters** + + +The buffer to swap with. + + +### destroy inline + + + + +Destroys the buffer, deallocates the buffer and destroys the memory resource. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::destroy( + ::cuda::stream_ref __stream +) +``` + + + +After this explicit destroy call, the buffer can only be assigned to or destroyed. + + +**Parameters** + + +The stream to deallocate the buffer on. + + + + + +Destroys the buffer, deallocates the buffer and destroys the memory resource. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::destroy() +``` + + + +Uses the stored stream to deallocate the buffer. + + + +After this explicit destroy call, the buffer can only be assigned to or destroyed. + + + + + +--- + +## Friend functions + +### swap noexcept + + +```cpp showLineNumbers={false} +void swap( + buffer &__lhs, + buffer &__rhs +) noexcept +``` + + +**Parameters** + + +The first buffer. + + + +The second buffer. + + +### transform_launch_argument noexcept + + + + + +```cpp showLineNumbers={false} +template +::cuda::std::span<_Tp> transform_launch_argument( + ::cuda::stream_ref, + buffer &__self +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +template +::cuda::std::span transform_launch_argument( + ::cuda::stream_ref, + const buffer &__self +) noexcept +``` + + + + + +### get_property noexcept + + +```cpp showLineNumbers={false} +template +void get_property( + const buffer &, + _Property +) noexcept +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `value_type` | `_Tp` | +| `reference` | `_Tp &` | +| `const_reference` | `const _Tp &` | +| `pointer` | `_Tp *` | +| `const_pointer` | `const _Tp *` | +| `iterator` | `::cuda::heterogeneous_iterator<_Tp, _Properties...>` | +| `const_iterator` | `::cuda::heterogeneous_iterator` | +| `reverse_iterator` | `::cuda::std::reverse_iterator` | +| `const_reverse_iterator` | `::cuda::std::reverse_iterator` | +| `size_type` | `::cuda::std::size_t` | +| `difference_type` | `::cuda::std::ptrdiff_t` | +| `properties_list` | `::cuda::mr::properties_list<_Properties...>` | diff --git a/fern/pages/libcudacxx/empty_docstring_class_v4.mdx b/fern/pages/libcudacxx/empty_docstring_class_v4.mdx new file mode 100644 index 0000000..c169478 --- /dev/null +++ b/fern/pages/libcudacxx/empty_docstring_class_v4.mdx @@ -0,0 +1,830 @@ +--- +title: "cuda::buffer" +description: "A memory-safe buffer for managing typed, property-annotated device memory with stream-ordered allocation." +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type to be stored in the buffer. + + + +The properties the allocated memory satisfies. + + + + + +--- + +## Constructors + +### Copy and move constructors + + + + +inline explicit + +Copy-constructs from a buffer. + + +```cpp showLineNumbers={false} +cuda::buffer<_Tp, _Properties>::buffer( + const buffer &__other +) +``` + + +**Parameters** + + +The other buffer. + + + + + +inline noexcept + +Move-constructs from a buffer. + + +```cpp showLineNumbers={false} +cuda::buffer<_Tp, _Properties>::buffer( + buffer &&__other +) noexcept +``` + + +**Parameters** + + +The other buffer. After move construction, the other buffer can only be assigned to or destroyed. + + + + + +inline explicit + +Copy-constructs from a buffer with matching properties. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + const buffer<_Tp, _OtherProperties...> &__other +) +``` + + +**Parameters** + + +The other buffer. + + + + + +inline noexcept + +Move-constructs from a buffer with matching properties. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + buffer<_Tp, _OtherProperties...> &&__other +) noexcept +``` + + +**Parameters** + + +The other buffer. After move construction, the other buffer can only be assigned to or destroyed. + + + + + +### Resource constructors + + + + +inline + +Constructs an empty buffer using an environment. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + const _Env &__env = {} +) +``` + + + +No memory is allocated. + + +**Parameters** + + +The environment providing the needed information. + + + + + +inline explicit + +Constructs a buffer of size `__size` using a memory and leaves all elements uninitialized. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + const size_type __size, + ::cuda::no_init_t, + const _Env &__env = {} +) +``` + + + +This constructor does **NOT** initialize any elements. It is the user's responsibility to ensure that the elements within `[vec.begin(), vec.end())` are properly initialized. + + +**Parameters** + + +The size of the buffer. + + + +The environment used to query the memory resource. + + + + + +inline + +Constructs a buffer using a memory resource and copy-constructs all elements from the forward range `[__first, __last)`. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + _Iter __first, + _Iter __last, + const _Env &__env = {} +) +``` + + + +If `__first == __last` then no memory is allocated. + + +**Parameters** + + +The start of the input sequence. + + + +The end of the input sequence. + + + +The environment used to query the memory resource. + + + + + +inline + +Constructs a buffer using a memory resource and copy-constructs all elements from `__ilist`. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + ::cuda::std::initializer_list<_Tp> __ilist, + const _Env &__env = {} +) +``` + + + +If `__ilist.size() == 0` then no memory is allocated. + + +**Parameters** + + +The initializer_list being copied into the buffer. + + + +The environment used to query the memory resource. + + + + + +inline + +Constructs a buffer using a memory resource and an input range. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + _Range &&__range, + const _Env &__env = {} +) +``` + + + +If `__range.size() == 0` then no memory is allocated. + + +**Parameters** + + +The input range to be moved into the buffer. + + + +The environment used to query the memory resource. + + + + + +--- + +## Assignment operators + +### operator= inline + +Move assignment operator. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::operator=( + buffer &&__other +) +``` + + +**Parameters** + + +The buffer to move from. + + +--- + +## Element access + +### get_unsynchronized inline noexcept + + + + +nodiscard + +Returns a reference to the `__n`'th element of the async_vector. + + +```cpp showLineNumbers={false} +reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( + const size_type __n +) noexcept +``` + + +**Returns:** `reference` + +**Parameters** + + +The index of the element. + + + + + +const nodiscard + +Returns a reference to the `__n`'th element of the async_vector. + + +```cpp showLineNumbers={false} +const_reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( + const size_type __n +) const noexcept +``` + + +**Returns:** `const_reference` + +**Parameters** + + +The index of the element. + + + + + +### data inline noexcept + + + + +nodiscard + +Returns a pointer to the first element of the buffer. + + +```cpp showLineNumbers={false} +pointer cuda::buffer<_Tp, _Properties>::data() noexcept +``` + + +**Returns:** `pointer` + + + + +const nodiscard + +Returns a pointer to the first element of the buffer. + + +```cpp showLineNumbers={false} +const_pointer cuda::buffer<_Tp, _Properties>::data() const noexcept +``` + + +**Returns:** `const_pointer` + + + + +--- + +## Iterators + +### begin inline noexcept + + + + +nodiscard + +Returns an iterator to the first element of the buffer. + + +```cpp showLineNumbers={false} +iterator cuda::buffer<_Tp, _Properties>::begin() noexcept +``` + + +**Returns:** `iterator` + + + + +const nodiscard + +Returns an immutable iterator to the first element of the buffer. + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::begin() const noexcept +``` + + +**Returns:** `const_iterator` + + + + +### cbegin inline const noexcept nodiscard + +Returns an immutable iterator to the first element of the buffer. + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::cbegin() const noexcept +``` + + +**Returns:** `const_iterator` + +### end inline noexcept + + + + +nodiscard + +Returns an iterator to the element following the last element of the buffer. + + +```cpp showLineNumbers={false} +iterator cuda::buffer<_Tp, _Properties>::end() noexcept +``` + + +**Returns:** `iterator` + + + + +const nodiscard + +Returns an immutable iterator to the element following the last element of the buffer. + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::end() const noexcept +``` + + +**Returns:** `const_iterator` + + + + +### cend inline const noexcept nodiscard + +Returns an immutable iterator to the element following the last element of the buffer. + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::cend() const noexcept +``` + + +**Returns:** `const_iterator` + +### rbegin inline noexcept + + + + +nodiscard + +Returns a reverse iterator to the first element of the reversed buffer. + + +```cpp showLineNumbers={false} +reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() noexcept +``` + + +**Returns:** `reverse_iterator` + + + + +const nodiscard + +Returns an immutable reverse iterator to the first element of the reversed buffer. + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() const noexcept +``` + + +**Returns:** `const_reverse_iterator` + + + + +### crbegin inline const noexcept nodiscard + +Returns an immutable reverse iterator to the first element of the reversed buffer. + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::crbegin() const noexcept +``` + + +**Returns:** `const_reverse_iterator` + +### rend inline noexcept + + + + +nodiscard + +Returns a reverse iterator to the element following the last element of the reversed buffer. + + +```cpp showLineNumbers={false} +reverse_iterator cuda::buffer<_Tp, _Properties>::rend() noexcept +``` + + +**Returns:** `reverse_iterator` + + + + +const nodiscard + +Returns an immutable reverse iterator to the element following the last element of the reversed buffer. + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::rend() const noexcept +``` + + +**Returns:** `const_reverse_iterator` + + + + +### crend inline const noexcept nodiscard + +Returns an immutable reverse iterator to the element following the last element of the reversed buffer. + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::crend() const noexcept +``` + + +**Returns:** `const_reverse_iterator` + +--- + +## Capacity + +### size inline const noexcept nodiscard + +Returns the current number of elements stored in the buffer. + + +```cpp showLineNumbers={false} +size_type cuda::buffer<_Tp, _Properties>::size() const noexcept +``` + + +**Returns:** `size_type` + +### empty inline const noexcept nodiscard + +Returns true if the buffer is empty. + + +```cpp showLineNumbers={false} +bool cuda::buffer<_Tp, _Properties>::empty() const noexcept +``` + + +**Returns:** `bool` + +--- + +## Resource and stream management + +### memory_resource inline const noexcept nodiscard + + +```cpp showLineNumbers={false} +const __resource_t& cuda::buffer<_Tp, _Properties>::memory_resource() const noexcept +``` + + +**Returns:** `const __resource_t &` + +### stream inline const constexpr noexcept nodiscard + +Returns the stored stream. + + +```cpp showLineNumbers={false} +stream_ref cuda::buffer<_Tp, _Properties>::stream() const noexcept +``` + + +**Returns:** [`stream_ref`](/libcudacxx/api/cuda::stream_ref) + + +Stream used to allocate the buffer is initially stored in the buffer, but can be changed with [`set_stream`](/libcudacxx/api/cuda::buffer::set_stream). + + +### set_stream inline constexpr + +Replaces the stored stream. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::set_stream( + stream_ref __new_stream +) +``` + + + +Always synchronizes with the old stream. + + +**Parameters** + + +The new stream. + + +--- + +## Modifiers + +### swap inline noexcept + +Swaps the contents of a buffer with those of `__other`. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::swap( + buffer &__other +) noexcept +``` + + +**Parameters** + + +The buffer to swap with. + + +### destroy inline + + + + +Destroys the buffer, deallocates the buffer and destroys the memory resource. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::destroy( + ::cuda::stream_ref __stream +) +``` + + + +After this explicit destroy call, the buffer can only be assigned to or destroyed. + + +**Parameters** + + +The stream to deallocate the buffer on. + + + + + +Destroys the buffer, deallocates the buffer and destroys the memory resource. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::destroy() +``` + + + +Uses the stored stream to deallocate the buffer. + + + +After this explicit destroy call, the buffer can only be assigned to or destroyed. + + + + + +--- + +## Friend functions + +### swap noexcept + + +```cpp showLineNumbers={false} +void swap( + buffer &__lhs, + buffer &__rhs +) noexcept +``` + + +**Parameters** + + +The first buffer. + + + +The second buffer. + + +### transform_launch_argument noexcept + + + + + +```cpp showLineNumbers={false} +template +::cuda::std::span<_Tp> transform_launch_argument( + ::cuda::stream_ref, + buffer &__self +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +template +::cuda::std::span transform_launch_argument( + ::cuda::stream_ref, + const buffer &__self +) noexcept +``` + + + + + +### get_property noexcept + + +```cpp showLineNumbers={false} +template +void get_property( + const buffer &, + _Property +) noexcept +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `value_type` | `_Tp` | +| `reference` | `_Tp &` | +| `const_reference` | `const _Tp &` | +| `pointer` | `_Tp *` | +| `const_pointer` | `const _Tp *` | +| `iterator` | `::cuda::heterogeneous_iterator<_Tp, _Properties...>` | +| `const_iterator` | `::cuda::heterogeneous_iterator` | +| `reverse_iterator` | `::cuda::std::reverse_iterator` | +| `const_reverse_iterator` | `::cuda::std::reverse_iterator` | +| `size_type` | `::cuda::std::size_t` | +| `difference_type` | `::cuda::std::ptrdiff_t` | +| `properties_list` | `::cuda::mr::properties_list<_Properties...>` | diff --git a/fern/pages/libcudacxx/raises_example.mdx b/fern/pages/libcudacxx/raises_example.mdx new file mode 100644 index 0000000..5d77439 --- /dev/null +++ b/fern/pages/libcudacxx/raises_example.mdx @@ -0,0 +1,506 @@ +--- +title: "cuda::stream" +description: "An owning wrapper for cudaStream_t providing RAII-based stream lifecycle management." +--- + +An owning wrapper for `cudaStream_t`. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** [`cuda::stream_ref`](/libcudacxx/api/cuda::stream_ref) (public) + +--- + +## Constructors + +### stream inline + + + + +explicit + +Constructs a stream on a specified device and with specified priority. + +Priority is defaulted to [`stream::default_priority`](/libcudacxx/api/cuda::stream::default_priority). + + +```cpp showLineNumbers={false} +cuda::stream::stream( + device_ref __dev, + int __priority = default_priority +) +``` + + +**Throws:** `cuda_error` if stream creation fails. + +**Parameters** + + +The device on which to create the stream. + + + +The priority of the stream. + + + + + +explicit noexcept + +Construct a new [`stream`](/libcudacxx/api/cuda::stream) object into the moved-from state. + + +```cpp showLineNumbers={false} +cuda::stream::stream( + no_init_t +) noexcept +``` + + + +[`stream()`](/libcudacxx/api/cuda::stream::stream) returns an invalid stream handle. + + + + + +noexcept + +Move-construct a new [`stream`](/libcudacxx/api/cuda::stream) object. + + +```cpp showLineNumbers={false} +cuda::stream::stream( + stream &&__other +) noexcept +``` + + + +`__other` is in moved-from state. + + +**Parameters** + + +The stream to move from. + + + + + +explicit + + +```cpp showLineNumbers={false} +cuda::stream::stream( + ::cudaStream_t __handle +) +``` + + + + + + +```cpp showLineNumbers={false} +cuda::stream::stream( + const stream & +) = delete +``` + + + + + +### ~stream inline + +Destroy the [`stream`](/libcudacxx/api/cuda::stream) object. + + +```cpp showLineNumbers={false} +cuda::stream::~stream() +``` + + + +If the stream fails to be destroyed, the error is silently ignored. + + +--- + +## Assignment operators + +### operator= inline + + + + +noexcept + +Move-assign a [`stream`](/libcudacxx/api/cuda::stream) object. + + +```cpp showLineNumbers={false} +stream& cuda::stream::operator=( + stream &&__other +) noexcept +``` + + +**Returns:** `stream &` + + +`__other` is in a moved-from state. + + +**Parameters** + + +The stream to move from. + + + + + + +```cpp showLineNumbers={false} +stream& cuda::stream::operator=( + const stream & +) = delete +``` + + + + + +--- + +## Ownership + +### release inline + +nodiscard + +Retrieve the native `cudaStream_t` handle and give up ownership. + + +```cpp showLineNumbers={false} +::cudaStream_t cuda::stream::release() +``` + + +**Returns:** `cudaStream_t` -- the native handle being held by the [`stream`](/libcudacxx/api/cuda::stream) object. + + +The stream object is in a moved-from state. + + +--- + +## Accessors + +### get inline constexpr const noexcept + +nodiscard + +Returns the wrapped `cudaStream_t` handle. + + +```cpp showLineNumbers={false} +value_type cuda::stream_ref::get() const noexcept +``` + + +**Returns:** [`value_type`](/libcudacxx/api/cuda::stream_ref::value_type) + +--- + +## Synchronization + +### sync inline const + +Synchronizes the wrapped stream. + + +```cpp showLineNumbers={false} +void cuda::stream_ref::sync() const +``` + + +**Throws:** `cuda::cuda_error` if synchronization fails. + +### wait inline const + + + + +Deprecated. + + +```cpp showLineNumbers={false} +void cuda::stream_ref::wait() const +``` + + + +Use [`sync()`](/libcudacxx/api/cuda::stream_ref::sync) instead. + + + + + +Make all future work submitted into this stream depend on completion of the specified event. + + +```cpp showLineNumbers={false} +void cuda::stream_ref::wait( + event_ref __ev +) const +``` + + +**Throws:** `cuda_error` if inserting the dependency fails. + +**Parameters** + + +Event that this stream should wait for. + + + + + +Make all future work submitted into this stream depend on completion of all work from the specified stream. + + +```cpp showLineNumbers={false} +void cuda::stream_ref::wait( + stream_ref __other +) const +``` + + +**Throws:** `cuda_error` if inserting the dependency fails. + +**Parameters** + + +Stream that this stream should wait for. + + + + + +--- + +## Query methods + +### is_done inline const + +nodiscard + +Queries if all operations on the stream have completed. + + +```cpp showLineNumbers={false} +bool cuda::stream_ref::is_done() const +``` + + +**Returns:** `true` if all operations have completed, or `false` if not. + +**Throws:** `cuda::cuda_error` if the query fails. + +### ready inline const + +nodiscard + +Queries if all operations on the wrapped stream have completed. + + +```cpp showLineNumbers={false} +bool cuda::stream_ref::ready() const +``` + + +**Returns:** `true` if all operations have completed, or `false` if not. + +**Throws:** `cuda::cuda_error` if the query fails. + +### priority inline const + +nodiscard + +Queries the priority of the wrapped stream. + + +```cpp showLineNumbers={false} +int cuda::stream_ref::priority() const +``` + + +**Returns:** Value representing the priority of the wrapped stream. + +**Throws:** `cuda::cuda_error` if the query fails. + +### id inline const + +nodiscard + +Get the unique ID of the stream. + + +```cpp showLineNumbers={false} +stream_id cuda::stream_ref::id() const +``` + + +**Returns:** The unique ID of the stream. + +**Throws:** `cuda_error` if the ID query fails. + +### query inline constexpr const noexcept + +nodiscard + +Queries the `stream_ref` for itself. + + +```cpp showLineNumbers={false} +stream_ref cuda::stream_ref::query( + const ::cuda::get_stream_t & +) const noexcept +``` + + +**Returns:** [`stream_ref`](/libcudacxx/api/cuda::stream_ref) + +--- + +## Event recording + +### record_event inline const + +nodiscard + +Create a new event and record it into this stream. + + +```cpp showLineNumbers={false} +event cuda::stream_ref::record_event( + event_flags __flags = event_flags::none +) const +``` + + +**Returns:** A new event that was recorded into this stream. + +**Throws:** `cuda_error` if event creation or record failed. + +**Parameters** + + +Flags for event creation. + + +### record_timed_event inline const + +nodiscard + +Create a new timed event and record it into this stream. + + +```cpp showLineNumbers={false} +timed_event cuda::stream_ref::record_timed_event( + event_flags __flags = event_flags::none +) const +``` + + +**Returns:** A new timed event that was recorded into this stream. + +**Throws:** `cuda_error` if event creation or record failed. + +**Parameters** + + +Flags for event creation. + + +--- + +## Device information + +### device inline const + +nodiscard + +Get device under which this stream was created. + + +```cpp showLineNumbers={false} +device_ref cuda::stream_ref::device() const +``` + + +**Returns:** [`device_ref`](/libcudacxx/api/cuda::device_ref) + +**Throws:** `cuda_error` if device check fails. + +--- + +## Static methods + +### from_native_handle inline static + +nodiscard + +Construct an [`stream`](/libcudacxx/api/cuda::stream) object from a native `cudaStream_t` handle and take ownership. + + +```cpp showLineNumbers={false} +static stream cuda::stream::from_native_handle( + ::cudaStream_t __handle +) +``` + + +**Returns:** `stream` + +**Parameters** + + +The native handle. + + +The following overloads are deleted to prevent misuse: + +```cpp showLineNumbers={false} +static stream cuda::stream::from_native_handle(int) = delete; +static stream cuda::stream::from_native_handle(::cuda::std::nullptr_t) = delete; +static stream cuda::stream::from_native_handle(invalid_stream_t) = delete; +``` + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `default_priority` static constexpr | `int` | The default stream priority. | + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `value_type` | `::cudaStream_t` | diff --git a/fern/pages/libcudacxx/raises_example_v4.mdx b/fern/pages/libcudacxx/raises_example_v4.mdx new file mode 100644 index 0000000..ceda85a --- /dev/null +++ b/fern/pages/libcudacxx/raises_example_v4.mdx @@ -0,0 +1,495 @@ +--- +title: "cuda::stream" +description: "An owning wrapper for cudaStream_t providing RAII-based stream lifecycle management." +--- + +An owning wrapper for `cudaStream_t`. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** [`cuda::stream_ref`](/libcudacxx/api/cuda::stream_ref) (public) + +--- + +## Constructors + +### stream inline + + + + +explicit + +Constructs a stream on a specified device and with specified priority. + +Priority is defaulted to [`stream::default_priority`](/libcudacxx/api/cuda::stream::default_priority). + + +```cpp showLineNumbers={false} +cuda::stream::stream( + device_ref __dev, + int __priority = default_priority +) +``` + + +**Throws:** `cuda_error` if stream creation fails. + +**Parameters** + + +The device on which to create the stream. + + + +The priority of the stream. + + + + + +explicit noexcept + +Construct a new [`stream`](/libcudacxx/api/cuda::stream) object into the moved-from state. + + +```cpp showLineNumbers={false} +cuda::stream::stream( + no_init_t +) noexcept +``` + + + +[`stream()`](/libcudacxx/api/cuda::stream::stream) returns an invalid stream handle. + + + + + +noexcept + +Move-construct a new [`stream`](/libcudacxx/api/cuda::stream) object. + + +```cpp showLineNumbers={false} +cuda::stream::stream( + stream &&__other +) noexcept +``` + + + +`__other` is in moved-from state. + + +**Parameters** + + +The stream to move from. + + + + + +explicit + + +```cpp showLineNumbers={false} +cuda::stream::stream( + ::cudaStream_t __handle +) +``` + + + + + + +```cpp showLineNumbers={false} +cuda::stream::stream( + const stream & +) = delete +``` + + + + + +### ~stream inline + +Destroy the [`stream`](/libcudacxx/api/cuda::stream) object. + + +```cpp showLineNumbers={false} +cuda::stream::~stream() +``` + + + +If the stream fails to be destroyed, the error is silently ignored. + + +--- + +## Assignment operators + +### operator= inline + + + + +noexcept + +Move-assign a [`stream`](/libcudacxx/api/cuda::stream) object. + + +```cpp showLineNumbers={false} +stream& cuda::stream::operator=( + stream &&__other +) noexcept +``` + + +**Returns:** `stream &` + + +`__other` is in a moved-from state. + + +**Parameters** + + +The stream to move from. + + + + + + +```cpp showLineNumbers={false} +stream& cuda::stream::operator=( + const stream & +) = delete +``` + + + + + +--- + +## Ownership + +### release inline nodiscard + +Retrieve the native `cudaStream_t` handle and give up ownership. + + +```cpp showLineNumbers={false} +::cudaStream_t cuda::stream::release() +``` + + +**Returns:** `cudaStream_t` -- the native handle being held by the [`stream`](/libcudacxx/api/cuda::stream) object. + + +The stream object is in a moved-from state. + + +--- + +## Accessors + +### get inline constexpr const noexcept nodiscard + +Returns the wrapped `cudaStream_t` handle. + + +```cpp showLineNumbers={false} +value_type cuda::stream::get() const noexcept +``` + + +**Returns:** [`value_type`](/libcudacxx/api/cuda::stream_ref::value_type) + +--- + +## Synchronization + +### sync inline const + +Synchronizes the wrapped stream. + + +```cpp showLineNumbers={false} +void cuda::stream::sync() const +``` + + +**Throws:** `cuda::cuda_error` if synchronization fails. + +### wait inline const + + + + +Deprecated. + + +```cpp showLineNumbers={false} +void cuda::stream::wait() const +``` + + + +Use [`sync()`](/libcudacxx/api/cuda::stream_ref::sync) instead. + + + + + +Make all future work submitted into this stream depend on completion of the specified event. + + +```cpp showLineNumbers={false} +void cuda::stream::wait( + event_ref __ev +) const +``` + + +**Throws:** `cuda_error` if inserting the dependency fails. + +**Parameters** + + +Event that this stream should wait for. + + + + + +Make all future work submitted into this stream depend on completion of all work from the specified stream. + + +```cpp showLineNumbers={false} +void cuda::stream::wait( + stream_ref __other +) const +``` + + +**Throws:** `cuda_error` if inserting the dependency fails. + +**Parameters** + + +Stream that this stream should wait for. + + + + + +--- + +## Query methods + +### is_done inline const nodiscard + +Queries if all operations on the stream have completed. + + +```cpp showLineNumbers={false} +bool cuda::stream::is_done() const +``` + + +**Returns:** `true` if all operations have completed, or `false` if not. + +**Throws:** `cuda::cuda_error` if the query fails. + +### ready inline const nodiscard + +Queries if all operations on the wrapped stream have completed. + + +```cpp showLineNumbers={false} +bool cuda::stream::ready() const +``` + + +**Returns:** `true` if all operations have completed, or `false` if not. + +**Throws:** `cuda::cuda_error` if the query fails. + +### priority inline const nodiscard + +Queries the priority of the wrapped stream. + + +```cpp showLineNumbers={false} +int cuda::stream::priority() const +``` + + +**Returns:** Value representing the priority of the wrapped stream. + +**Throws:** `cuda::cuda_error` if the query fails. + +### id inline const nodiscard + +Get the unique ID of the stream. + + +```cpp showLineNumbers={false} +stream_id cuda::stream::id() const +``` + + +**Returns:** The unique ID of the stream. + +**Throws:** `cuda_error` if the ID query fails. + +### query inline constexpr const noexcept nodiscard + +Queries the `stream_ref` for itself. + + +```cpp showLineNumbers={false} +stream_ref cuda::stream::query( + const ::cuda::get_stream_t & +) const noexcept +``` + + +**Returns:** [`stream_ref`](/libcudacxx/api/cuda::stream_ref) + +--- + +## Event recording + +### record_event inline const nodiscard + +Create a new event and record it into this stream. + + +```cpp showLineNumbers={false} +event cuda::stream::record_event( + event_flags __flags = event_flags::none +) const +``` + + +**Returns:** A new event that was recorded into this stream. + +**Throws:** `cuda_error` if event creation or record failed. + +**Parameters** + + +Flags for event creation. + + +### record_timed_event inline const nodiscard + +Create a new timed event and record it into this stream. + + +```cpp showLineNumbers={false} +timed_event cuda::stream::record_timed_event( + event_flags __flags = event_flags::none +) const +``` + + +**Returns:** A new timed event that was recorded into this stream. + +**Throws:** `cuda_error` if event creation or record failed. + +**Parameters** + + +Flags for event creation. + + +--- + +## Device information + +### device inline const nodiscard + +Get device under which this stream was created. + + +```cpp showLineNumbers={false} +device_ref cuda::stream::device() const +``` + + +**Returns:** [`device_ref`](/libcudacxx/api/cuda::device_ref) + +**Throws:** `cuda_error` if device check fails. + +--- + +## Static methods + +### from_native_handle inline static nodiscard + + + + +Construct an [`stream`](/libcudacxx/api/cuda::stream) object from a native `cudaStream_t` handle and take ownership. + + +```cpp showLineNumbers={false} +static stream cuda::stream::from_native_handle( + ::cudaStream_t __handle +) +``` + + +**Returns:** `stream` + +**Parameters** + + +The native handle. + + + + + +The following overloads are deleted to prevent misuse: + + +```cpp showLineNumbers={false} +static stream cuda::stream::from_native_handle(int) = delete; +static stream cuda::stream::from_native_handle(::cuda::std::nullptr_t) = delete; +static stream cuda::stream::from_native_handle(invalid_stream_t) = delete; +``` + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `default_priority` static constexpr | `int` | The default stream priority. | + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `value_type` | `::cudaStream_t` | diff --git a/fern/pages/thrust/deprecated_example.mdx b/fern/pages/thrust/deprecated_example.mdx new file mode 100644 index 0000000..354aa6b --- /dev/null +++ b/fern/pages/thrust/deprecated_example.mdx @@ -0,0 +1,134 @@ +--- +title: thrust::strided_iterator +description: "An iterator adaptor that wraps another iterator and moves it by a specified stride each time it is incremented or decremented." +--- + +A [`strided_iterator`](/library/api/thrust::strided_iterator) wraps another iterator and moves it by a specified stride each time it is incremented or decremented. + +```cpp showLineNumbers={false} +#include +``` + + +Use `cuda::strided_iterator` instead. + + + + + + +A random access iterator. + + + +Either a [runtime_value](/library/api/thrust::runtime_value) or a [compile_time_value](/library/api/thrust::compile_time_value) specifying the stride. + + + + + +**Inherits from:** [`thrust::iterator_adaptor< strided_iterator< RandomAccessIterator, StrideHolder >, RandomAccessIterator >`](/library/api/thrust::iterator_adaptor) (public), `StrideHolder` (private) + +--- + +## Constructors + +### strided_iterator + + + + + +```cpp showLineNumbers={false} +thrust::strided_iterator< RandomAccessIterator, StrideHolder >::strided_iterator() = default +``` + + + + + +inline + +Creates a [strided_iterator](/library/api/thrust::strided_iterator) from an existing iterator and a stride. + + +```cpp showLineNumbers={false} +thrust::strided_iterator< RandomAccessIterator, StrideHolder >::strided_iterator( + RandomAccessIterator it, + StrideHolder stride = {} +) +``` + + +**Parameters** + + + + + + + + + + +--- + +## Methods + +### stride_holder inline const + +Returns either the [runtime_value](/library/api/thrust::runtime_value) or the [compile_time_value](/library/api/thrust::compile_time_value) holding the stride's value. + + +```cpp showLineNumbers={false} +const auto & thrust::strided_iterator< RandomAccessIterator, StrideHolder >::stride_holder() const +``` + + +### stride inline const + +Returns the stride's value. + + +```cpp showLineNumbers={false} +difference_type thrust::strided_iterator< RandomAccessIterator, StrideHolder >::stride() const +``` + + +### base inline const + + +```cpp showLineNumbers={false} +RandomAccessIterator const & thrust::iterator_adaptor< strided_iterator< RandomAccessIterator, StrideHolder >, RandomAccessIterator, use_default, use_default, use_default, use_default, use_default >::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline const + + +```cpp showLineNumbers={false} +RandomAccessIterator const & thrust::iterator_adaptor< strided_iterator< RandomAccessIterator, StrideHolder >, RandomAccessIterator, use_default, use_default, use_default, use_default, use_default >::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `RandomAccessIterator` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)'s adapts. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `has_static_stride` static constexpr | `bool` | | diff --git a/fern/pages/thrust/deprecated_example_v4.mdx b/fern/pages/thrust/deprecated_example_v4.mdx new file mode 100644 index 0000000..7a4302b --- /dev/null +++ b/fern/pages/thrust/deprecated_example_v4.mdx @@ -0,0 +1,134 @@ +--- +title: thrust::strided_iterator +description: "An iterator adaptor that wraps another iterator and moves it by a specified stride each time it is incremented or decremented." +--- + +A [`strided_iterator`](/library/api/thrust::strided_iterator) wraps another iterator and moves it by a specified stride each time it is incremented or decremented. + +```cpp showLineNumbers={false} +#include +``` + + +Use `cuda::strided_iterator` instead. + + + + + + +A random access iterator. + + + +Either a [runtime_value](/library/api/thrust::runtime_value) or a [compile_time_value](/library/api/thrust::compile_time_value) specifying the stride. + + + + + +**Inherits from:** [`thrust::iterator_adaptor< strided_iterator< RandomAccessIterator, StrideHolder >, RandomAccessIterator >`](/library/api/thrust::iterator_adaptor) (public), `StrideHolder` (private) + +--- + +## Constructors + +### strided_iterator + + + + + +```cpp showLineNumbers={false} +thrust::strided_iterator< RandomAccessIterator, StrideHolder >::strided_iterator() = default +``` + + + + + +inline + +Creates a [strided_iterator](/library/api/thrust::strided_iterator) from an existing iterator and a stride. + + +```cpp showLineNumbers={false} +thrust::strided_iterator< RandomAccessIterator, StrideHolder >::strided_iterator( + RandomAccessIterator it, + StrideHolder stride = {} +) +``` + + +**Parameters** + + + + + + + + + + +--- + +## Methods + +### stride_holder inline const + +Returns either the [runtime_value](/library/api/thrust::runtime_value) or the [compile_time_value](/library/api/thrust::compile_time_value) holding the stride's value. + + +```cpp showLineNumbers={false} +const auto & thrust::strided_iterator< RandomAccessIterator, StrideHolder >::stride_holder() const +``` + + +### stride inline const + +Returns the stride's value. + + +```cpp showLineNumbers={false} +difference_type thrust::strided_iterator< RandomAccessIterator, StrideHolder >::stride() const +``` + + +### base inline const + + +```cpp showLineNumbers={false} +RandomAccessIterator const & thrust::iterator_adaptor< strided_iterator< RandomAccessIterator, StrideHolder >, RandomAccessIterator, use_default, use_default, use_default, use_default, use_default >::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline const + + +```cpp showLineNumbers={false} +RandomAccessIterator const & thrust::iterator_adaptor< strided_iterator< RandomAccessIterator, StrideHolder >, RandomAccessIterator, use_default, use_default, use_default, use_default, use_default >::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `RandomAccessIterator` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)'s adapts. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `has_static_stride` static constexpr | `bool` | | diff --git a/fern/pages/thrust/device_vector.mdx b/fern/pages/thrust/device_vector.mdx new file mode 100644 index 0000000..8869d0d --- /dev/null +++ b/fern/pages/thrust/device_vector.mdx @@ -0,0 +1,1188 @@ +--- +title: thrust::device_vector +--- + +# device_vector + +A `device_vector` is a container that supports random access to elements, constant time removal of elements at the end, and linear time insertion and removal of elements at the beginning or in the middle. + +The number of elements in a `device_vector` may vary dynamically; memory management is automatic. The memory associated with a `device_vector` resides in the memory accessible to devices. + +```cpp +#include +``` + +**See also:** +[https://en.cppreference.com/w/cpp/container/vector](https://en.cppreference.com/w/cpp/container/vector), +[device_allocator](/library/api/thrust::device_allocator), +[host_vector](/library/api/thrust::host_vector), +[universal_vector](/library/api/thrust::universal_vector) + + + + +The element type of the vector. + + + +**[optional]** The allocator type used for memory management (default: [thrust::device_allocator](/library/api/thrust::device_allocator)``). + + + + +**Inherits from:** `detail::vector_base< T, thrust::device_allocator< T > >` (public) + +--- + +## Constructors + +### Default and allocator constructors + + + + +This constructor creates an empty `device_vector`. + +```cpp +thrust::device_vector< T, Alloc >::device_vector() +``` + + + + +This constructor creates an empty `device_vector`. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(const Alloc &alloc) +``` + +#### Parameters + + +The allocator to use by this `device_vector`. + + + + + +### Size constructors + + + + +explicit + +This constructor creates a `device_vector` with the given size. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(size_type n) +``` + +#### Parameters + + +The number of elements to initially create. + + + + + +This constructor creates a `device_vector` with the given size, performing only default-initialization instead of value-initialization. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(size_type n, default_init_t) +``` + +#### Parameters + + +The number of elements to initially create. + + + + + +This constructor creates a `device_vector` with the given size, without initializing elements. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(size_type n, no_init_t) +``` + +#### Parameters + + +The number of elements to initially create. + + + + + +explicit + +This constructor creates a `device_vector` with the given size. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(size_type n, const Alloc &alloc) +``` + +#### Parameters + + +The number of elements to initially create. + + + +The allocator to use by this `device_vector`. + + + + + +### Fill constructors + + + + +explicit + +This constructor creates a `device_vector` with copies of an exemplar element. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(size_type n, const value_type &value) +``` + +#### Parameters + + +The number of elements to initially create. + + + +An element to copy. + + + + + +explicit + +This constructor creates a `device_vector` with copies of an exemplar element. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(size_type n, const value_type &value, const Alloc &alloc) +``` + +#### Parameters + + +The number of elements to initially create. + + + +An element to copy. + + + +The allocator to use by this `device_vector`. + + + + + +### Copy constructors + + + + +Copy constructor copies from an exemplar `device_vector`. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(const device_vector &v) +``` + +#### Parameters + + +The `device_vector` to copy. + + + + + +Copy constructor copies from an exemplar `device_vector`. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(const device_vector &v, const Alloc &alloc) +``` + +#### Parameters + + +The `device_vector` to copy. + + + +The allocator to use by this `device_vector`. + + + + + +explicit + +Copy constructor copies from an exemplar `device_vector` with different type. + +```cpp +template +thrust::device_vector< T, Alloc >::device_vector(const device_vector< OtherT, OtherAlloc > &v) +``` + +#### Parameters + + +The `device_vector` to copy. + + + + + +Copy constructor copies from an exemplar `std::vector`. + +```cpp +template +thrust::device_vector< T, Alloc >::device_vector(const std::vector< OtherT, OtherAlloc > &v) +``` + +#### Parameters + + +The `std::vector` to copy. + + + + + +Copy construct from a `vector_base` whose element type is convertible to `T`. + +```cpp +template +thrust::device_vector< T, Alloc >::device_vector(const detail::vector_base< OtherT, OtherAlloc > &v) +``` + +#### Parameters + + +The `vector_base` to copy. + + + + + +### Move constructors + + + + +Move constructor moves from another `device_vector`. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(device_vector &&v) +``` + +#### Parameters + + +The `device_vector` to move. + + + + + +Move constructor moves from another `device_vector`. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(device_vector &&v, const Alloc &alloc) +``` + +#### Parameters + + +The `device_vector` to move. + + + +The allocator to use by this `device_vector`. + + + + + +### Initializer list constructors + + + + +This constructor builds a `device_vector` from an initializer_list. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(::cuda::std::initializer_list< T > il) +``` + +#### Parameters + + +The initializer_list. + + + + + +This constructor builds a `device_vector` from an initializer_list. + +```cpp +thrust::device_vector< T, Alloc >::device_vector(::cuda::std::initializer_list< T > il, const Alloc &alloc) +``` + +#### Parameters + + +The initializer_list. + + + +The allocator to use by this `device_vector`. + + + + + +### Range constructors + + + + +This constructor builds a `device_vector` from a range. + +```cpp +template +thrust::device_vector< T, Alloc >::device_vector(InputIterator first, InputIterator last) +``` + +#### Parameters + + +The beginning of the range. + + + +The end of the range. + + + + + +This constructor builds a `device_vector` from a range. + +```cpp +template +thrust::device_vector< T, Alloc >::device_vector(InputIterator first, InputIterator last, const Alloc &alloc) +``` + +#### Parameters + + +The beginning of the range. + + + +The end of the range. + + + +The allocator to use by this `device_vector`. + + + + + +### Destructor + +#### ~device_vector inline + +The destructor erases the elements. + +```cpp +thrust::device_vector< T, Alloc >::~device_vector() +``` + +--- + +## Assignment operators + +### operator= inline + + + + +Copy assign operator copies another `device_vector` with the same type. + +```cpp +device_vector & thrust::device_vector< T, Alloc >::operator=(const device_vector &v) +``` + +**Returns:** `device_vector &` + +#### Parameters + + +The `device_vector` to copy. + + + + + +Move assign operator moves from another `device_vector`. + +```cpp +device_vector & thrust::device_vector< T, Alloc >::operator=(device_vector &&v) +``` + +**Returns:** `device_vector &` + +#### Parameters + + +The `device_vector` to move. + + + + + +Assign operator copies from an exemplar `device_vector` with different type. + +```cpp +template +device_vector & thrust::device_vector< T, Alloc >::operator=(const device_vector< OtherT, OtherAlloc > &v) +``` + +**Returns:** `device_vector &` + +#### Parameters + + +The `device_vector` to copy. + + + + + +Assign operator copies from an exemplar `std::vector`. + +```cpp +template +device_vector & thrust::device_vector< T, Alloc >::operator=(const std::vector< OtherT, OtherAlloc > &v) +``` + +**Returns:** `device_vector &` + +#### Parameters + + +The `std::vector` to copy. + + + + + +Assign a `vector_base` whose element type is convertible to `T`. + +```cpp +template +device_vector & thrust::device_vector< T, Alloc >::operator=(const detail::vector_base< OtherT, OtherAlloc > &v) +``` + +**Returns:** `device_vector &` + +#### Parameters + + +The `vector_base` to copy. + + + + + +Assign an `initializer_list` with a matching element type. + +```cpp +device_vector & thrust::device_vector< T, Alloc >::operator=(::cuda::std::initializer_list< T > il) +``` + +**Returns:** `device_vector &` + +#### Parameters + + +The initializer_list. + + + + + +--- + +## Element access + +### operator[] + + + + +Subscript access to the data contained in this vector. + +```cpp +reference thrust::device_vector< T, Alloc >::operator[](size_type n) +``` + +**Returns:** Read/write reference to data. + +#### Parameters + + +The index of the element for which data should be accessed. + + + + + +const + +Subscript read access to the data contained in this vector. + +```cpp +const_reference thrust::device_vector< T, Alloc >::operator[](size_type n) const +``` + +**Returns:** Read reference to data. + +#### Parameters + + +The index of the element for which data should be accessed. + + + + + +### front + + + + +This method returns a reference pointing to the first element of this vector. + +```cpp +reference thrust::device_vector< T, Alloc >::front() +``` + +**Returns:** The first element of this vector. + + + + +const + +This method returns a `const_reference` referring to the first element of this vector. + +```cpp +const_reference thrust::device_vector< T, Alloc >::front() const +``` + +**Returns:** The first element of this vector. + + + + +### back + + + + +This method returns a reference referring to the last element of this vector. + +```cpp +reference thrust::device_vector< T, Alloc >::back() +``` + +**Returns:** The last element of this vector. + + + + +const + +This method returns a const reference pointing to the last element of this vector. + +```cpp +const_reference thrust::device_vector< T, Alloc >::back() const +``` + +**Returns:** The last element of this vector. + + + + +### data + + + + +This method returns a pointer to this vector's first element. + +```cpp +pointer thrust::device_vector< T, Alloc >::data() +``` + +**Returns:** A pointer to the first element of this vector. + + + + +const + +This method returns a `const_pointer` to this vector's first element. + +```cpp +const_pointer thrust::device_vector< T, Alloc >::data() const +``` + +**Returns:** A `const_pointer` to the first element of this vector. + + + + +--- + +## Iterators + +### begin + + + + +This method returns an iterator pointing to the beginning of this vector. + +```cpp +iterator thrust::device_vector< T, Alloc >::begin() +``` + +**Returns:** `iterator` to the beginning. + + + + +const + +This method returns a `const_iterator` pointing to the beginning of this vector. + +```cpp +const_iterator thrust::device_vector< T, Alloc >::begin() const +``` + +**Returns:** `const_iterator` to the beginning. + + + + +### cbegin const + +This method returns a `const_iterator` pointing to the beginning of this vector. + +```cpp +const_iterator thrust::device_vector< T, Alloc >::cbegin() const +``` + +**Returns:** `const_iterator` to the beginning. + +### end + + + + +This method returns an iterator pointing to one element past the last of this vector. + +```cpp +iterator thrust::device_vector< T, Alloc >::end() +``` + +**Returns:** `iterator` past the end. + + + + +const + +This method returns a `const_iterator` pointing to one element past the last of this vector. + +```cpp +const_iterator thrust::device_vector< T, Alloc >::end() const +``` + +**Returns:** `const_iterator` past the end. + + + + +### cend const + +This method returns a `const_iterator` pointing to one element past the last of this vector. + +```cpp +const_iterator thrust::device_vector< T, Alloc >::cend() const +``` + +**Returns:** `const_iterator` past the end. + +### rbegin + + + + +This method returns a `reverse_iterator` pointing to the beginning of this vector's reversed sequence. + +```cpp +reverse_iterator thrust::device_vector< T, Alloc >::rbegin() +``` + +**Returns:** A `reverse_iterator` pointing to the beginning of this vector's reversed sequence. + + + + +const + +This method returns a `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. + +```cpp +const_reverse_iterator thrust::device_vector< T, Alloc >::rbegin() const +``` + +**Returns:** A `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. + + + + +### crbegin const + +This method returns a `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. + +```cpp +const_reverse_iterator thrust::device_vector< T, Alloc >::crbegin() const +``` + +**Returns:** A `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. + +### rend + + + + +This method returns a `reverse_iterator` pointing to one element past the last of this vector's reversed sequence. + +```cpp +reverse_iterator thrust::device_vector< T, Alloc >::rend() +``` + +**Returns:** A `reverse_iterator` past the end of the reversed sequence. + + + + +const + +This method returns a `const_reverse_iterator` pointing to one element past the last of this vector's reversed sequence. + +```cpp +const_reverse_iterator thrust::device_vector< T, Alloc >::rend() const +``` + +**Returns:** A `const_reverse_iterator` past the end of the reversed sequence. + + + + +### crend const + +This method returns a `const_reverse_iterator` pointing to one element past the last of this vector's reversed sequence. + +```cpp +const_reverse_iterator thrust::device_vector< T, Alloc >::crend() const +``` + +**Returns:** A `const_reverse_iterator` past the end of the reversed sequence. + +--- + +## Capacity + +### size const + +Returns the number of elements in this vector. + +```cpp +size_type thrust::device_vector< T, Alloc >::size() const +``` + +**Returns:** `size_type` -- the number of elements. + +### max_size const + +Returns the [size()](/library/api/thrust::device_vector::size) of the largest possible vector. + +```cpp +size_type thrust::device_vector< T, Alloc >::max_size() const +``` + +**Returns:** The largest possible return value of [size()](/library/api/thrust::device_vector::size). + +### capacity const + +Returns the number of elements which have been reserved in this vector. + +```cpp +size_type thrust::device_vector< T, Alloc >::capacity() const +``` + +**Returns:** `size_type` -- the number of elements reserved. + +### empty const + +This method returns true iff [size()](/library/api/thrust::device_vector::size) == 0. + +```cpp +bool thrust::device_vector< T, Alloc >::empty() const +``` + +**Returns:** `true` if [size()](/library/api/thrust::device_vector::size) == 0; `false`, otherwise. + +### reserve + +If `n` is less than or equal to [capacity()](/library/api/thrust::device_vector::capacity), this call has no effect. Otherwise, this method is a request for allocation of additional memory. If the request is successful, then [capacity()](/library/api/thrust::device_vector::capacity) is greater than or equal to `n`; otherwise, [capacity()](/library/api/thrust::device_vector::capacity) is unchanged. In either case, [size()](/library/api/thrust::device_vector::size) is unchanged. + +```cpp +void thrust::device_vector< T, Alloc >::reserve(size_type n) +``` + +**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). + +### shrink_to_fit + +This method shrinks the capacity of this vector to exactly fit its elements. + +```cpp +void thrust::device_vector< T, Alloc >::shrink_to_fit() +``` + +### resize + + + + +Resizes this vector to the specified number of elements. + +```cpp +void thrust::device_vector< T, Alloc >::resize(size_type new_size, const value_type &x=value_type()) +``` + +**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). + +#### Parameters + + +Number of elements this vector should contain. + + + +Data with which new elements should be populated. + + + + + +Resizes this vector to the specified number of elements, performing default-initialization instead of value-initialization. + +```cpp +void thrust::device_vector< T, Alloc >::resize(size_type new_size, default_init_t) +``` + +**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). + +#### Parameters + + +Number of elements this vector should contain. + + + + + +Resizes this vector to the specified number of elements, without initializing elements. + +```cpp +void thrust::device_vector< T, Alloc >::resize(size_type new_size, no_init_t) +``` + +**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). + +#### Parameters + + +Number of elements this vector should contain. + + + + + +--- + +## Modifiers + +### push_back + +This method appends the given element to the end of this vector. + +```cpp +void thrust::device_vector< T, Alloc >::push_back(const value_type &x) +``` + +#### Parameters + + +The element to append. + + +### pop_back + +This method erases the last element of this vector, invalidating all iterators and references to it. + +```cpp +void thrust::device_vector< T, Alloc >::pop_back() +``` + +### clear + +This method resizes this vector to 0. + +```cpp +void thrust::device_vector< T, Alloc >::clear() +``` + +### swap + +This method swaps the contents of this `device_vector` with another vector. + +```cpp +void thrust::device_vector< T, Alloc >::swap(device_vector &v) +``` + +#### Parameters + + +The vector with which to swap. + + +### insert + + + + +This method inserts a single copy of a given exemplar value at the specified position in this vector. + +```cpp +iterator thrust::device_vector< T, Alloc >::insert(iterator position, const T &x) +``` + +**Returns:** An iterator pointing to the newly inserted element. + +#### Parameters + + +The insertion position. + + + +The exemplar element to copy & insert. + + + + + +This method inserts a copy of an exemplar value to a range at the specified position in this vector. + +```cpp +void thrust::device_vector< T, Alloc >::insert(iterator position, size_type n, const T &x) +``` + +#### Parameters + + +The insertion position. + + + +The number of insertions to perform. + + + +The value to replicate and insert. + + + + + +This method inserts a copy of an input range at the specified position in this vector. + +```cpp +template +void thrust::device_vector< T, Alloc >::insert(iterator position, InputIterator first, InputIterator last) +``` + +#### Template parameters + + +A model of [Input Iterator](https://en.cppreference.com/w/cpp/iterator/input_iterator) whose `value_type` is a model of [Assignable](https://en.cppreference.com/w/cpp/named_req/CopyAssignable). + + +#### Parameters + + +The insertion position. + + + +The beginning of the range to copy. + + + +The end of the range to copy. + + + + + +### erase + + + + +This method removes the element at position pos. + +```cpp +iterator thrust::device_vector< T, Alloc >::erase(iterator pos) +``` + +**Returns:** An iterator pointing to the new location of the element that followed the element at position pos. + +#### Parameters + + +The position of the element of interest. + + + + + +This method removes the range of elements [first,last) from this vector. + +```cpp +iterator thrust::device_vector< T, Alloc >::erase(iterator first, iterator last) +``` + +**Returns:** An iterator pointing to the new location of the element that followed the last element in the sequence [first,last). + +#### Parameters + + +The beginning of the range of elements to remove. + + + +The end of the range of elements to remove. + + + + + +### assign + + + + +This version of `assign` replicates a given exemplar `n` times into this vector. + +```cpp +void thrust::device_vector< T, Alloc >::assign(size_type n, const T &x) +``` + +#### Parameters + + +The number of times to copy `x`. + + + +The exemplar element to replicate. + + + + + +This version of `assign` makes this vector a copy of a given input range. + +```cpp +template +void thrust::device_vector< T, Alloc >::assign(InputIterator first, InputIterator last) +``` + +#### Template parameters + + +A model of [Input Iterator](https://en.cppreference.com/w/cpp/iterator/input_iterator). + + +#### Parameters + + +The beginning of the range to copy. + + + +The end of the range to copy. + + + + + +--- + +## Allocator + +### get_allocator const + +This method returns a copy of this vector's allocator. + +```cpp +allocator_type thrust::device_vector< T, Alloc >::get_allocator() const +``` + +**Returns:** A copy of the allocator used by this vector. + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `Parent` | `detail::vector_base< T, Alloc >` | diff --git a/fern/pages/thrust/device_vector_v3.mdx b/fern/pages/thrust/device_vector_v3.mdx new file mode 100644 index 0000000..cc89a73 --- /dev/null +++ b/fern/pages/thrust/device_vector_v3.mdx @@ -0,0 +1,1424 @@ +--- +title: thrust::device_vector +description: "A dynamically-sized container for device memory with automatic memory management." +--- + +A `device_vector` is a container that supports random access to elements, constant time removal of elements at the end, and linear time insertion and removal of elements at the beginning or in the middle. + +The number of elements in a `device_vector` may vary dynamically; memory management is automatic. The memory associated with a `device_vector` resides in the memory accessible to devices. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[https://en.cppreference.com/w/cpp/container/vector](https://en.cppreference.com/w/cpp/container/vector), +[device_allocator](/library/api/thrust::device_allocator), +[host_vector](/library/api/thrust::host_vector), +[universal_vector](/library/api/thrust::universal_vector) + + + + + +The element type of the vector. + + + +**[optional]** The allocator type used for memory management (default: [thrust::device_allocator](/library/api/thrust::device_allocator)``). + + + + + +**Inherits from:** `detail::vector_base< T, thrust::device_allocator< T > >` (public) + +--- + +## Constructors + +### Default and allocator constructors + + + + +This constructor creates an empty `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector() +``` + + + + + +This constructor creates an empty `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + const Alloc &alloc +) +``` + + +**Parameters** + + +The allocator to use by this `device_vector`. + + + + + +### Size constructors + + + + +explicit + +This constructor creates a `device_vector` with the given size. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +This constructor creates a `device_vector` with the given size, performing only default-initialization instead of value-initialization. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + default_init_t +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +This constructor creates a `device_vector` with the given size, without initializing elements. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + no_init_t +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +explicit + +This constructor creates a `device_vector` with the given size. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + const Alloc &alloc +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +The allocator to use by this `device_vector`. + + + + + +### Fill constructors + + + + +explicit + +This constructor creates a `device_vector` with copies of an exemplar element. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + const value_type &value +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +An element to copy. + + + + + +explicit + +This constructor creates a `device_vector` with copies of an exemplar element. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + const value_type &value, + const Alloc &alloc +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +An element to copy. + + + +The allocator to use by this `device_vector`. + + + + + +### Copy constructors + + + + +Copy constructor copies from an exemplar `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + const device_vector &v +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + + + +Copy constructor copies from an exemplar `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + const device_vector &v, + const Alloc &alloc +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + +The allocator to use by this `device_vector`. + + + + + +explicit + +Copy constructor copies from an exemplar `device_vector` with different type. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + const device_vector &v +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + + + +Copy constructor copies from an exemplar `std::vector`. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + const std::vector &v +) +``` + + +**Parameters** + + +The `std::vector` to copy. + + + + + +Copy construct from a `vector_base` whose element type is convertible to `T`. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + const detail::vector_base &v +) +``` + + +**Parameters** + + +The `vector_base` to copy. + + + + + +### Move constructors + + + + +Move constructor moves from another `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + device_vector &&v +) +``` + + +**Parameters** + + +The `device_vector` to move. + + + + + +Move constructor moves from another `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + device_vector &&v, + const Alloc &alloc +) +``` + + +**Parameters** + + +The `device_vector` to move. + + + +The allocator to use by this `device_vector`. + + + + + +### Initializer list constructors + + + + +This constructor builds a `device_vector` from an initializer_list. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + ::cuda::std::initializer_list il +) +``` + + +**Parameters** + + +The initializer_list. + + + + + +This constructor builds a `device_vector` from an initializer_list. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + ::cuda::std::initializer_list il, + const Alloc &alloc +) +``` + + +**Parameters** + + +The initializer_list. + + + +The allocator to use by this `device_vector`. + + + + + +### Range constructors + + + + +This constructor builds a `device_vector` from a range. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + InputIterator first, + InputIterator last +) +``` + + +**Parameters** + + +The beginning of the range. + + + +The end of the range. + + + + + +This constructor builds a `device_vector` from a range. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + InputIterator first, + InputIterator last, + const Alloc &alloc +) +``` + + +**Parameters** + + +The beginning of the range. + + + +The end of the range. + + + +The allocator to use by this `device_vector`. + + + + + +### Destructor + +### ~device_vector inline + +The destructor erases the elements. + + +```cpp showLineNumbers={false} +thrust::device_vector::~device_vector() +``` + + +--- + +## Assignment operators + +### operator= inline + + + + +Copy assign operator copies another `device_vector` with the same type. + + +```cpp showLineNumbers={false} +device_vector& thrust::device_vector::operator=( + const device_vector &v +) +``` + + +**Returns:** `device_vector &` + +**Parameters** + + +The `device_vector` to copy. + + + + + +Move assign operator moves from another `device_vector`. + + +```cpp showLineNumbers={false} +device_vector& thrust::device_vector::operator=( + device_vector &&v +) +``` + + +**Returns:** `device_vector &` + +**Parameters** + + +The `device_vector` to move. + + + + + +Assign operator copies from an exemplar `device_vector` with different type. + + +```cpp showLineNumbers={false} +template +device_vector& thrust::device_vector::operator=( + const device_vector &v +) +``` + + +**Returns:** `device_vector &` + +**Parameters** + + +The `device_vector` to copy. + + + + + +Assign operator copies from an exemplar `std::vector`. + + +```cpp showLineNumbers={false} +template +device_vector& thrust::device_vector::operator=( + const std::vector &v +) +``` + + +**Returns:** `device_vector &` + +**Parameters** + + +The `std::vector` to copy. + + + + + +Assign a `vector_base` whose element type is convertible to `T`. + + +```cpp showLineNumbers={false} +template +device_vector& thrust::device_vector::operator=( + const detail::vector_base &v +) +``` + + +**Returns:** `device_vector &` + +**Parameters** + + +The `vector_base` to copy. + + + + + +Assign an `initializer_list` with a matching element type. + + +```cpp showLineNumbers={false} +device_vector& thrust::device_vector::operator=( + ::cuda::std::initializer_list il +) +``` + + +**Returns:** `device_vector &` + +**Parameters** + + +The initializer_list. + + + + + +--- + +## Element access + +### operator[] + + + + +Subscript access to the data contained in this vector. + + +```cpp showLineNumbers={false} +reference thrust::device_vector::operator[]( + size_type n +) +``` + + +**Returns:** Read/write reference to data. + +**Parameters** + + +The index of the element for which data should be accessed. + + + + + +const + +Subscript read access to the data contained in this vector. + + +```cpp showLineNumbers={false} +const_reference thrust::device_vector::operator[]( + size_type n +) const +``` + + +**Returns:** Read reference to data. + +**Parameters** + + +The index of the element for which data should be accessed. + + + + + +### front + + + + +This method returns a reference pointing to the first element of this vector. + + +```cpp showLineNumbers={false} +reference thrust::device_vector::front() +``` + + +**Returns:** The first element of this vector. + + + + +const + +This method returns a `const_reference` referring to the first element of this vector. + + +```cpp showLineNumbers={false} +const_reference thrust::device_vector::front() const +``` + + +**Returns:** The first element of this vector. + + + + +### back + + + + +This method returns a reference referring to the last element of this vector. + + +```cpp showLineNumbers={false} +reference thrust::device_vector::back() +``` + + +**Returns:** The last element of this vector. + + + + +const + +This method returns a const reference pointing to the last element of this vector. + + +```cpp showLineNumbers={false} +const_reference thrust::device_vector::back() const +``` + + +**Returns:** The last element of this vector. + + + + +### data + + + + +This method returns a pointer to this vector's first element. + + +```cpp showLineNumbers={false} +pointer thrust::device_vector::data() +``` + + +**Returns:** A pointer to the first element of this vector. + + + + +const + +This method returns a `const_pointer` to this vector's first element. + + +```cpp showLineNumbers={false} +const_pointer thrust::device_vector::data() const +``` + + +**Returns:** A `const_pointer` to the first element of this vector. + + + + +--- + +## Iterators + +### begin + + + + +This method returns an iterator pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::begin() +``` + + +**Returns:** `iterator` to the beginning. + + + + +const + +This method returns a `const_iterator` pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::begin() const +``` + + +**Returns:** `const_iterator` to the beginning. + + + + +### cbegin const + +This method returns a `const_iterator` pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::cbegin() const +``` + + +**Returns:** `const_iterator` to the beginning. + +### end + + + + +This method returns an iterator pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::end() +``` + + +**Returns:** `iterator` past the end. + + + + +const + +This method returns a `const_iterator` pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::end() const +``` + + +**Returns:** `const_iterator` past the end. + + + + +### cend const + +This method returns a `const_iterator` pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::cend() const +``` + + +**Returns:** `const_iterator` past the end. + +### rbegin + + + + +This method returns a `reverse_iterator` pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +reverse_iterator thrust::device_vector::rbegin() +``` + + +**Returns:** A `reverse_iterator` pointing to the beginning of this vector's reversed sequence. + + + + +const + +This method returns a `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::rbegin() const +``` + + +**Returns:** A `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. + + + + +### crbegin const + +This method returns a `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::crbegin() const +``` + + +**Returns:** A `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. + +### rend + + + + +This method returns a `reverse_iterator` pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +reverse_iterator thrust::device_vector::rend() +``` + + +**Returns:** A `reverse_iterator` past the end of the reversed sequence. + + + + +const + +This method returns a `const_reverse_iterator` pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::rend() const +``` + + +**Returns:** A `const_reverse_iterator` past the end of the reversed sequence. + + + + +### crend const + +This method returns a `const_reverse_iterator` pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::crend() const +``` + + +**Returns:** A `const_reverse_iterator` past the end of the reversed sequence. + +--- + +## Capacity + +### size const + +Returns the number of elements in this vector. + + +```cpp showLineNumbers={false} +size_type thrust::device_vector::size() const +``` + + +**Returns:** `size_type` -- the number of elements. + +### max_size const + +Returns the [size()](/library/api/thrust::device_vector::size) of the largest possible vector. + + +```cpp showLineNumbers={false} +size_type thrust::device_vector::max_size() const +``` + + +**Returns:** The largest possible return value of [size()](/library/api/thrust::device_vector::size). + +### capacity const + +Returns the number of elements which have been reserved in this vector. + + +```cpp showLineNumbers={false} +size_type thrust::device_vector::capacity() const +``` + + +**Returns:** `size_type` -- the number of elements reserved. + +### empty const + +This method returns true iff [size()](/library/api/thrust::device_vector::size) == 0. + + +```cpp showLineNumbers={false} +bool thrust::device_vector::empty() const +``` + + +**Returns:** `true` if [size()](/library/api/thrust::device_vector::size) == 0; `false`, otherwise. + +### reserve + +If `n` is less than or equal to [capacity()](/library/api/thrust::device_vector::capacity), this call has no effect. Otherwise, this method is a request for allocation of additional memory. If the request is successful, then [capacity()](/library/api/thrust::device_vector::capacity) is greater than or equal to `n`; otherwise, [capacity()](/library/api/thrust::device_vector::capacity) is unchanged. In either case, [size()](/library/api/thrust::device_vector::size) is unchanged. + + +```cpp showLineNumbers={false} +void thrust::device_vector::reserve( + size_type n +) +``` + + +**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). + +### shrink_to_fit + +This method shrinks the capacity of this vector to exactly fit its elements. + + +```cpp showLineNumbers={false} +void thrust::device_vector::shrink_to_fit() +``` + + +### resize + + + + +Resizes this vector to the specified number of elements. + + +```cpp showLineNumbers={false} +void thrust::device_vector::resize( + size_type new_size, + const value_type &x = value_type() +) +``` + + +**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). + +**Parameters** + + +Number of elements this vector should contain. + + + +Data with which new elements should be populated. + + + + + +Resizes this vector to the specified number of elements, performing default-initialization instead of value-initialization. + + +```cpp showLineNumbers={false} +void thrust::device_vector::resize( + size_type new_size, + default_init_t +) +``` + + +**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). + +**Parameters** + + +Number of elements this vector should contain. + + + + + +Resizes this vector to the specified number of elements, without initializing elements. + + +```cpp showLineNumbers={false} +void thrust::device_vector::resize( + size_type new_size, + no_init_t +) +``` + + +**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). + +**Parameters** + + +Number of elements this vector should contain. + + + + + +--- + +## Modifiers + +### push_back + +This method appends the given element to the end of this vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::push_back( + const value_type &x +) +``` + + +#### Parameters + + +The element to append. + + +### pop_back + +This method erases the last element of this vector, invalidating all iterators and references to it. + + +```cpp showLineNumbers={false} +void thrust::device_vector::pop_back() +``` + + +### clear + +This method resizes this vector to 0. + + +```cpp showLineNumbers={false} +void thrust::device_vector::clear() +``` + + +### swap + +This method swaps the contents of this `device_vector` with another vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::swap( + device_vector &v +) +``` + + +#### Parameters + + +The vector with which to swap. + + +### insert + + + + +This method inserts a single copy of a given exemplar value at the specified position in this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::insert( + iterator position, + const T &x +) +``` + + +**Returns:** An iterator pointing to the newly inserted element. + +**Parameters** + + +The insertion position. + + + +The exemplar element to copy & insert. + + + + + +This method inserts a copy of an exemplar value to a range at the specified position in this vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::insert( + iterator position, + size_type n, + const T &x +) +``` + + +**Parameters** + + +The insertion position. + + + +The number of insertions to perform. + + + +The value to replicate and insert. + + + + + +This method inserts a copy of an input range at the specified position in this vector. + + +```cpp showLineNumbers={false} +template +void thrust::device_vector::insert( + iterator position, + InputIterator first, + InputIterator last +) +``` + + +**Template parameters** + + +A model of [Input Iterator](https://en.cppreference.com/w/cpp/iterator/input_iterator) whose `value_type` is a model of [Assignable](https://en.cppreference.com/w/cpp/named_req/CopyAssignable). + + +**Parameters** + + +The insertion position. + + + +The beginning of the range to copy. + + + +The end of the range to copy. + + + + + +### erase + + + + +This method removes the element at position pos. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::erase( + iterator pos +) +``` + + +**Returns:** An iterator pointing to the new location of the element that followed the element at position pos. + +**Parameters** + + +The position of the element of interest. + + + + + +This method removes the range of elements [first,last) from this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::erase( + iterator first, + iterator last +) +``` + + +**Returns:** An iterator pointing to the new location of the element that followed the last element in the sequence [first,last). + +**Parameters** + + +The beginning of the range of elements to remove. + + + +The end of the range of elements to remove. + + + + + +### assign + + + + +This version of `assign` replicates a given exemplar `n` times into this vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::assign( + size_type n, + const T &x +) +``` + + +**Parameters** + + +The number of times to copy `x`. + + + +The exemplar element to replicate. + + + + + +This version of `assign` makes this vector a copy of a given input range. + + +```cpp showLineNumbers={false} +template +void thrust::device_vector::assign( + InputIterator first, + InputIterator last +) +``` + + +**Template parameters** + + +A model of [Input Iterator](https://en.cppreference.com/w/cpp/iterator/input_iterator). + + +**Parameters** + + +The beginning of the range to copy. + + + +The end of the range to copy. + + + + + +--- + +## Allocator + +### get_allocator const + +This method returns a copy of this vector's allocator. + + +```cpp showLineNumbers={false} +allocator_type thrust::device_vector::get_allocator() const +``` + + +**Returns:** A copy of the allocator used by this vector. + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `Parent` | `detail::vector_base< T, Alloc >` | \ No newline at end of file diff --git a/fern/pages/thrust/group_member_example.mdx b/fern/pages/thrust/group_member_example.mdx new file mode 100644 index 0000000..bb46084 --- /dev/null +++ b/fern/pages/thrust/group_member_example.mdx @@ -0,0 +1,408 @@ +--- +title: thrust::mr::disjoint_unsynchronized_pool_resource +description: "A memory resource adaptor that pools and caches allocations using a separate bookkeeper for managing memory inaccessible from the host." +--- + +A memory resource adaptor allowing for pooling and caching allocations from `Upstream`, using `Bookkeeper` for management of that cached and pooled memory, allowing to cache portions of memory inaccessible from the host. + +On a typical memory resource, calls to `allocate` and `deallocate` actually allocate and deallocate memory. Pooling memory resources only allocate and deallocate memory from an external resource (the upstream memory resource) when there's no suitable memory currently cached; otherwise, they use memory they have acquired beforehand, to make memory allocation faster and more efficient. + +The disjoint version of the pool resources uses a separate upstream memory resource, `Bookkeeper`, to allocate memory necessary to manage the cached memory. There may be many reasons to do that; the canonical one is that `Upstream` allocates memory that is inaccessible to the code of the pool resource, which means that it cannot embed the necessary information in memory obtained from `Upstream`; for instance, `Upstream` can be a CUDA non-managed memory resource, or a CUDA managed memory resource whose memory we would prefer to not migrate back and forth between host and device when executing bookkeeping code. + +This is not the only case where it makes sense to use a disjoint pool resource, though. In a multi-core environment it may be beneficial to avoid stealing cache lines from other cores by writing over bookkeeping information embedded in an allocated block of memory. In such a case, one can imagine wanting to use a disjoint pool where both the upstream and the bookkeeper are of the same type, to allocate memory consistently, but separately for those two purposes. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type of memory resources that will be used for allocating memory blocks to be handed off to the user. + + + +The type of memory resources that will be used for allocating bookkeeping memory. + + + + + +**Inherits from:** [`thrust::mr::memory_resource< Upstream::pointer >`](/library/api/thrust::mr::memory_resource) (public), [`thrust::mr::validator2< Upstream, Bookkeeper >`](/library/api/thrust::mr::validator2) (private) + +This class is marked final. + +--- + +## Constructors + +### disjoint_unsynchronized_pool_resource inline + + + + +Constructor. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::disjoint_unsynchronized_pool_resource( + Upstream *upstream, + Bookkeeper *bookkeeper, + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +The upstream memory resource for allocations. + + + +The upstream memory resource for bookkeeping. + + + +Pool options to use. + + + + + +Constructor. Upstream and bookkeeping resources are obtained by calling `get_global_resource` for their types. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::disjoint_unsynchronized_pool_resource( + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +Pool options to use. + + + + + +### Destructor + +### ~disjoint_unsynchronized_pool_resource inline + +Destructor. Releases all held memory to upstream. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::~disjoint_unsynchronized_pool_resource() +``` + + +--- + +## Pool management + +### release inline + +Releases all held memory to upstream. + + +```cpp showLineNumbers={false} +void thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::release() +``` + + +### squeeze inline + + +```cpp showLineNumbers={false} +void thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::squeeze() +``` + + +--- + +## Allocation + +### do_allocate inline nodiscard virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +virtual void_ptr thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::do_allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Parameters** + + +Size, in bytes, that is requested from this allocation. + + + +Alignment that is requested from this allocation. + + +### do_allocate_impl inline nodiscard + + +```cpp showLineNumbers={false} +void_ptr thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::do_allocate_impl( + std::size_t bytes, + std::size_t alignment +) +``` + + +### do_deallocate inline virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::do_deallocate( + void_ptr p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Parameters** + + +Pointer to be deallocated. + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource< Upstream::pointer >::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Parameters** + + +Size, in bytes, that is requested from this allocation. + + + +Alignment that is requested from this allocation. + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource< Upstream::pointer >::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated. + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +--- + +## Comparison + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource< Upstream::pointer >::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** Whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to. + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource< Upstream::pointer >::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** Whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to. + + +--- + +## Static methods + +### get_default_options inline static + +Get the default options for a disjoint pool. + + +```cpp showLineNumbers={false} +static pool_options thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::get_default_options() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `void_ptr` | `typename Upstream::pointer` | | +| `char_ptr` | `typename thrust::detail::pointer_traits< void_ptr >::template rebind< char >::other` | | +| `chunk_vector` | `thrust::host_vector< chunk_descriptor, allocator< chunk_descriptor, Bookkeeper > >` | | +| `oversized_block_vector` | `thrust::host_vector< oversized_block_descriptor, allocator< oversized_block_descriptor, Bookkeeper > >` | | +| `pointer_vector` | `thrust::host_vector< void_ptr, allocator< void_ptr, Bookkeeper > >` | | +| `pool_vector` | `thrust::host_vector< pool, allocator< pool, Bookkeeper > >` | | +| `pointer` | `Upstream::pointer` | Alias for the template parameter. | + +--- + +## Member variables + +| Name | Type | +|---|---| +| `m_upstream` | `Upstream *` | +| `m_bookkeeper` | `Bookkeeper *` | +| `m_options` | [`pool_options`](/library/api/thrust::mr::pool_options) | +| `m_smallest_block_log2` | `std::size_t` | +| `m_pools` | `pool_vector` | +| `m_allocated` | `chunk_vector` | +| `m_cached_oversized` | `oversized_block_vector` | +| `m_oversized` | `oversized_block_vector` | + +--- + +## Inner classes + +### chunk_descriptor + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::chunk_descriptor +``` + + +| Member | Type | +|---|---| +| `size` | `std::size_t` | +| `pointer` | `void_ptr` | +| `pool_idx` | `std::size_t` | + +### oversized_block_descriptor + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::oversized_block_descriptor +``` + + +| Member | Type | +|---|---| +| `size` | `std::size_t` | +| `alignment` | `std::size_t` | +| `pointer` | `void_ptr` | + +### equal_pointers + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::equal_pointers +``` + + +| Member | Type | +|---|---| +| `p` | `void_ptr` | + +### matching_alignment + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::matching_alignment +``` + + +| Member | Type | +|---|---| +| `requested` | `std::size_t` | + +### pool + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::pool +``` + + +| Member | Type | +|---|---| +| `free_blocks` | `pointer_vector` | +| `previous_allocated_count` | `std::size_t` | diff --git a/fern/pages/thrust/group_member_example_v4.mdx b/fern/pages/thrust/group_member_example_v4.mdx new file mode 100644 index 0000000..df1192b --- /dev/null +++ b/fern/pages/thrust/group_member_example_v4.mdx @@ -0,0 +1,408 @@ +--- +title: thrust::mr::disjoint_unsynchronized_pool_resource +description: "A memory resource adaptor that pools and caches allocations using a separate bookkeeper for managing memory inaccessible from the host." +--- + +A memory resource adaptor allowing for pooling and caching allocations from `Upstream`, using `Bookkeeper` for management of that cached and pooled memory, allowing to cache portions of memory inaccessible from the host. + +On a typical memory resource, calls to `allocate` and `deallocate` actually allocate and deallocate memory. Pooling memory resources only allocate and deallocate memory from an external resource (the upstream memory resource) when there's no suitable memory currently cached; otherwise, they use memory they have acquired beforehand, to make memory allocation faster and more efficient. + +The disjoint version of the pool resources uses a separate upstream memory resource, `Bookkeeper`, to allocate memory necessary to manage the cached memory. There may be many reasons to do that; the canonical one is that `Upstream` allocates memory that is inaccessible to the code of the pool resource, which means that it cannot embed the necessary information in memory obtained from `Upstream`; for instance, `Upstream` can be a CUDA non-managed memory resource, or a CUDA managed memory resource whose memory we would prefer to not migrate back and forth between host and device when executing bookkeeping code. + +This is not the only case where it makes sense to use a disjoint pool resource, though. In a multi-core environment it may be beneficial to avoid stealing cache lines from other cores by writing over bookkeeping information embedded in an allocated block of memory. In such a case, one can imagine wanting to use a disjoint pool where both the upstream and the bookkeeper are of the same type, to allocate memory consistently, but separately for those two purposes. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type of memory resources that will be used for allocating memory blocks to be handed off to the user. + + + +The type of memory resources that will be used for allocating bookkeeping memory. + + + + + +**Inherits from:** [`thrust::mr::memory_resource< Upstream::pointer >`](/library/api/thrust::mr::memory_resource) (public), [`thrust::mr::validator2< Upstream, Bookkeeper >`](/library/api/thrust::mr::validator2) (private) + +This class is marked final. + +--- + +## Constructors + +### disjoint_unsynchronized_pool_resource inline + + + + +Constructor. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::disjoint_unsynchronized_pool_resource( + Upstream *upstream, + Bookkeeper *bookkeeper, + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +The upstream memory resource for allocations. + + + +The upstream memory resource for bookkeeping. + + + +Pool options to use. + + + + + +Constructor. Upstream and bookkeeping resources are obtained by calling `get_global_resource` for their types. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::disjoint_unsynchronized_pool_resource( + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +Pool options to use. + + + + + +### Destructor + +### ~disjoint_unsynchronized_pool_resource inline + +Destructor. Releases all held memory to upstream. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::~disjoint_unsynchronized_pool_resource() +``` + + +--- + +## Pool management + +### release inline + +Releases all held memory to upstream. + + +```cpp showLineNumbers={false} +void thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::release() +``` + + +### squeeze inline + + +```cpp showLineNumbers={false} +void thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::squeeze() +``` + + +--- + +## Allocation + +### do_allocate inline nodiscard virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +virtual void_ptr thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::do_allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Parameters** + + +Size, in bytes, that is requested from this allocation. + + + +Alignment that is requested from this allocation. + + +### do_allocate_impl inline nodiscard + + +```cpp showLineNumbers={false} +void_ptr thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::do_allocate_impl( + std::size_t bytes, + std::size_t alignment +) +``` + + +### do_deallocate inline virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::do_deallocate( + void_ptr p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Parameters** + + +Pointer to be deallocated. + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource< Upstream::pointer >::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Parameters** + + +Size, in bytes, that is requested from this allocation. + + + +Alignment that is requested from this allocation. + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource< Upstream::pointer >::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated. + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +--- + +## Comparison + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource< Upstream::pointer >::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** Whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to. + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource< Upstream::pointer >::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** Whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to. + + +--- + +## Static methods + +### get_default_options inline static + +Get the default options for a disjoint pool. + + +```cpp showLineNumbers={false} +static pool_options thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::get_default_options() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `void_ptr` | `typename Upstream::pointer` | | +| `char_ptr` | `typename thrust::detail::pointer_traits< void_ptr >::template rebind< char >::other` | | +| `chunk_vector` | `thrust::host_vector< chunk_descriptor, allocator< chunk_descriptor, Bookkeeper > >` | | +| `oversized_block_vector` | `thrust::host_vector< oversized_block_descriptor, allocator< oversized_block_descriptor, Bookkeeper > >` | | +| `pointer_vector` | `thrust::host_vector< void_ptr, allocator< void_ptr, Bookkeeper > >` | | +| `pool_vector` | `thrust::host_vector< pool, allocator< pool, Bookkeeper > >` | | +| `pointer` | `Upstream::pointer` | Alias for the template parameter. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `m_upstream` | `Upstream *` | | +| `m_bookkeeper` | `Bookkeeper *` | | +| `m_options` | [`pool_options`](/library/api/thrust::mr::pool_options) | | +| `m_smallest_block_log2` | `std::size_t` | | +| `m_pools` | `pool_vector` | | +| `m_allocated` | `chunk_vector` | | +| `m_cached_oversized` | `oversized_block_vector` | | +| `m_oversized` | `oversized_block_vector` | | + +--- + +## Inner classes + +### chunk_descriptor + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::chunk_descriptor +``` + + +| Name | Type | Description | +|---|---|---| +| `size` | `std::size_t` | | +| `pointer` | `void_ptr` | | +| `pool_idx` | `std::size_t` | | + +### oversized_block_descriptor + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::oversized_block_descriptor +``` + + +| Name | Type | Description | +|---|---|---| +| `size` | `std::size_t` | | +| `alignment` | `std::size_t` | | +| `pointer` | `void_ptr` | | + +### equal_pointers + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::equal_pointers +``` + + +| Name | Type | Description | +|---|---|---| +| `p` | `void_ptr` | | + +### matching_alignment + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::matching_alignment +``` + + +| Name | Type | Description | +|---|---|---| +| `requested` | `std::size_t` | | + +### pool + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::pool +``` + + +| Name | Type | Description | +|---|---|---| +| `free_blocks` | `pointer_vector` | | +| `previous_allocated_count` | `std::size_t` | | diff --git a/fern/pages/thrust/pointer.mdx b/fern/pages/thrust/pointer.mdx new file mode 100644 index 0000000..362e7e4 --- /dev/null +++ b/fern/pages/thrust/pointer.mdx @@ -0,0 +1,295 @@ +--- +title: thrust::pointer +description: "A tagged pointer type that stores a pointer to an object allocated in memory, generalizing device_ptr with configurable backend systems." +--- + +`pointer` stores a pointer to an object allocated in memory. + +Like [`device_ptr`](/library/api/thrust::device_ptr), this type ensures type safety when dispatching standard algorithms on ranges resident in memory. + +`pointer` generalizes [`device_ptr`](/library/api/thrust::device_ptr) by relaxing the backend system associated with the `pointer`. Instead of the backend system specified by `THRUST_DEVICE_SYSTEM`, `pointer`'s system is given by its second template parameter, `Tag`. For the purpose of Thrust dispatch, [`device_ptr`](/library/api/thrust::device_ptr) and `pointer` are considered equivalent. + +The raw pointer encapsulated by a `pointer` may be obtained through its [`get`](/library/api/thrust::pointer::get) member function or the `raw_pointer_cast` free function. + +```cpp showLineNumbers={false} +#include +``` + + +`pointer` is not a smart pointer; it is the client's responsibility to deallocate memory pointer to by `pointer`. + + +**See also:** +[device_ptr](/library/api/thrust::device_ptr), +reference, +[raw_pointer_cast](/library/api/thrust::raw_pointer_cast) + + + + + +Specifies the type of the pointed-to object. + + + +Specifies the system with which this `pointer` is associated. This may be any Thrust backend system, or a user-defined tag. + + + +Allows the client to specify the reference type returned upon dereference. By default, this type is `reference`. + + + +Allows the client to specify the name of the derived type when `pointer` is used as a base class. This is useful to ensure that arithmetic on values of the derived type return values of the derived type as a result. By default, this type is `pointer`. + + + + + +**Inherits from:** [`thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >`](/library/api/thrust::iterator_adaptor) (public) + +--- + +## Constructors + +### pointer inline + + + + +`pointer`'s default constructor initializes its encapsulated pointer to `0`. + + +```cpp showLineNumbers={false} +thrust::pointer< Element, Tag, Reference, Derived >::pointer() +``` + + + + + + +```cpp showLineNumbers={false} +thrust::pointer< Element, Tag, Reference, Derived >::pointer( + ::cuda::std::nullptr_t +) +``` + + + + + +explicit + +This constructor allows construction of a `pointer` from a `T *`. + + +```cpp showLineNumbers={false} +template +thrust::pointer< Element, Tag, Reference, Derived >::pointer( + OtherElement *ptr +) +``` + + +**Template parameters** + + +`OtherElement` shall be convertible to `Element`. + + +**Parameters** + + +A raw pointer to copy from, presumed to point to a location in `Tag`'s memory. + + + + + +This constructor allows initialization from another pointer-like object. + + +```cpp showLineNumbers={false} +template * = nullptr> +thrust::pointer< Element, Tag, Reference, Derived >::pointer( + const OtherPointer &other +) +``` + + +**Template parameters** + + +The tag associated with `OtherPointer` shall be convertible to `Tag`, and its element type shall be convertible to `Element`. + + +**Parameters** + + +The `OtherPointer` to copy. + + + + + +--- + +## Assignment operators + +### operator= inline + + + + + +```cpp showLineNumbers={false} +derived_type& thrust::pointer< Element, Tag, Reference, Derived >::operator=( + ::cuda::std::nullptr_t +) +``` + + +**Returns:** `derived_type &` + + + + +Assignment operator allows assigning from another pointer-like object whose element type is convertible to `Element`. + + +```cpp showLineNumbers={false} +template +detail::enable_if_pointer_is_convertible_t +thrust::pointer< Element, Tag, Reference, Derived >::operator=( + const OtherPointer &other +) +``` + + +**Returns:** `*this` + +**Template parameters** + + +The tag associated with `OtherPointer` shall be convertible to `Tag`, and its element type shall be convertible to `Element`. + + +**Parameters** + + +The other pointer-like object to assign from. + + + + + +--- + +## Methods + +### get inline const + +`get` returns this `pointer`'s encapsulated raw pointer. + + +```cpp showLineNumbers={false} +Element * thrust::pointer< Element, Tag, Reference, Derived >::get() const +``` + + +**Returns:** This `pointer`'s raw pointer. + +### operator-> inline const + + +```cpp showLineNumbers={false} +Element * thrust::pointer< Element, Tag, Reference, Derived >::operator->() const +``` + + +### operator bool inline explicit const + + +```cpp showLineNumbers={false} +thrust::pointer< Element, Tag, Reference, Derived >::operator bool() const +``` + + +### dereference inline const + + +```cpp showLineNumbers={false} +template +SuperRef thrust::pointer< Element, Tag, Reference, Derived >::dereference() const +``` + + +### base inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference + + + + +inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +inline + + +```cpp showLineNumbers={false} +Base & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Static methods + +### pointer_to inline static + + +```cpp showLineNumbers={false} +static derived_type thrust::pointer< Element, Tag, Reference, Derived >::pointer_to( + typename detail::pointer_traits_detail::pointer_to_param< Element >::type r +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `super_t` | `typename detail::pointer_base< Element, Tag, Reference, Derived >::type` | | +| `derived_type` | `typename detail::pointer_base< Element, Tag, Reference, Derived >::derived_type` | | +| `raw_pointer` | `typename super_t::base_type` | The type of the raw pointer. | +| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)'s adapts. | diff --git a/fern/pages/thrust/pointer_v4.mdx b/fern/pages/thrust/pointer_v4.mdx new file mode 100644 index 0000000..3c7b616 --- /dev/null +++ b/fern/pages/thrust/pointer_v4.mdx @@ -0,0 +1,300 @@ +--- +title: thrust::pointer +description: "A tagged pointer type that stores a pointer to an object allocated in memory, generalizing device_ptr with configurable backend systems." +--- + +`pointer` stores a pointer to an object allocated in memory. + +Like [`device_ptr`](/library/api/thrust::device_ptr), this type ensures type safety when dispatching standard algorithms on ranges resident in memory. + +`pointer` generalizes [`device_ptr`](/library/api/thrust::device_ptr) by relaxing the backend system associated with the `pointer`. Instead of the backend system specified by `THRUST_DEVICE_SYSTEM`, `pointer`'s system is given by its second template parameter, `Tag`. For the purpose of Thrust dispatch, [`device_ptr`](/library/api/thrust::device_ptr) and `pointer` are considered equivalent. + +The raw pointer encapsulated by a `pointer` may be obtained through its [`get`](/library/api/thrust::pointer::get) member function or the `raw_pointer_cast` free function. + +```cpp showLineNumbers={false} +#include +``` + + +`pointer` is not a smart pointer; it is the client's responsibility to deallocate memory pointer to by `pointer`. + + +**See also:** +[device_ptr](/library/api/thrust::device_ptr), +reference, +[raw_pointer_cast](/library/api/thrust::raw_pointer_cast) + + + + + +Specifies the type of the pointed-to object. + + + +Specifies the system with which this `pointer` is associated. This may be any Thrust backend system, or a user-defined tag. + + + +Allows the client to specify the reference type returned upon dereference. By default, this type is `reference`. + + + +Allows the client to specify the name of the derived type when `pointer` is used as a base class. This is useful to ensure that arithmetic on values of the derived type return values of the derived type as a result. By default, this type is `pointer`. + + + + + +**Inherits from:** [`thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >`](/library/api/thrust::iterator_adaptor) (public) + +--- + +## Constructors + +### pointer inline + + + + +`pointer`'s default constructor initializes its encapsulated pointer to `0`. + + +```cpp showLineNumbers={false} +thrust::pointer< Element, Tag, Reference, Derived >::pointer() +``` + + + + + + +```cpp showLineNumbers={false} +thrust::pointer< Element, Tag, Reference, Derived >::pointer( + ::cuda::std::nullptr_t +) +``` + + + + + +explicit + +This constructor allows construction of a `pointer` from a `T *`. + + +```cpp showLineNumbers={false} +template +thrust::pointer< Element, Tag, Reference, Derived >::pointer( + OtherElement *ptr +) +``` + + +**Template parameters** + + +`OtherElement` shall be convertible to `Element`. + + +**Parameters** + + +A raw pointer to copy from, presumed to point to a location in `Tag`'s memory. + + + + + +This constructor allows initialization from another pointer-like object. + + +```cpp showLineNumbers={false} +template * = nullptr> +thrust::pointer< Element, Tag, Reference, Derived >::pointer( + const OtherPointer &other +) +``` + + +**Template parameters** + + +The tag associated with `OtherPointer` shall be convertible to `Tag`, and its element type shall be convertible to `Element`. + + +**Parameters** + + +The `OtherPointer` to copy. + + + + + +--- + +## Assignment operators + +### operator= inline + + + + + +```cpp showLineNumbers={false} +derived_type& thrust::pointer< Element, Tag, Reference, Derived >::operator=( + ::cuda::std::nullptr_t +) +``` + + +**Returns:** `derived_type &` + + + + +Assignment operator allows assigning from another pointer-like object whose element type is convertible to `Element`. + + +```cpp showLineNumbers={false} +template +detail::enable_if_pointer_is_convertible_t +thrust::pointer< Element, Tag, Reference, Derived >::operator=( + const OtherPointer &other +) +``` + + +**Returns:** `*this` + +**Template parameters** + + +The tag associated with `OtherPointer` shall be convertible to `Tag`, and its element type shall be convertible to `Element`. + + +**Parameters** + + +The other pointer-like object to assign from. + + + + + +--- + +## Methods + +### get inline const + +`get` returns this `pointer`'s encapsulated raw pointer. + + +```cpp showLineNumbers={false} +Element * thrust::pointer< Element, Tag, Reference, Derived >::get() const +``` + + +**Returns:** This `pointer`'s raw pointer. + +### operator-> inline const + + +```cpp showLineNumbers={false} +Element * thrust::pointer< Element, Tag, Reference, Derived >::operator->() const +``` + + +### operator bool inline explicit const + + +```cpp showLineNumbers={false} +thrust::pointer< Element, Tag, Reference, Derived >::operator bool() const +``` + + +### dereference inline const + + +```cpp showLineNumbers={false} +template +SuperRef thrust::pointer< Element, Tag, Reference, Derived >::dereference() const +``` + + +### base inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference + + + + +inline + + +```cpp showLineNumbers={false} +Base & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Static methods + +### pointer_to inline static + + +```cpp showLineNumbers={false} +static derived_type thrust::pointer< Element, Tag, Reference, Derived >::pointer_to( + typename detail::pointer_traits_detail::pointer_to_param< Element >::type r +) +``` + + +**Parameters** + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `super_t` | `typename detail::pointer_base< Element, Tag, Reference, Derived >::type` | | +| `derived_type` | `typename detail::pointer_base< Element, Tag, Reference, Derived >::derived_type` | | +| `raw_pointer` | `typename super_t::base_type` | The type of the raw pointer. | +| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)'s adapts. | From 9b3a416578a29edd02d5aea60024d60aced1e058 Mon Sep 17 00:00:00 2001 From: Paarth Gupta Date: Wed, 4 Mar 2026 17:13:38 -0500 Subject: [PATCH 6/6] c++ library docs --- .../cub/cub/AgentAdjacentDifferencePolicy.mdx | 38 + .../cub/cub/cub/AgentHistogramPolicy.mdx | 59 + .../cub/cub/cub/AgentMergeSortPolicy.mdx | 38 + .../cub/cub/AgentRadixSortDownsweepPolicy.mdx | 61 + .../cub/AgentRadixSortExclusiveSumPolicy.mdx | 25 + .../cub/cub/AgentRadixSortHistogramPolicy.mdx | 50 + .../cub/cub/AgentRadixSortOnesweepPolicy.mdx | 51 + .../cub/cub/AgentRadixSortUpsweepPolicy.mdx | 46 + .../cub/cub/cub/AgentReduceByKeyPolicy.mdx | 48 + .../cub/cub/cub/AgentReducePolicy.mdx | 51 + fern/cudapages/cub/cub/cub/AgentRlePolicy.mdx | 53 + .../cub/cub/cub/AgentScanByKeyPolicy.mdx | 47 + .../cudapages/cub/cub/cub/AgentScanPolicy.mdx | 60 + .../cub/cub/cub/AgentSelectIfPolicy.mdx | 48 + .../cub/cub/AgentSubWarpMergeSortPolicy.mdx | 43 + .../cub/cub/AgentThreeWayPartitionPolicy.mdx | 40 + .../cub/cub/cub/AgentUniqueByKeyPolicy.mdx | 43 + .../cub/cub/cub/AgentWarpReducePolicy.mdx | 50 + .../cub/cub/cub/ArgIndexInputIterator.mdx | 290 + fern/cudapages/cub/cub/cub/ArgMax.mdx | 24 + fern/cudapages/cub/cub/cub/ArgMin.mdx | 24 + .../cub/cub/cub/BFEDigitExtractor.mdx | 82 + .../cub/cub/cub/BaseDigitExtractor.mdx | 49 + .../cub/cub/BaseDigitExtractor_KeyT_true.mdx | 38 + .../cub/cub/cub/BlockAdjacentDifference.mdx | 762 +++ .../cub/cub/cub/BlockDiscontinuity.mdx | 1003 +++ fern/cudapages/cub/cub/cub/BlockExchange.mdx | 972 +++ fern/cudapages/cub/cub/cub/BlockHistogram.mdx | 360 ++ fern/cudapages/cub/cub/cub/BlockLoad.mdx | 419 ++ fern/cudapages/cub/cub/cub/BlockLoadType.mdx | 29 + fern/cudapages/cub/cub/cub/BlockMergeSort.mdx | 224 + .../cub/cub/cub/BlockMergeSortStrategy.mdx | 507 ++ fern/cudapages/cub/cub/cub/BlockRadixRank.mdx | 284 + .../cub/cub/BlockRadixRankEmptyCallback.mdx | 29 + .../cub/cub/cub/BlockRadixRankMatch.mdx | 246 + .../cub/BlockRadixRankMatchEarlyCounts.mdx | 173 + fern/cudapages/cub/cub/cub/BlockRadixSort.mdx | 1753 ++++++ .../cub/cub/cub/BlockRakingLayout.mdx | 96 + fern/cudapages/cub/cub/cub/BlockReduce.mdx | 565 ++ .../cub/cub/cub/BlockRunLengthDecode.mdx | 296 + fern/cudapages/cub/cub/cub/BlockScan.mdx | 1797 ++++++ .../cub/cub/cub/BlockScanRunningPrefixOp.mdx | 91 + fern/cudapages/cub/cub/cub/BlockShuffle.mdx | 344 ++ fern/cudapages/cub/cub/cub/BlockStore.mdx | 388 ++ .../cub/cub/CacheModifiedInputIterator.mdx | 265 + .../cub/cub/CacheModifiedOutputIterator.mdx | 279 + .../cub/cub/cub/CachingDeviceAllocator.mdx | 253 + fern/cudapages/cub/cub/cub/CastOp.mdx | 32 + fern/cudapages/cub/cub/cub/ChainedPolicy.mdx | 84 + .../cub/cub/cub/DeviceAdjacentDifference.mdx | 524 ++ fern/cudapages/cub/cub/cub/DeviceCopy.mdx | 235 + fern/cudapages/cub/cub/cub/DeviceFind.mdx | 270 + fern/cudapages/cub/cub/cub/DeviceFor.mdx | 728 +++ .../cudapages/cub/cub/cub/DeviceHistogram.mdx | 1313 ++++ fern/cudapages/cub/cub/cub/DeviceMemcpy.mdx | 146 + fern/cudapages/cub/cub/cub/DeviceMerge.mdx | 202 + .../cudapages/cub/cub/cub/DeviceMergeSort.mdx | 633 ++ .../cudapages/cub/cub/cub/DevicePartition.mdx | 730 +++ .../cudapages/cub/cub/cub/DeviceRadixSort.mdx | 2874 +++++++++ fern/cudapages/cub/cub/cub/DeviceReduce.mdx | 1750 ++++++ .../cub/cub/cub/DeviceRleDispatch.mdx | 212 + .../cub/cub/cub/DeviceRunLengthEncode.mdx | 270 + fern/cudapages/cub/cub/cub/DeviceScan.mdx | 2197 +++++++ .../cub/cub/cub/DeviceSegmentedRadixSort.mdx | 1260 ++++ .../cub/cub/cub/DeviceSegmentedReduce.mdx | 1181 ++++ .../cub/cub/cub/DeviceSegmentedScan.mdx | 1186 ++++ .../cub/cub/cub/DeviceSegmentedSort.mdx | 2522 ++++++++ fern/cudapages/cub/cub/cub/DeviceSelect.mdx | 1326 ++++ fern/cudapages/cub/cub/cub/DeviceTopK.mdx | 338 + .../cudapages/cub/cub/cub/DeviceTransform.mdx | 476 ++ .../cub/cub/DispatchAdjacentDifference.mdx | 110 + .../cub/cub/cub/DispatchHistogram.mdx | 661 ++ .../cub/cub/cub/DispatchMergeSort.mdx | 132 + .../cub/cub/cub/DispatchRadixSort.mdx | 377 ++ fern/cudapages/cub/cub/cub/DispatchReduce.mdx | 241 + .../cub/cub/cub/DispatchReduceByKey.mdx | 210 + fern/cudapages/cub/cub/cub/DispatchScan.mdx | 241 + .../cub/cub/cub/DispatchScanByKey.mdx | 242 + .../cub/cub/DispatchSegmentedRadixSort.mdx | 274 + .../cub/cub/cub/DispatchSegmentedReduce.mdx | 218 + .../cub/cub/cub/DispatchSegmentedSort.mdx | 190 + .../cub/cub/cub/DispatchSelectIf.mdx | 248 + .../cub/cub/DispatchThreeWayPartitionIf.mdx | 142 + .../cub/cub/cub/DispatchUniqueByKey.mdx | 247 + fern/cudapages/cub/cub/cub/GridEvenShare.mdx | 164 + fern/cudapages/cub/cub/cub/GridQueue.mdx | 179 + .../cub/cub/cub/InequalityWrapper.mdx | 57 + .../cub/cub/cub/PtxVersionCacheTag.mdx | 4 + .../cub/cub/cub/RadixSortTwiddle.mdx | 70 + fern/cudapages/cub/cub/cub/ReduceByKeyOp.mdx | 83 + .../cub/cub/cub/ReduceByKeyScanTileState.mdx | 21 + ...ceByKeyScanTileState_ValueT_KeyT_false.mdx | 44 + .../cub/cub/cub/ReduceBySegmentOp.mdx | 95 + fern/cudapages/cub/cub/cub/ScanTileState.mdx | 18 + .../cub/cub/cub/ScanTileState_T_false.mdx | 180 + .../cub/cub/cub/ShiftDigitExtractor.mdx | 82 + .../cub/cub/cub/SmVersionCacheTag.mdx | 4 + fern/cudapages/cub/cub/cub/SwizzleScanOp.mdx | 57 + .../cub/cub/cub/TilePrefixCallbackOp.mdx | 192 + fern/cudapages/cub/cub/cub/WarpExchange.mdx | 244 + fern/cudapages/cub/cub/cub/WarpLoad.mdx | 421 ++ fern/cudapages/cub/cub/cub/WarpMergeSort.mdx | 145 + fern/cudapages/cub/cub/cub/WarpReduce.mdx | 749 +++ fern/cudapages/cub/cub/cub/WarpScan.mdx | 1184 ++++ fern/cudapages/cub/cub/cub/WarpStore.mdx | 353 ++ .../cuda/cuda/cuda/arch_traits_t.mdx | 52 + .../cuda/cuda/cuda/buffer.mdx} | 453 +- .../cuda/cuda/cuda/compute_capability.mdx | 207 + .../cuda/cuda/cuda/constant_iterator.mdx | 255 + .../cuda/cuda/cuda/copy_configuration.mdx | 20 + .../cuda/cuda/cuda/counting_iterator.mdx | 232 + .../compute_capability_t.mdx | 32 + .../cuda/cuda/cuda/device_memory_pool.mdx | 140 + .../cuda/cuda/cuda/device_memory_pool_ref.mdx | 68 + fern/cudapages/cuda/cuda/cuda/device_ref.mdx | 155 + .../cuda/cuda/cuda/discard_iterator.mdx | 264 + fern/cudapages/cuda/cuda/cuda/event.mdx | 342 + fern/cudapages/cuda/cuda/cuda/event_ref.mdx | 132 + .../cudapages/cuda/cuda/cuda/get_stream_t.mdx | 92 + .../cudapages/cuda/cuda/cuda/has_property.mdx | 33 + .../cuda/cuda/cuda/has_property_with.mdx | 36 + .../cuda/cuda/cuda/heterogeneous_iterator.mdx | 272 + .../cuda/cuda/cuda/managed_memory_pool.mdx | 133 + .../cuda/cuda/managed_memory_pool_ref.mdx | 50 + .../cuda/cuda/cuda/memory_pool_properties.mdx | 23 + .../cuda/cuda/cuda/mr/basic_any_resource.mdx | 591 ++ .../cuda/cuda/cuda/mr/basic_resource_ref.mdx | 375 ++ .../cuda/cuda/cuda/mr/device_accessible.mdx | 10 + .../cuda/cuda/cuda/mr/host_accessible.mdx | 10 + .../mr/legacy_managed_memory_resource.mdx | 120 + .../cuda/mr/legacy_pinned_memory_resource.mdx | 117 + .../cuda/cuda/cuda/mr/properties_list.mdx | 45 + fern/cudapages/cuda/cuda/cuda/mr/resource.mdx | 39 + .../cuda/cuda/cuda/mr/resource_with.mdx | 27 + .../cuda/cuda/cuda/mr/shared_resource.mdx | 435 ++ .../cuda/cuda/mr/synchronous_resource.mdx | 36 + .../cuda/mr/synchronous_resource_adapter.mdx | 146 + .../cuda/mr/synchronous_resource_with.mdx | 27 + .../cuda/cuda/cuda/permutation_iterator.mdx | 367 ++ .../cuda/cuda/cuda/pinned_memory_pool.mdx | 167 + .../cuda/cuda/cuda/pinned_memory_pool_ref.mdx | 50 + .../cuda/cuda/cuda/property_with_value.mdx | 24 + .../cuda/cuda/cuda/shuffle_iterator.mdx | 259 + .../cuda/cuda/cuda/std/pointer_traits.mdx | 56 + fern/cudapages/cuda/cuda/cuda/stream.mdx | 445 ++ fern/cudapages/cuda/cuda/cuda/stream_ref.mdx | 306 + .../cuda/cuda/cuda/strided_iterator.mdx | 341 + .../cuda/cuda/tabulate_output_iterator.mdx | 295 + fern/cudapages/cuda/cuda/cuda/timed_event.mdx | 286 + .../cuda/transform_input_output_iterator.mdx | 345 ++ .../cuda/cuda/cuda/transform_iterator.mdx | 358 ++ .../cuda/cuda/transform_output_iterator.mdx | 331 + .../cudapages/cuda/cuda/cuda/zip_function.mdx | 67 + .../cudapages/cuda/cuda/cuda/zip_iterator.mdx | 332 + .../cuda/cuda/cuda/zip_transform_iterator.mdx | 307 + .../thrust/thrust/thrust/allocator_delete.mdx | 171 + .../thrust/thrust/array_allocator_delete.mdx | 173 + .../bidirectional_device_iterator_tag.mdx | 17 + .../thrust/bidirectional_traversal_tag.mdx | 12 + .../thrust/thrust/compile_time_value.mdx | 27 + .../thrust/thrust/thrust/complex.mdx | 710 +++ .../thrust/thrust/constant_iterator.mdx | 217 + .../thrust/thrust/counting_iterator.mdx | 223 + .../thrust/thrust/thrust/device_allocator.mdx | 211 + .../thrust/thrust/device_execution_policy.mdx | 72 + .../thrust/thrust/device_malloc_allocator.mdx | 253 + .../thrust/thrust/device_new_allocator.mdx | 236 + .../thrust/thrust/thrust/device_ptr.mdx | 324 + .../thrust/device_ptr_memory_resource.mdx | 239 + .../thrust/thrust/thrust/device_reference.mdx | 1189 ++++ .../thrust/thrust/thrust/device_vector.mdx | 1368 ++++ .../thrust/thrust/discard_block_engine.mdx | 223 + .../thrust/thrust/thrust/discard_iterator.mdx | 126 + .../thrust/thrust/thrust/error_category.mdx | 139 + .../thrust/thrust/thrust/error_code.mdx | 175 + .../thrust/thrust/thrust/error_condition.mdx | 211 + .../thrust/forward_device_iterator_tag.mdx | 17 + .../thrust/thrust/forward_traversal_tag.mdx | 12 + .../thrust/thrust/host_execution_policy.mdx | 72 + .../thrust/thrust/thrust/host_vector.mdx | 1365 ++++ .../thrust/incrementable_traversal_tag.mdx | 12 + .../thrust/input_device_iterator_tag.mdx | 17 + .../thrust/thrust/is_error_code_enum.mdx | 21 + .../thrust/thrust/is_error_condition_enum.mdx | 21 + .../thrust/thrust/thrust/iterator_adaptor.mdx | 182 + .../thrust/thrust/iterator_core_access.mdx | 10 + .../thrust/thrust/iterator_difference.mdx | 29 + .../thrust/thrust/thrust/iterator_facade.mdx | 216 + .../thrust/thrust/thrust/iterator_pointer.mdx | 29 + .../thrust/thrust/iterator_reference.mdx | 29 + .../thrust/thrust/thrust/iterator_system.mdx | 21 + .../thrust/iterator_system_const_void_ptr.mdx | 10 + ...r_system_cudaconstant_iterator_T_Index.mdx | 32 + ...tor_system_cudacounting_iterator_Start.mdx | 29 + .../iterator_system_cudadiscard_iterator.mdx | 20 + ...m_cudapermutation_iterator_Iter_Offset.mdx | 32 + ...dashuffle_iterator_IndexType_Bijection.mdx | 32 + ...or_system_cudastdreverse_iterator_Iter.mdx | 19 + ...ystem_cudastrided_iterator_Iter_Stride.mdx | 22 + ..._cudatabulate_output_iterator_Fn_Index.mdx | 32 + ..._output_iterator_InputFn_OutputFn_Iter.mdx | 25 + ..._system_cudatransform_iterator_Fn_Iter.mdx | 22 + ..._cudatransform_output_iterator_Fn_Iter.mdx | 22 + ...ator_system_cudazip_iterator_Iterators.mdx | 29 + ...udazip_transform_iterator_Fn_Iterators.mdx | 32 + .../thrust/iterator_system_void_ptr.mdx | 10 + .../thrust/thrust/iterator_traversal.mdx | 21 + ...raversal_cudaconstant_iterator_T_Index.mdx | 32 + ..._traversal_cudacounting_iterator_Start.mdx | 29 + ...terator_traversal_cudadiscard_iterator.mdx | 20 + ...l_cudapermutation_iterator_Iter_Offset.mdx | 32 + ...dashuffle_iterator_IndexType_Bijection.mdx | 32 + ...traversal_cudastdreverse_iterator_Iter.mdx | 19 + ...ersal_cudastrided_iterator_Iter_Stride.mdx | 22 + ..._cudatabulate_output_iterator_Fn_Index.mdx | 32 + ..._output_iterator_InputFn_OutputFn_Iter.mdx | 25 + ...aversal_cudatransform_iterator_Fn_Iter.mdx | 22 + ..._cudatransform_output_iterator_Fn_Iter.mdx | 22 + ...r_traversal_cudazip_iterator_Iterators.mdx | 29 + ...udazip_transform_iterator_Fn_Iterators.mdx | 32 + .../thrust/thrust/thrust/iterator_value.mdx | 29 + .../thrust/linear_congruential_engine.mdx | 200 + .../thrust/linear_feedback_shift_engine.mdx | 144 + .../thrust/thrust/thrust/mr/allocator.mdx | 185 + .../disjoint_synchronized_pool_resource.mdx | 293 + .../disjoint_unsynchronized_pool_resource.mdx | 410 ++ .../thrust/mr/fancy_pointer_resource.mdx | 229 + .../thrust/thrust/mr/memory_resource.mdx | 204 + .../thrust/mr/memory_resource_void_ptr.mdx | 98 + .../thrust/thrust/mr/new_delete_resource.mdx | 207 + .../thrust/mr/new_delete_resource_base.mdx | 201 + .../mr/polymorphic_adaptor_resource.mdx | 211 + .../thrust/thrust/thrust/mr/pool_options.mdx | 43 + .../mr/stateless_resource_allocator.mdx | 221 + .../thrust/mr/synchronized_pool_resource.mdx | 284 + .../mr/unsynchronized_pool_resource.mdx | 373 ++ .../thrust/thrust/thrust/mr/validator.mdx | 17 + .../thrust/thrust/thrust/mr/validator2.mdx | 22 + .../thrust/thrust/mr/validator2_T_T.mdx | 19 + .../thrust/thrust/thrust/no_traversal_tag.mdx | 10 + .../thrust/thrust/normal_distribution.mdx | 271 + .../thrust/thrust/thrust/offset_iterator.mdx | 166 + .../thrust/output_device_iterator_tag.mdx | 17 + .../thrust/thrust/per_device_allocator.mdx | 208 + .../thrust/thrust/permutation_iterator.mdx | 190 + .../thrust/thrust/thrust/pointer.mdx | 289 + .../thrust/proclaim_contiguous_iterator.mdx | 25 + .../thrust/thrust/thrust/project1st.mdx | 57 + .../thrust/thrust/project1st_void_void.mdx | 33 + .../thrust/thrust/thrust/project2nd.mdx | 57 + .../thrust/thrust/project2nd_void_void.mdx | 33 + .../thrust/random/discard_block_engine.mdx | 223 + .../random/linear_congruential_engine.mdx | 200 + .../random/linear_feedback_shift_engine.mdx | 144 + .../thrust/random/normal_distribution.mdx | 271 + .../random/subtract_with_carry_engine.mdx | 152 + .../random/uniform_int_distribution.mdx | 277 + .../random/uniform_real_distribution.mdx | 277 + .../thrust/random/xor_combine_engine.mdx | 239 + .../random_access_device_iterator_tag.mdx | 17 + .../thrust/random_access_traversal_tag.mdx | 12 + .../thrust/thrust/thrust/runtime_value.mdx | 27 + .../thrust/thrust/thrust/shuffle_iterator.mdx | 127 + .../thrust/single_pass_traversal_tag.mdx | 12 + .../cudapages/thrust/thrust/thrust/square.mdx | 57 + .../thrust/thrust/thrust/square_void.mdx | 33 + .../thrust/thrust/thrust/strided_iterator.mdx | 126 + .../thrust/subtract_with_carry_engine.mdx | 152 + .../thrust/thrust/system/error_category.mdx | 139 + .../thrust/thrust/system/error_code.mdx | 175 + .../thrust/thrust/system/error_condition.mdx | 211 + .../thrust/system/is_error_code_enum.mdx | 21 + .../is_error_code_enum_cudaerrcerrc_t.mdx | 12 + .../thrust/system/is_error_condition_enum.mdx | 21 + .../is_error_condition_enum_errcerrc_t.mdx | 12 + .../thrust/thrust/system/system_error.mdx | 290 + .../thrust/thrust/thrust/system_error.mdx | 290 + .../thrust/tabulate_output_iterator.mdx | 149 + .../thrust/thrust/thrust/tagged_deleter.mdx | 46 + .../transform_input_output_iterator.mdx | 172 + .../thrust/thrust/transform_iterator.mdx | 317 + .../thrust/transform_output_iterator.mdx | 164 + .../thrust/uniform_int_distribution.mdx | 277 + .../thrust/uniform_real_distribution.mdx | 277 + .../thrust/thrust/xor_combine_engine.mdx | 239 + .../thrust/thrust/thrust/zip_function.mdx | 140 + .../thrust/thrust/thrust/zip_iterator.mdx | 200 + fern/docs.yml | 658 +- fern/docs/pages/nominal-data-model.mdx | 150 + fern/docs/pages/steps-toc-test.mdx | 92 + .../langchain-core-docs/_navigation.yml | 1305 ++++ .../langchain-core/langchain_core.mdx | 75 + .../langchain-core/langchain_core/_api.mdx | 121 + .../langchain_core/_api/beta_decorator.mdx | 208 + .../langchain_core/_api/deprecation.mdx | 458 ++ .../langchain_core/_api/internal.mdx | 35 + .../langchain_core/_api/path.mdx | 135 + .../langchain_core/_import_utils.mdx | 65 + .../langchain_core/_security.mdx | 9 + .../_security/_ssrf_protection.mdx | 499 ++ .../langchain-core/langchain_core/agents.mdx | 415 ++ .../langchain-core/langchain_core/caches.mdx | 541 ++ .../langchain_core/callbacks.mdx | 91 + .../langchain_core/callbacks/base.mdx | 2445 ++++++++ .../langchain_core/callbacks/file.mdx | 477 ++ .../langchain_core/callbacks/manager.mdx | 2935 +++++++++ .../langchain_core/callbacks/stdout.mdx | 259 + .../callbacks/streaming_stdout.mdx | 479 ++ .../langchain_core/callbacks/usage.mdx | 116 + .../langchain_core/chat_history.mdx | 442 ++ .../langchain_core/chat_loaders.mdx | 81 + .../langchain_core/chat_sessions.mdx | 46 + .../langchain_core/document_loaders.mdx | 88 + .../langchain_core/document_loaders/base.mdx | 273 + .../document_loaders/blob_loaders.mdx | 78 + .../document_loaders/langsmith.mdx | 118 + .../langchain_core/documents.mdx | 113 + .../langchain_core/documents/base.mdx | 495 ++ .../langchain_core/documents/compressor.mdx | 133 + .../langchain_core/documents/transformers.mdx | 104 + .../langchain_core/embeddings.mdx | 87 + .../langchain_core/embeddings/embeddings.mdx | 185 + .../langchain_core/embeddings/fake.mdx | 196 + .../langchain-core/langchain_core/env.mdx | 39 + .../langchain_core/example_selectors.mdx | 91 + .../langchain_core/example_selectors/base.mdx | 165 + .../example_selectors/length_based.mdx | 213 + .../example_selectors/semantic_similarity.mdx | 656 ++ .../langchain_core/exceptions.mdx | 204 + .../langchain-core/langchain_core/globals.mdx | 210 + .../langchain_core/indexing.mdx | 92 + .../langchain_core/indexing/api.mdx | 789 +++ .../langchain_core/indexing/base.mdx | 1269 ++++ .../langchain_core/indexing/in_memory.mdx | 151 + .../langchain_core/language_models.mdx | 120 + .../langchain_core/language_models/_utils.mdx | 327 + .../langchain_core/language_models/base.mdx | 608 ++ .../language_models/chat_models.mdx | 1297 ++++ .../langchain_core/language_models/fake.mdx | 199 + .../language_models/fake_chat_models.mdx | 437 ++ .../langchain_core/language_models/llms.mdx | 1335 ++++ .../language_models/model_profile.mdx | 144 + .../langchain-core/langchain_core/load.mdx | 90 + .../langchain_core/load/_validation.mdx | 252 + .../langchain_core/load/dump.mdx | 177 + .../langchain_core/load/load.mdx | 692 +++ .../langchain_core/load/mapping.mdx | 85 + .../langchain_core/load/serializable.mdx | 524 ++ .../langchain_core/messages.mdx | 99 + .../langchain_core/messages/ai.mdx | 486 ++ .../langchain_core/messages/base.mdx | 537 ++ .../messages/block_translators.mdx | 159 + .../messages/block_translators/anthropic.mdx | 200 + .../messages/block_translators/bedrock.mdx | 141 + .../block_translators/bedrock_converse.mdx | 219 + .../block_translators/google_genai.mdx | 244 + .../block_translators/google_vertexai.mdx | 37 + .../messages/block_translators/groq.mdx | 176 + .../block_translators/langchain_v0.mdx | 79 + .../messages/block_translators/openai.mdx | 408 ++ .../langchain_core/messages/chat.mdx | 85 + .../langchain_core/messages/content.mdx | 1941 ++++++ .../langchain_core/messages/function.mdx | 92 + .../langchain_core/messages/human.mdx | 70 + .../langchain_core/messages/modifier.mdx | 43 + .../langchain_core/messages/system.mdx | 71 + .../langchain_core/messages/tool.mdx | 495 ++ .../langchain_core/messages/utils.mdx | 1401 +++++ .../langchain_core/output_parsers.mdx | 110 + .../langchain_core/output_parsers/base.mdx | 516 ++ .../output_parsers/format_instructions.mdx | 27 + .../langchain_core/output_parsers/json.mdx | 242 + .../langchain_core/output_parsers/list.mdx | 473 ++ .../output_parsers/openai_functions.mdx | 422 ++ .../output_parsers/openai_tools.mdx | 437 ++ .../output_parsers/pydantic.mdx | 204 + .../langchain_core/output_parsers/string.mdx | 111 + .../output_parsers/transform.mdx | 242 + .../langchain_core/output_parsers/xml.mdx | 356 ++ .../langchain-core/langchain_core/outputs.mdx | 109 + .../outputs/chat_generation.mdx | 192 + .../langchain_core/outputs/chat_result.mdx | 62 + .../langchain_core/outputs/generation.mdx | 158 + .../langchain_core/outputs/llm_result.mdx | 131 + .../langchain_core/outputs/run_info.mdx | 47 + .../langchain_core/prompt_values.mdx | 416 ++ .../langchain-core/langchain_core/prompts.mdx | 99 + .../langchain_core/prompts/base.mdx | 669 ++ .../langchain_core/prompts/chat.mdx | 1931 ++++++ .../langchain_core/prompts/dict.mdx | 240 + .../langchain_core/prompts/few_shot.mdx | 671 ++ .../prompts/few_shot_with_templates.mdx | 267 + .../langchain_core/prompts/image.mdx | 230 + .../langchain_core/prompts/loading.mdx | 284 + .../langchain_core/prompts/message.mdx | 223 + .../langchain_core/prompts/prompt.mdx | 374 ++ .../langchain_core/prompts/string.mdx | 597 ++ .../langchain_core/prompts/structured.mdx | 211 + .../langchain_core/rate_limiters.mdx | 300 + .../langchain_core/retrievers.mdx | 433 ++ .../langchain_core/runnables.mdx | 117 + .../langchain_core/runnables/base.mdx | 5475 +++++++++++++++++ .../langchain_core/runnables/branch.mdx | 287 + .../langchain_core/runnables/config.mdx | 777 +++ .../langchain_core/runnables/configurable.mdx | 796 +++ .../langchain_core/runnables/fallbacks.mdx | 352 ++ .../langchain_core/runnables/graph.mdx | 1150 ++++ .../langchain_core/runnables/graph_ascii.mdx | 391 ++ .../runnables/graph_mermaid.mdx | 325 + .../langchain_core/runnables/graph_png.mdx | 361 ++ .../langchain_core/runnables/history.mdx | 497 ++ .../langchain_core/runnables/passthrough.mdx | 1080 ++++ .../langchain_core/runnables/retry.mdx | 414 ++ .../langchain_core/runnables/router.mdx | 241 + .../langchain_core/runnables/schema.mdx | 243 + .../langchain_core/runnables/utils.mdx | 1426 +++++ .../langchain-core/langchain_core/stores.mdx | 657 ++ .../langchain_core/structured_query.mdx | 426 ++ .../langchain_core/sys_info.mdx | 64 + .../langchain-core/langchain_core/tools.mdx | 94 + .../langchain_core/tools/base.mdx | 1764 ++++++ .../langchain_core/tools/convert.mdx | 348 ++ .../langchain_core/tools/render.mdx | 116 + .../langchain_core/tools/retriever.mdx | 110 + .../langchain_core/tools/simple.mdx | 270 + .../langchain_core/tools/structured.mdx | 304 + .../langchain-core/langchain_core/tracers.mdx | 99 + .../langchain_core/tracers/_compat.mdx | 229 + .../langchain_core/tracers/_streaming.mdx | 115 + .../langchain_core/tracers/base.mdx | 1678 +++++ .../langchain_core/tracers/context.mdx | 239 + .../langchain_core/tracers/core.mdx | 1017 +++ .../langchain_core/tracers/evaluation.mdx | 233 + .../langchain_core/tracers/event_stream.mdx | 777 +++ .../langchain_core/tracers/langchain.mdx | 704 +++ .../langchain_core/tracers/log_stream.mdx | 759 +++ .../langchain_core/tracers/memory_stream.mdx | 282 + .../langchain_core/tracers/root_listeners.mdx | 218 + .../langchain_core/tracers/run_collector.mdx | 75 + .../langchain_core/tracers/schemas.mdx | 41 + .../langchain_core/tracers/stdout.mdx | 387 ++ .../langchain-core/langchain_core/utils.mdx | 105 + .../langchain_core/utils/_merge.mdx | 140 + .../langchain_core/utils/aiter.mdx | 541 ++ .../langchain_core/utils/env.mdx | 145 + .../langchain_core/utils/formatting.mdx | 145 + .../langchain_core/utils/function_calling.mdx | 789 +++ .../langchain_core/utils/html.mdx | 201 + .../langchain_core/utils/image.mdx | 35 + .../langchain_core/utils/input.mdx | 185 + .../langchain_core/utils/interactive_env.mdx | 39 + .../langchain_core/utils/iter.mdx | 370 ++ .../langchain_core/utils/json.mdx | 260 + .../langchain_core/utils/json_schema.mdx | 283 + .../langchain_core/utils/mustache.mdx | 532 ++ .../langchain_core/utils/pydantic.mdx | 679 ++ .../langchain_core/utils/strings.mdx | 149 + .../langchain_core/utils/usage.mdx | 81 + .../langchain_core/utils/utils.mdx | 567 ++ .../langchain_core/utils/uuid.mdx | 103 + .../langchain_core/vectorstores.mdx | 113 + .../langchain_core/vectorstores/base.mdx | 1712 ++++++ .../langchain_core/vectorstores/in_memory.mdx | 560 ++ .../langchain_core/vectorstores/utils.mdx | 171 + .../langchain-core/langchain_core/version.mdx | 27 + .../library-docs/nemo-rl-docs/_navigation.yml | 1033 ++++ .../nemo-rl-docs/nemo-rl/nemo_rl.mdx | 149 + .../nemo-rl/nemo_rl/algorithms.mdx | 19 + .../algorithms/advantage_estimator.mdx | 196 + .../nemo_rl/algorithms/async_utils.mdx | 572 ++ .../nemo_rl/algorithms/distillation.mdx | 326 + .../nemo-rl/nemo_rl/algorithms/dpo.mdx | 378 ++ .../nemo-rl/nemo_rl/algorithms/grpo.mdx | 916 +++ .../nemo-rl/nemo_rl/algorithms/interfaces.mdx | 123 + .../nemo_rl/algorithms/loss_functions.mdx | 875 +++ .../nemo_rl/algorithms/reward_functions.mdx | 102 + .../nemo-rl/nemo_rl/algorithms/rm.mdx | 320 + .../nemo-rl/nemo_rl/algorithms/sft.mdx | 258 + .../nemo-rl/nemo_rl/algorithms/utils.mdx | 379 ++ .../nemo-rl-docs/nemo-rl/nemo_rl/data.mdx | 466 ++ .../nemo-rl/nemo_rl/data/chat_templates.mdx | 35 + .../nemo-rl/nemo_rl/data/collate_fn.mdx | 166 + .../nemo-rl/nemo_rl/data/datasets.mdx | 37 + .../nemo_rl/data/datasets/eval_datasets.mdx | 60 + .../data/datasets/eval_datasets/aime.mdx | 64 + .../data/datasets/eval_datasets/gpqa.mdx | 64 + .../eval_datasets/local_math_dataset.mdx | 65 + .../data/datasets/eval_datasets/math.mdx | 61 + .../data/datasets/eval_datasets/mmlu.mdx | 61 + .../data/datasets/eval_datasets/mmlu_pro.mdx | 60 + .../data/datasets/preference_datasets.mdx | 72 + .../binary_preference_dataset.mdx | 102 + .../preference_datasets/helpsteer3.mdx | 66 + .../preference_dataset.mdx | 77 + .../datasets/preference_datasets/tulu3.mdx | 59 + .../data/datasets/processed_dataset.mdx | 135 + .../nemo_rl/data/datasets/raw_dataset.mdx | 94 + .../data/datasets/response_datasets.mdx | 82 + .../datasets/response_datasets/aime24.mdx | 66 + .../data/datasets/response_datasets/clevr.mdx | 97 + .../datasets/response_datasets/dapo_math.mdx | 84 + .../datasets/response_datasets/deepscaler.mdx | 59 + .../datasets/response_datasets/geometry3k.mdx | 76 + .../datasets/response_datasets/helpsteer3.mdx | 66 + .../response_datasets/nemogym_dataset.mdx | 54 + .../response_datasets/oai_format_dataset.mdx | 214 + .../data/datasets/response_datasets/oasst.mdx | 127 + .../response_datasets/openmathinstruct2.mdx | 84 + .../datasets/response_datasets/refcoco.mdx | 160 + .../response_datasets/response_dataset.mdx | 104 + .../data/datasets/response_datasets/squad.mdx | 66 + .../data/datasets/response_datasets/tulu3.mdx | 76 + .../nemo-rl/nemo_rl/data/datasets/utils.mdx | 191 + .../nemo-rl/nemo_rl/data/interfaces.mdx | 284 + .../nemo_rl/data/llm_message_utils.mdx | 548 ++ .../nemo-rl/nemo_rl/data/multimodal_utils.mdx | 298 + .../nemo-rl/nemo_rl/data/packing.mdx | 30 + .../nemo_rl/data/packing/algorithms.mdx | 791 +++ .../nemo-rl/nemo_rl/data/packing/metrics.mdx | 177 + .../nemo-rl/nemo_rl/data/processors.mdx | 353 ++ .../nemo-rl/nemo_rl/data/utils.mdx | 104 + .../nemo-rl/nemo_rl/distributed.mdx | 17 + .../nemo_rl/distributed/batched_data_dict.mdx | 671 ++ .../nemo_rl/distributed/collectives.mdx | 108 + .../nemo_rl/distributed/model_utils.mdx | 851 +++ .../nemo_rl/distributed/named_sharding.mdx | 236 + .../ray_actor_environment_registry.mdx | 105 + .../distributed/stateless_process_group.mdx | 73 + .../nemo_rl/distributed/virtual_cluster.mdx | 514 ++ .../distributed/worker_group_utils.mdx | 81 + .../nemo_rl/distributed/worker_groups.mdx | 603 ++ .../nemo-rl/nemo_rl/environments.mdx | 19 + .../nemo_rl/environments/code_environment.mdx | 290 + .../environments/code_jaccard_environment.mdx | 268 + .../environments/dapo_math_verifier.mdx | 316 + .../nemo_rl/environments/interfaces.mdx | 151 + .../nemo_rl/environments/math_environment.mdx | 356 ++ .../nemo-rl/nemo_rl/environments/metrics.mdx | 42 + .../nemo-rl/nemo_rl/environments/nemo_gym.mdx | 213 + .../environments/reward_model_environment.mdx | 276 + .../nemo-rl/nemo_rl/environments/rewards.mdx | 180 + .../nemo-rl/nemo_rl/environments/utils.mdx | 152 + .../nemo_rl/environments/vlm_environment.mdx | 243 + .../nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx | 10 + .../nemo-rl/nemo_rl/evals/answer_parsing.mdx | 86 + .../nemo-rl/nemo_rl/evals/eval.mdx | 399 ++ .../nemo-rl/nemo_rl/experience.mdx | 9 + .../nemo-rl/nemo_rl/experience/rollouts.mdx | 489 ++ .../nemo-rl-docs/nemo-rl/nemo_rl/models.mdx | 14 + .../nemo-rl/nemo_rl/models/automodel.mdx | 12 + .../nemo_rl/models/automodel/config.mdx | 125 + .../nemo-rl/nemo_rl/models/automodel/data.mdx | 374 ++ .../nemo_rl/models/automodel/setup.mdx | 229 + .../nemo_rl/models/automodel/train.mdx | 841 +++ .../nemo-rl/nemo_rl/models/dtensor.mdx | 9 + .../nemo_rl/models/dtensor/parallelize.mdx | 454 ++ .../nemo-rl/nemo_rl/models/generation.mdx | 62 + .../nemo_rl/models/generation/interfaces.mdx | 569 ++ .../nemo_rl/models/generation/sglang.mdx | 33 + .../models/generation/sglang/config.mdx | 299 + .../generation/sglang/sglang_copied_utils.mdx | 307 + .../generation/sglang/sglang_generation.mdx | 369 ++ .../generation/sglang/sglang_worker.mdx | 529 ++ .../models/generation/sglang/utils.mdx | 109 + .../nemo_rl/models/generation/vllm.mdx | 34 + .../nemo_rl/models/generation/vllm/config.mdx | 111 + .../nemo_rl/models/generation/vllm/utils.mdx | 113 + .../models/generation/vllm/vllm_backend.mdx | 236 + .../generation/vllm/vllm_generation.mdx | 656 ++ .../models/generation/vllm/vllm_worker.mdx | 545 ++ .../generation/vllm/vllm_worker_async.mdx | 485 ++ .../nemo-rl/nemo_rl/models/huggingface.mdx | 9 + .../nemo_rl/models/huggingface/common.mdx | 303 + .../nemo-rl/nemo_rl/models/megatron.mdx | 15 + .../nemo_rl/models/megatron/common.mdx | 133 + .../models/megatron/community_import.mdx | 76 + .../nemo_rl/models/megatron/config.mdx | 142 + .../nemo-rl/nemo_rl/models/megatron/data.mdx | 471 ++ .../models/megatron/pipeline_parallel.mdx | 124 + .../nemo-rl/nemo_rl/models/megatron/setup.mdx | 535 ++ .../nemo-rl/nemo_rl/models/megatron/train.mdx | 538 ++ .../nemo-rl/nemo_rl/models/policy.mdx | 948 +++ .../nemo_rl/models/policy/interfaces.mdx | 574 ++ .../nemo_rl/models/policy/lm_policy.mdx | 609 ++ .../nemo-rl/nemo_rl/models/policy/utils.mdx | 624 ++ .../nemo-rl/nemo_rl/models/policy/workers.mdx | 13 + .../policy/workers/base_policy_worker.mdx | 309 + .../policy/workers/dtensor_policy_worker.mdx | 693 +++ .../workers/dtensor_policy_worker_v2.mdx | 714 +++ .../policy/workers/megatron_policy_worker.mdx | 638 ++ .../nemo_rl/models/policy/workers/patches.mdx | 85 + .../nemo-rl/nemo_rl/package_info.mdx | 235 + .../nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx | 22 + .../nemo_rl/utils/automodel_checkpoint.mdx | 436 ++ .../nemo-rl/nemo_rl/utils/checkpoint.mdx | 411 ++ .../nemo-rl/nemo_rl/utils/config.mdx | 266 + .../nemo-rl/nemo_rl/utils/flops_formulas.mdx | 501 ++ .../nemo-rl/nemo_rl/utils/flops_tracker.mdx | 215 + .../nemo-rl/nemo_rl/utils/logger.mdx | 1856 ++++++ .../nemo-rl/nemo_rl/utils/memory_tracker.mdx | 122 + .../nemo_rl/utils/native_checkpoint.mdx | 351 ++ .../nemo-rl/nemo_rl/utils/nsys.mdx | 138 + .../nemo-rl/nemo_rl/utils/nvml.mdx | 100 + .../nemo-rl/nemo_rl/utils/packed_tensor.mdx | 140 + .../nemo-rl/nemo_rl/utils/prefetch_venvs.mdx | 108 + .../nemo-rl/nemo_rl/utils/timer.mdx | 441 ++ .../nemo-rl/nemo_rl/utils/venvs.mdx | 177 + fern/library-docs/ttl-docs/_navigation.yml | 149 + fern/library-docs/ttl-docs/ttl/ttl.mdx | 60 + .../ttl-docs/ttl/ttl/_mlir_libs.mdx | 9 + .../ttl/ttl/_mlir_libs/_site_initialize_1.mdx | 35 + fern/library-docs/ttl-docs/ttl/ttl/_src.mdx | 11 + .../ttl-docs/ttl/ttl/_src/auto_profile.mdx | 479 ++ .../ttl-docs/ttl/ttl/_src/tensor_registry.mdx | 169 + .../ttl-docs/ttl/ttl/_src/ttl_ast.mdx | 731 +++ .../ttl-docs/ttl/ttl/circular_buffer.mdx | 239 + .../ttl-docs/ttl/ttl/constants.mdx | 41 + .../ttl-docs/ttl/ttl/diagnostics.mdx | 466 ++ .../ttl-docs/ttl/ttl/dialects.mdx | 12 + .../ttl-docs/ttl/ttl/dialects/_ods_common.mdx | 39 + .../ttl-docs/ttl/ttl/dialects/ttl.mdx | 81 + .../ttl-docs/ttl/ttl/dtype_utils.mdx | 212 + .../ttl-docs/ttl/ttl/kernel_runner.mdx | 274 + .../library-docs/ttl-docs/ttl/ttl/layouts.mdx | 126 + .../ttl-docs/ttl/ttl/operators.mdx | 714 +++ fern/library-docs/ttl-docs/ttl/ttl/ttl.mdx | 27 + .../library-docs/ttl-docs/ttl/ttl/ttl_api.mdx | 907 +++ .../ttl-docs/ttl/ttl/ttl_math.mdx | 29 + .../ttl-docs/ttl/ttl/ttl_utils.mdx | 70 + ...lock_reduce_v3.mdx => block_reduce_v5.mdx} | 70 +- .../{block_reduce.mdx => block_reduce_v6.mdx} | 234 +- .../{block_scan_v4.mdx => block_scan_v5.mdx} | 289 +- .../cub/{block_scan.mdx => block_scan_v6.mdx} | 829 ++- fern/pages/cub/simple_struct_v4.mdx | 48 - ...simple_struct.mdx => simple_struct_v5.mdx} | 0 fern/pages/cub/simple_struct_v6.mdx | 24 + ...{warp_reduce_v4.mdx => warp_reduce_v5.mdx} | 54 +- .../{warp_reduce.mdx => warp_reduce_v6.mdx} | 462 +- fern/pages/libcudacxx/concept_example.mdx | 43 - ..._example_v3.mdx => concept_example_v5.mdx} | 2 +- fern/pages/libcudacxx/concept_example_v6.mdx | 27 + ...lass_v4.mdx => deep_template_class_v5.mdx} | 16 +- ...e_class.mdx => deep_template_class_v6.mdx} | 131 +- ...ss_v4.mdx => empty_docstring_class_v5.mdx} | 53 +- .../libcudacxx/empty_docstring_class_v6.mdx | 723 +++ ...s_example_v4.mdx => raises_example_v5.mdx} | 6 +- ...ises_example.mdx => raises_example_v6.mdx} | 299 +- fern/pages/rendering_rules.md | 1113 ++++ fern/pages/style_reference.md | 874 +++ ...ample_v4.mdx => deprecated_example_v5.mdx} | 10 +- ..._example.mdx => deprecated_example_v6.mdx} | 44 +- fern/pages/thrust/device_vector.mdx | 1188 ---- ...ice_vector_v3.mdx => device_vector_v5.mdx} | 64 +- fern/pages/thrust/device_vector_v6.mdx | 1368 ++++ ...ple_v4.mdx => group_member_example_v5.mdx} | 2 +- ...xample.mdx => group_member_example_v6.mdx} | 158 +- .../thrust/{pointer.mdx => pointer_v5.mdx} | 26 +- .../thrust/{pointer_v4.mdx => pointer_v6.mdx} | 81 +- 658 files changed, 192478 insertions(+), 2892 deletions(-) create mode 100644 fern/cudapages/cub/cub/cub/AgentAdjacentDifferencePolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentHistogramPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentMergeSortPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentRadixSortDownsweepPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentRadixSortExclusiveSumPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentRadixSortHistogramPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentRadixSortOnesweepPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentRadixSortUpsweepPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentReduceByKeyPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentReducePolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentRlePolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentScanByKeyPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentScanPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentSelectIfPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentSubWarpMergeSortPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentThreeWayPartitionPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentUniqueByKeyPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/AgentWarpReducePolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/ArgIndexInputIterator.mdx create mode 100644 fern/cudapages/cub/cub/cub/ArgMax.mdx create mode 100644 fern/cudapages/cub/cub/cub/ArgMin.mdx create mode 100644 fern/cudapages/cub/cub/cub/BFEDigitExtractor.mdx create mode 100644 fern/cudapages/cub/cub/cub/BaseDigitExtractor.mdx create mode 100644 fern/cudapages/cub/cub/cub/BaseDigitExtractor_KeyT_true.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockAdjacentDifference.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockDiscontinuity.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockExchange.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockHistogram.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockLoad.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockLoadType.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockMergeSort.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockMergeSortStrategy.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockRadixRank.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockRadixRankEmptyCallback.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockRadixRankMatch.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockRadixRankMatchEarlyCounts.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockRadixSort.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockRakingLayout.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockReduce.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockRunLengthDecode.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockScan.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockScanRunningPrefixOp.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockShuffle.mdx create mode 100644 fern/cudapages/cub/cub/cub/BlockStore.mdx create mode 100644 fern/cudapages/cub/cub/cub/CacheModifiedInputIterator.mdx create mode 100644 fern/cudapages/cub/cub/cub/CacheModifiedOutputIterator.mdx create mode 100644 fern/cudapages/cub/cub/cub/CachingDeviceAllocator.mdx create mode 100644 fern/cudapages/cub/cub/cub/CastOp.mdx create mode 100644 fern/cudapages/cub/cub/cub/ChainedPolicy.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceAdjacentDifference.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceCopy.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceFind.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceFor.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceHistogram.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceMemcpy.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceMerge.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceMergeSort.mdx create mode 100644 fern/cudapages/cub/cub/cub/DevicePartition.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceRadixSort.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceReduce.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceRleDispatch.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceRunLengthEncode.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceScan.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceSegmentedRadixSort.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceSegmentedReduce.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceSegmentedScan.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceSegmentedSort.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceSelect.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceTopK.mdx create mode 100644 fern/cudapages/cub/cub/cub/DeviceTransform.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchAdjacentDifference.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchHistogram.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchMergeSort.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchRadixSort.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchReduce.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchReduceByKey.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchScan.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchScanByKey.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchSegmentedRadixSort.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchSegmentedReduce.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchSegmentedSort.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchSelectIf.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchThreeWayPartitionIf.mdx create mode 100644 fern/cudapages/cub/cub/cub/DispatchUniqueByKey.mdx create mode 100644 fern/cudapages/cub/cub/cub/GridEvenShare.mdx create mode 100644 fern/cudapages/cub/cub/cub/GridQueue.mdx create mode 100644 fern/cudapages/cub/cub/cub/InequalityWrapper.mdx create mode 100644 fern/cudapages/cub/cub/cub/PtxVersionCacheTag.mdx create mode 100644 fern/cudapages/cub/cub/cub/RadixSortTwiddle.mdx create mode 100644 fern/cudapages/cub/cub/cub/ReduceByKeyOp.mdx create mode 100644 fern/cudapages/cub/cub/cub/ReduceByKeyScanTileState.mdx create mode 100644 fern/cudapages/cub/cub/cub/ReduceByKeyScanTileState_ValueT_KeyT_false.mdx create mode 100644 fern/cudapages/cub/cub/cub/ReduceBySegmentOp.mdx create mode 100644 fern/cudapages/cub/cub/cub/ScanTileState.mdx create mode 100644 fern/cudapages/cub/cub/cub/ScanTileState_T_false.mdx create mode 100644 fern/cudapages/cub/cub/cub/ShiftDigitExtractor.mdx create mode 100644 fern/cudapages/cub/cub/cub/SmVersionCacheTag.mdx create mode 100644 fern/cudapages/cub/cub/cub/SwizzleScanOp.mdx create mode 100644 fern/cudapages/cub/cub/cub/TilePrefixCallbackOp.mdx create mode 100644 fern/cudapages/cub/cub/cub/WarpExchange.mdx create mode 100644 fern/cudapages/cub/cub/cub/WarpLoad.mdx create mode 100644 fern/cudapages/cub/cub/cub/WarpMergeSort.mdx create mode 100644 fern/cudapages/cub/cub/cub/WarpReduce.mdx create mode 100644 fern/cudapages/cub/cub/cub/WarpScan.mdx create mode 100644 fern/cudapages/cub/cub/cub/WarpStore.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/arch_traits_t.mdx rename fern/{pages/libcudacxx/empty_docstring_class.mdx => cudapages/cuda/cuda/cuda/buffer.mdx} (69%) create mode 100644 fern/cudapages/cuda/cuda/cuda/compute_capability.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/constant_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/copy_configuration.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/counting_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/device_attributes/compute_capability_t.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/device_memory_pool.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/device_memory_pool_ref.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/device_ref.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/discard_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/event.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/event_ref.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/get_stream_t.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/has_property.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/has_property_with.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/heterogeneous_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/managed_memory_pool.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/managed_memory_pool_ref.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/memory_pool_properties.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/basic_any_resource.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/basic_resource_ref.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/device_accessible.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/host_accessible.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/legacy_managed_memory_resource.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/legacy_pinned_memory_resource.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/properties_list.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/resource.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/resource_with.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/shared_resource.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource_adapter.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource_with.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/permutation_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/pinned_memory_pool.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/pinned_memory_pool_ref.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/property_with_value.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/shuffle_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/std/pointer_traits.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/stream.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/stream_ref.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/strided_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/tabulate_output_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/timed_event.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/transform_input_output_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/transform_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/transform_output_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/zip_function.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/zip_iterator.mdx create mode 100644 fern/cudapages/cuda/cuda/cuda/zip_transform_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/allocator_delete.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/array_allocator_delete.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/bidirectional_device_iterator_tag.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/bidirectional_traversal_tag.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/compile_time_value.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/complex.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/constant_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/counting_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/device_allocator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/device_execution_policy.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/device_malloc_allocator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/device_new_allocator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/device_ptr.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/device_ptr_memory_resource.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/device_reference.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/device_vector.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/discard_block_engine.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/discard_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/error_category.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/error_code.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/error_condition.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/forward_device_iterator_tag.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/forward_traversal_tag.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/host_execution_policy.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/host_vector.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/incrementable_traversal_tag.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/input_device_iterator_tag.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/is_error_code_enum.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/is_error_condition_enum.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_adaptor.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_core_access.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_difference.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_facade.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_pointer.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_reference.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_const_void_ptr.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudaconstant_iterator_T_Index.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudacounting_iterator_Start.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudadiscard_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudapermutation_iterator_Iter_Offset.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudashuffle_iterator_IndexType_Bijection.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudastdreverse_iterator_Iter.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudastrided_iterator_Iter_Stride.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudatabulate_output_iterator_Fn_Index.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_input_output_iterator_InputFn_OutputFn_Iter.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_iterator_Fn_Iter.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_output_iterator_Fn_Iter.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudazip_iterator_Iterators.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_cudazip_transform_iterator_Fn_Iterators.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_system_void_ptr.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudaconstant_iterator_T_Index.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudacounting_iterator_Start.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudadiscard_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudapermutation_iterator_Iter_Offset.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudashuffle_iterator_IndexType_Bijection.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudastdreverse_iterator_Iter.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudastrided_iterator_Iter_Stride.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatabulate_output_iterator_Fn_Index.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_input_output_iterator_InputFn_OutputFn_Iter.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_iterator_Fn_Iter.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_output_iterator_Fn_Iter.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudazip_iterator_Iterators.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudazip_transform_iterator_Fn_Iterators.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/iterator_value.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/linear_congruential_engine.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/linear_feedback_shift_engine.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/allocator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/disjoint_synchronized_pool_resource.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/disjoint_unsynchronized_pool_resource.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/fancy_pointer_resource.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/memory_resource.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/memory_resource_void_ptr.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/new_delete_resource.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/new_delete_resource_base.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/polymorphic_adaptor_resource.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/pool_options.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/stateless_resource_allocator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/synchronized_pool_resource.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/unsynchronized_pool_resource.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/validator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/validator2.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/mr/validator2_T_T.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/no_traversal_tag.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/normal_distribution.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/offset_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/output_device_iterator_tag.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/per_device_allocator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/permutation_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/pointer.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/proclaim_contiguous_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/project1st.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/project1st_void_void.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/project2nd.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/project2nd_void_void.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/random/discard_block_engine.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/random/linear_congruential_engine.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/random/linear_feedback_shift_engine.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/random/normal_distribution.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/random/subtract_with_carry_engine.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/random/uniform_int_distribution.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/random/uniform_real_distribution.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/random/xor_combine_engine.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/random_access_device_iterator_tag.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/random_access_traversal_tag.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/runtime_value.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/shuffle_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/single_pass_traversal_tag.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/square.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/square_void.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/strided_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/subtract_with_carry_engine.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/system/error_category.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/system/error_code.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/system/error_condition.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/system/is_error_code_enum.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/system/is_error_code_enum_cudaerrcerrc_t.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/system/is_error_condition_enum.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/system/is_error_condition_enum_errcerrc_t.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/system/system_error.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/system_error.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/tabulate_output_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/tagged_deleter.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/transform_input_output_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/transform_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/transform_output_iterator.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/uniform_int_distribution.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/uniform_real_distribution.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/xor_combine_engine.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/zip_function.mdx create mode 100644 fern/cudapages/thrust/thrust/thrust/zip_iterator.mdx create mode 100644 fern/docs/pages/nominal-data-model.mdx create mode 100644 fern/docs/pages/steps-toc-test.mdx create mode 100644 fern/library-docs/langchain-core-docs/_navigation.yml create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/beta_decorator.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/deprecation.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/internal.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/path.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_import_utils.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_security.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_security/_ssrf_protection.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/agents.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/caches.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/file.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/manager.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/stdout.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/streaming_stdout.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/usage.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_history.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_loaders.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_sessions.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/blob_loaders.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/langsmith.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/compressor.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/transformers.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings/embeddings.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings/fake.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/env.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/length_based.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/semantic_similarity.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/exceptions.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/globals.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/api.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/in_memory.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/_utils.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/chat_models.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/fake.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/fake_chat_models.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/llms.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/model_profile.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/_validation.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/dump.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/load.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/mapping.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/serializable.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/ai.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/anthropic.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/bedrock.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/bedrock_converse.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/google_genai.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/google_vertexai.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/groq.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/langchain_v0.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/openai.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/chat.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/content.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/function.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/human.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/modifier.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/system.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/tool.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/utils.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/format_instructions.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/json.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/list.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/openai_functions.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/openai_tools.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/pydantic.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/string.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/transform.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/xml.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/chat_generation.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/chat_result.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/generation.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/llm_result.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/run_info.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompt_values.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/chat.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/dict.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/few_shot.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/few_shot_with_templates.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/image.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/loading.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/message.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/prompt.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/string.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/structured.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/rate_limiters.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/retrievers.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/branch.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/config.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/configurable.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/fallbacks.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_ascii.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_mermaid.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_png.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/history.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/passthrough.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/retry.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/router.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/schema.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/utils.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/stores.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/structured_query.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/sys_info.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/convert.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/render.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/retriever.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/simple.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/structured.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/_compat.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/_streaming.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/context.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/core.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/evaluation.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/event_stream.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/langchain.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/log_stream.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/memory_stream.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/root_listeners.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/run_collector.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/schemas.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/stdout.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/_merge.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/aiter.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/env.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/formatting.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/function_calling.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/html.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/image.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/input.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/interactive_env.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/iter.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/json.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/json_schema.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/mustache.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/pydantic.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/strings.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/usage.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/utils.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/uuid.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/base.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/in_memory.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/utils.mdx create mode 100644 fern/library-docs/langchain-core-docs/langchain-core/langchain_core/version.mdx create mode 100644 fern/library-docs/nemo-rl-docs/_navigation.yml create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/pipeline_parallel.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/train.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx create mode 100644 fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx create mode 100644 fern/library-docs/ttl-docs/_navigation.yml create mode 100644 fern/library-docs/ttl-docs/ttl/ttl.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/_mlir_libs.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/_src.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/_src/auto_profile.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/_src/tensor_registry.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/_src/ttl_ast.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/circular_buffer.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/constants.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/diagnostics.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/dialects.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/dialects/_ods_common.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/dialects/ttl.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/dtype_utils.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/kernel_runner.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/layouts.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/operators.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/ttl.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/ttl_api.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/ttl_math.mdx create mode 100644 fern/library-docs/ttl-docs/ttl/ttl/ttl_utils.mdx rename fern/pages/cub/{block_reduce_v3.mdx => block_reduce_v5.mdx} (89%) rename fern/pages/cub/{block_reduce.mdx => block_reduce_v6.mdx} (64%) rename fern/pages/cub/{block_scan_v4.mdx => block_scan_v5.mdx} (82%) rename fern/pages/cub/{block_scan.mdx => block_scan_v6.mdx} (62%) delete mode 100644 fern/pages/cub/simple_struct_v4.mdx rename fern/pages/cub/{simple_struct.mdx => simple_struct_v5.mdx} (100%) create mode 100644 fern/pages/cub/simple_struct_v6.mdx rename fern/pages/cub/{warp_reduce_v4.mdx => warp_reduce_v5.mdx} (89%) rename fern/pages/cub/{warp_reduce.mdx => warp_reduce_v6.mdx} (68%) delete mode 100644 fern/pages/libcudacxx/concept_example.mdx rename fern/pages/libcudacxx/{concept_example_v3.mdx => concept_example_v5.mdx} (99%) create mode 100644 fern/pages/libcudacxx/concept_example_v6.mdx rename fern/pages/libcudacxx/{deep_template_class_v4.mdx => deep_template_class_v5.mdx} (89%) rename fern/pages/libcudacxx/{deep_template_class.mdx => deep_template_class_v6.mdx} (63%) rename fern/pages/libcudacxx/{empty_docstring_class_v4.mdx => empty_docstring_class_v5.mdx} (93%) create mode 100644 fern/pages/libcudacxx/empty_docstring_class_v6.mdx rename fern/pages/libcudacxx/{raises_example_v4.mdx => raises_example_v5.mdx} (98%) rename fern/pages/libcudacxx/{raises_example.mdx => raises_example_v6.mdx} (52%) create mode 100644 fern/pages/rendering_rules.md create mode 100644 fern/pages/style_reference.md rename fern/pages/thrust/{deprecated_example_v4.mdx => deprecated_example_v5.mdx} (95%) rename fern/pages/thrust/{deprecated_example.mdx => deprecated_example_v6.mdx} (62%) delete mode 100644 fern/pages/thrust/device_vector.mdx rename fern/pages/thrust/{device_vector_v3.mdx => device_vector_v5.mdx} (97%) create mode 100644 fern/pages/thrust/device_vector_v6.mdx rename fern/pages/thrust/{group_member_example_v4.mdx => group_member_example_v5.mdx} (99%) rename fern/pages/thrust/{group_member_example.mdx => group_member_example_v6.mdx} (75%) rename fern/pages/thrust/{pointer.mdx => pointer_v5.mdx} (98%) rename fern/pages/thrust/{pointer_v4.mdx => pointer_v6.mdx} (73%) diff --git a/fern/cudapages/cub/cub/cub/AgentAdjacentDifferencePolicy.mdx b/fern/cudapages/cub/cub/cub/AgentAdjacentDifferencePolicy.mdx new file mode 100644 index 0000000..a8d066d --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentAdjacentDifferencePolicy.mdx @@ -0,0 +1,38 @@ +--- +title: cub::AgentAdjacentDifferencePolicy +description: "" +--- + + + + + + + + + + + + + + + + + + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `ITEMS_PER_THREAD` static constexpr | `int` | | +| `ITEMS_PER_TILE` static constexpr | `int` | | +| `LOAD_ALGORITHM` static constexpr | `cub::BlockLoadAlgorithm` | | +| `LOAD_MODIFIER` static constexpr | `cub::CacheLoadModifier` | | +| `STORE_ALGORITHM` static constexpr | `cub::BlockStoreAlgorithm` | | diff --git a/fern/cudapages/cub/cub/cub/AgentHistogramPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentHistogramPolicy.mdx new file mode 100644 index 0000000..38781a0 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentHistogramPolicy.mdx @@ -0,0 +1,59 @@ +--- +title: cub::AgentHistogramPolicy +description: "Parameterizable tuning policy type for AgentHistogram." +--- + +Parameterizable tuning policy type for AgentHistogram. + + + + + +Threads per thread block + + + +Pixels per thread (per tile of input) + + + +The [BlockLoad](/library/api/cub::_block_load) algorithm to use + + + +Cache load modifier for reading input elements + + + +Whether to perform localized RLE to compress samples before histogramming + + + +Whether to prefer privatized shared-memory bins (versus privatized global-memory bins) + + + +Whether to dequeue tiles from a global work queue + + + +Vector size for samples loading (1, 2, 4) + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | Threads per thread block. | +| `PIXELS_PER_THREAD` static constexpr | `int` | Pixels per thread (per tile of input). | +| `IS_RLE_COMPRESS` static constexpr | `bool` | Whether to perform localized RLE to compress samples before histogramming. | +| `MEM_PREFERENCE` static constexpr | `BlockHistogramMemoryPreference` | Whether to prefer privatized shared-memory bins (versus privatized global-memory bins). | +| `IS_WORK_STEALING` static constexpr | `bool` | Whether to dequeue tiles from a global work queue. | +| `VEC_SIZE` static constexpr | `int` | Vector size for samples loading (1, 2, 4). | +| `LOAD_ALGORITHM` static constexpr | `BlockLoadAlgorithm` | Cache load modifier for reading input elements. | +| `LOAD_MODIFIER` static constexpr | `CacheLoadModifier` | | diff --git a/fern/cudapages/cub/cub/cub/AgentMergeSortPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentMergeSortPolicy.mdx new file mode 100644 index 0000000..ae2d584 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentMergeSortPolicy.mdx @@ -0,0 +1,38 @@ +--- +title: cub::AgentMergeSortPolicy +description: "" +--- + + + + + + + + + + + + + + + + + + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `ITEMS_PER_THREAD` static constexpr | `int` | | +| `ITEMS_PER_TILE` static constexpr | `int` | | +| `LOAD_ALGORITHM` static constexpr | `cub::BlockLoadAlgorithm` | | +| `LOAD_MODIFIER` static constexpr | `cub::CacheLoadModifier` | | +| `STORE_ALGORITHM` static constexpr | `cub::BlockStoreAlgorithm` | | diff --git a/fern/cudapages/cub/cub/cub/AgentRadixSortDownsweepPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentRadixSortDownsweepPolicy.mdx new file mode 100644 index 0000000..c7330ff --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentRadixSortDownsweepPolicy.mdx @@ -0,0 +1,61 @@ +--- +title: cub::AgentRadixSortDownsweepPolicy +description: "Parameterizable tuning policy type for AgentRadixSortDownsweep." +--- + +Parameterizable tuning policy type for AgentRadixSortDownsweep. + + + + + +Threads per thread block + + + +Items per thread (per tile of input) + + + +Dominant compute type + + + +The [BlockLoad](/library/api/cub::_block_load) algorithm to use + + + +Cache load modifier for reading keys (and values) + + + +The radix ranking algorithm to use + + + +The block scan algorithm to use + + + +The number of radix bits, i.e., log2(bins) + + + + + + + + +**Inherits from:** `detail::RegBoundScaling< NominalBlockThreads4B, NominalItemsPerThread4B, ComputeT >` (public) + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `RADIX_BITS` static constexpr | `int` | The number of radix bits, i.e., log2(bins). | +| `LOAD_ALGORITHM` static constexpr | `BlockLoadAlgorithm` | The [BlockLoad](/library/api/cub::_block_load) algorithm to use. | +| `LOAD_MODIFIER` static constexpr | `CacheLoadModifier` | Cache load modifier for reading keys (and values). | +| `RANK_ALGORITHM` static constexpr | `RadixRankAlgorithm` | The radix ranking algorithm to use. | +| `SCAN_ALGORITHM` static constexpr | `BlockScanAlgorithm` | The [BlockScan](/library/api/cub::_block_scan) algorithm to use. | diff --git a/fern/cudapages/cub/cub/cub/AgentRadixSortExclusiveSumPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentRadixSortExclusiveSumPolicy.mdx new file mode 100644 index 0000000..bc1d0a5 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentRadixSortExclusiveSumPolicy.mdx @@ -0,0 +1,25 @@ +--- +title: cub::AgentRadixSortExclusiveSumPolicy +description: "" +--- + + + + + + + + + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `RADIX_BITS` static constexpr | `int` | | diff --git a/fern/cudapages/cub/cub/cub/AgentRadixSortHistogramPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentRadixSortHistogramPolicy.mdx new file mode 100644 index 0000000..4a74491 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentRadixSortHistogramPolicy.mdx @@ -0,0 +1,50 @@ +--- +title: cub::AgentRadixSortHistogramPolicy +description: "" +--- + + + + + + + + + + + + + + +If void, use NOMINAL_4B_NUM_PARTS directly for NUM_PARTS. Otherwise, perform scaling. + + + + + + + + +--- + +## Static methods + +### num_parts_helper inline static constexpr + + +```cpp showLineNumbers={false} +template +static constexpr int cub::AgentRadixSortHistogramPolicy::num_parts_helper() +``` + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `ITEMS_PER_THREAD` static constexpr | `int` | | +| `NUM_PARTS` static constexpr | `int` | NUM_PARTS is the number of private histograms (parts) each histogram is split into. | +| `RADIX_BITS` static constexpr | `int` | | diff --git a/fern/cudapages/cub/cub/cub/AgentRadixSortOnesweepPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentRadixSortOnesweepPolicy.mdx new file mode 100644 index 0000000..759c8c5 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentRadixSortOnesweepPolicy.mdx @@ -0,0 +1,51 @@ +--- +title: cub::AgentRadixSortOnesweepPolicy +description: "" +--- + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +**Inherits from:** `detail::RegBoundScaling< NominalBlockThreads4B, NominalItemsPerThread4B, ComputeT >` (public) + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `RANK_NUM_PARTS` static constexpr | `int` | | +| `RADIX_BITS` static constexpr | `int` | | +| `RANK_ALGORITHM` static constexpr | `RadixRankAlgorithm` | | +| `SCAN_ALGORITHM` static constexpr | `BlockScanAlgorithm` | | +| `STORE_ALGORITHM` static constexpr | `RadixSortStoreAlgorithm` | | diff --git a/fern/cudapages/cub/cub/cub/AgentRadixSortUpsweepPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentRadixSortUpsweepPolicy.mdx new file mode 100644 index 0000000..794a204 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentRadixSortUpsweepPolicy.mdx @@ -0,0 +1,46 @@ +--- +title: cub::AgentRadixSortUpsweepPolicy +description: "Parameterizable tuning policy type for AgentRadixSortUpsweep." +--- + +Parameterizable tuning policy type for AgentRadixSortUpsweep. + + + + + +Threads per thread block + + + +Items per thread (per tile of input) + + + +Dominant compute type + + + +Cache load modifier for reading keys + + + +The number of radix bits, i.e., log2(bins) + + + + + + + + +**Inherits from:** `detail::RegBoundScaling< NominalBlockThreads4B, NominalItemsPerThread4B, ComputeT >` (public) + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `RADIX_BITS` static constexpr | `int` | The number of radix bits, i.e., log2(bins). | +| `LOAD_MODIFIER` static constexpr | `CacheLoadModifier` | Cache load modifier for reading keys. | diff --git a/fern/cudapages/cub/cub/cub/AgentReduceByKeyPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentReduceByKeyPolicy.mdx new file mode 100644 index 0000000..502fd02 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentReduceByKeyPolicy.mdx @@ -0,0 +1,48 @@ +--- +title: cub::AgentReduceByKeyPolicy +description: "Parameterizable tuning policy type for AgentReduceByKey." +--- + +Parameterizable tuning policy type for AgentReduceByKey. + + + + + +Threads per thread block + + + +Items per thread (per tile of input) + + + +The [BlockLoad](/library/api/cub::_block_load) algorithm to use + + + +Cache load modifier for reading input elements + + + +The [BlockScan](/library/api/cub::_block_scan) algorithm to use + + + +Implementation detail, do not specify directly, requirements on the content of this type are subject to breaking change. + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | < Threads per thread block | +| `ITEMS_PER_THREAD` static constexpr | `int` | The [BlockLoad](/library/api/cub::_block_load) algorithm to use. | +| `LOAD_ALGORITHM` static constexpr | `BlockLoadAlgorithm` | Cache load modifier for reading input elements. | +| `LOAD_MODIFIER` static constexpr | `CacheLoadModifier` | The [BlockScan](/library/api/cub::_block_scan) algorithm to use. | +| `SCAN_ALGORITHM` static constexpr | `BlockScanAlgorithm` | | diff --git a/fern/cudapages/cub/cub/cub/AgentReducePolicy.mdx b/fern/cudapages/cub/cub/cub/AgentReducePolicy.mdx new file mode 100644 index 0000000..2a57060 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentReducePolicy.mdx @@ -0,0 +1,51 @@ +--- +title: cub::AgentReducePolicy +description: "Parameterizable tuning policy type for AgentReduce." +--- + +Parameterizable tuning policy type for AgentReduce. + + + + + +Threads per thread block + + + +Items per thread (per tile of input) + + + +Dominant compute type + + + +Number of items per vectorized load + + + +Cooperative block-wide reduction algorithm to use + + + +Cache load modifier for reading input elements + + + + + + + + +**Inherits from:** `detail::MemBoundScaling< NominalBlockThreads4B, NominalItemsPerThread4B, ComputeT >` (public) + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `VECTOR_LOAD_LENGTH` static constexpr | `int` | Number of items per vectorized load. | +| `BLOCK_ALGORITHM` static constexpr | `BlockReduceAlgorithm` | Cooperative block-wide reduction algorithm to use. | +| `LOAD_MODIFIER` static constexpr | `CacheLoadModifier` | Cache load modifier for reading input elements. | diff --git a/fern/cudapages/cub/cub/cub/AgentRlePolicy.mdx b/fern/cudapages/cub/cub/cub/AgentRlePolicy.mdx new file mode 100644 index 0000000..05ee349 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentRlePolicy.mdx @@ -0,0 +1,53 @@ +--- +title: cub::AgentRlePolicy +description: "Parameterizable tuning policy type for AgentRle." +--- + +Parameterizable tuning policy type for AgentRle. + + + + + +Threads per thread block + + + +Items per thread (per tile of input) + + + +The [BlockLoad](/library/api/cub::_block_load) algorithm to use + + + +Cache load modifier for reading input elements + + + +Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage) + + + +The [BlockScan](/library/api/cub::_block_scan) algorithm to use + + + +Implementation detail, do not specify directly, requirements on the content of this type are subject to breaking change. + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | Threads per thread block. | +| `ITEMS_PER_THREAD` static constexpr | `int` | Items per thread (per tile of input). | +| `STORE_WARP_TIME_SLICING` static constexpr | `bool` | Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage). | +| `LOAD_ALGORITHM` static constexpr | `BlockLoadAlgorithm` | The [BlockLoad](/library/api/cub::_block_load) algorithm to use. | +| `LOAD_MODIFIER` static constexpr | `CacheLoadModifier` | Cache load modifier for reading input elements. | +| `SCAN_ALGORITHM` static constexpr | `BlockScanAlgorithm` | The [BlockScan](/library/api/cub::_block_scan) algorithm to use. | diff --git a/fern/cudapages/cub/cub/cub/AgentScanByKeyPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentScanByKeyPolicy.mdx new file mode 100644 index 0000000..a23df29 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentScanByKeyPolicy.mdx @@ -0,0 +1,47 @@ +--- +title: cub::AgentScanByKeyPolicy +description: "Parameterizable tuning policy type for AgentScanByKey." +--- + +Parameterizable tuning policy type for AgentScanByKey. + + + + + + + + + + + + + + + + + + + + + + + +Implementation detail, do not specify directly, requirements on the content of this type are subject to breaking change. + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `ITEMS_PER_THREAD` static constexpr | `int` | | +| `LOAD_ALGORITHM` static constexpr | `BlockLoadAlgorithm` | | +| `LOAD_MODIFIER` static constexpr | `CacheLoadModifier` | | +| `SCAN_ALGORITHM` static constexpr | `BlockScanAlgorithm` | | +| `STORE_ALGORITHM` static constexpr | `BlockStoreAlgorithm` | | diff --git a/fern/cudapages/cub/cub/cub/AgentScanPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentScanPolicy.mdx new file mode 100644 index 0000000..6d049df --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentScanPolicy.mdx @@ -0,0 +1,60 @@ +--- +title: cub::AgentScanPolicy +description: "Parameterizable tuning policy type for AgentScan." +--- + +Parameterizable tuning policy type for AgentScan. + + + + + +Threads per thread block + + + +Items per thread (per tile of input) + + + +Dominant compute type + + + +The [BlockLoad](/library/api/cub::_block_load) algorithm to use + + + +Cache load modifier for reading input elements + + + +The [BlockStore](/library/api/cub::_block_store) algorithm to use + + + +The [BlockScan](/library/api/cub::_block_scan) algorithm to use + + + + + + +Implementation detail, do not specify directly, requirements on the content of this type are subject to breaking change. + + + + + +**Inherits from:** `detail::MemBoundScaling< NominalBlockThreads4B, NominalItemsPerThread4B, ComputeT >` (public) + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `LOAD_ALGORITHM` static constexpr | `BlockLoadAlgorithm` | | +| `LOAD_MODIFIER` static constexpr | `CacheLoadModifier` | | +| `STORE_ALGORITHM` static constexpr | `BlockStoreAlgorithm` | | +| `SCAN_ALGORITHM` static constexpr | `BlockScanAlgorithm` | | diff --git a/fern/cudapages/cub/cub/cub/AgentSelectIfPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentSelectIfPolicy.mdx new file mode 100644 index 0000000..7d7e5ca --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentSelectIfPolicy.mdx @@ -0,0 +1,48 @@ +--- +title: cub::AgentSelectIfPolicy +description: "Parameterizable tuning policy type for AgentSelectIf." +--- + +Parameterizable tuning policy type for AgentSelectIf. + + + + + +Threads per thread block + + + +Items per thread (per tile of input) + + + +The [BlockLoad](/library/api/cub::_block_load) algorithm to use + + + +Cache load modifier for reading input elements + + + +The [BlockScan](/library/api/cub::_block_scan) algorithm to use + + + +Implementation detail, do not specify directly, requirements on the content of this type are subject to breaking change. + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | Threads per thread block. | +| `ITEMS_PER_THREAD` static constexpr | `int` | Items per thread (per tile of input). | +| `LOAD_ALGORITHM` static constexpr | `BlockLoadAlgorithm` | The [BlockLoad](/library/api/cub::_block_load) algorithm to use. | +| `LOAD_MODIFIER` static constexpr | `CacheLoadModifier` | Cache load modifier for reading input elements. | +| `SCAN_ALGORITHM` static constexpr | `BlockScanAlgorithm` | The [BlockScan](/library/api/cub::_block_scan) algorithm to use. | diff --git a/fern/cudapages/cub/cub/cub/AgentSubWarpMergeSortPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentSubWarpMergeSortPolicy.mdx new file mode 100644 index 0000000..651fba1 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentSubWarpMergeSortPolicy.mdx @@ -0,0 +1,43 @@ +--- +title: cub::AgentSubWarpMergeSortPolicy +description: "" +--- + + + + + + + + + + + + + + + + + + + + + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `WARP_THREADS` static constexpr | `int` | | +| `ITEMS_PER_THREAD` static constexpr | `int` | | +| `ITEMS_PER_TILE` static constexpr | `int` | | +| `SEGMENTS_PER_BLOCK` static constexpr | `int` | | +| `LOAD_ALGORITHM` static constexpr | `cub::WarpLoadAlgorithm` | | +| `LOAD_MODIFIER` static constexpr | `cub::CacheLoadModifier` | | +| `STORE_ALGORITHM` static constexpr | `cub::WarpStoreAlgorithm` | | diff --git a/fern/cudapages/cub/cub/cub/AgentThreeWayPartitionPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentThreeWayPartitionPolicy.mdx new file mode 100644 index 0000000..385a268 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentThreeWayPartitionPolicy.mdx @@ -0,0 +1,40 @@ +--- +title: cub::AgentThreeWayPartitionPolicy +description: "" +--- + + + + + + + + + + + + + + + + + + + + + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `ITEMS_PER_THREAD` static constexpr | `int` | | +| `LOAD_ALGORITHM` static constexpr | `BlockLoadAlgorithm` | | +| `LOAD_MODIFIER` static constexpr | `CacheLoadModifier` | | +| `SCAN_ALGORITHM` static constexpr | `BlockScanAlgorithm` | | diff --git a/fern/cudapages/cub/cub/cub/AgentUniqueByKeyPolicy.mdx b/fern/cudapages/cub/cub/cub/AgentUniqueByKeyPolicy.mdx new file mode 100644 index 0000000..07baee4 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentUniqueByKeyPolicy.mdx @@ -0,0 +1,43 @@ +--- +title: cub::AgentUniqueByKeyPolicy +description: "Parameterizable tuning policy type for AgentUniqueByKey." +--- + +Parameterizable tuning policy type for AgentUniqueByKey. + + + + + + + + + + + + + + + + + + + + +Implementation detail, do not specify directly, requirements on the content of this type are subject to breaking change. + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `ITEMS_PER_THREAD` static constexpr | `int` | | +| `LOAD_ALGORITHM` static constexpr | `cub::BlockLoadAlgorithm` | | +| `LOAD_MODIFIER` static constexpr | `cub::CacheLoadModifier` | | +| `SCAN_ALGORITHM` static constexpr | `cub::BlockScanAlgorithm` | | diff --git a/fern/cudapages/cub/cub/cub/AgentWarpReducePolicy.mdx b/fern/cudapages/cub/cub/cub/AgentWarpReducePolicy.mdx new file mode 100644 index 0000000..3154401 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/AgentWarpReducePolicy.mdx @@ -0,0 +1,50 @@ +--- +title: cub::AgentWarpReducePolicy +description: "Parameterizable tuning policy type for AgentReduce." +--- + +Parameterizable tuning policy type for AgentReduce. + + + + + +Threads per thread block + + + +Threads per warp + + + +Items per thread (per tile of input) + + + +Dominant compute type + + + +Number of items per vectorized load + + + +Cache load modifier for reading input elements + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `WARP_THREADS` static constexpr | `int` | Number of threads per warp. | +| `VECTOR_LOAD_LENGTH` static constexpr | `int` | Number of items per vectorized load. | +| `BLOCK_THREADS` static constexpr | `int` | Number of threads per block. | +| `ITEMS_PER_THREAD` static constexpr | `int` | Number of items per thread. | +| `LOAD_MODIFIER` static constexpr | `CacheLoadModifier` | Cache load modifier for reading input elements. | +| `ITEMS_PER_TILE` static constexpr | `int` | Number of items per tile. | +| `SEGMENTS_PER_BLOCK` static constexpr | `int` | Number of segments per block. | diff --git a/fern/cudapages/cub/cub/cub/ArgIndexInputIterator.mdx b/fern/cudapages/cub/cub/cub/ArgIndexInputIterator.mdx new file mode 100644 index 0000000..b2812f1 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/ArgIndexInputIterator.mdx @@ -0,0 +1,290 @@ +--- +title: cub::ArgIndexInputIterator +description: "A random-access input wrapper for pairing dereferenced values with their corresponding indices (forming `KeyValuePair` tuples)." +--- + +A random-access input wrapper for pairing dereferenced values with their corresponding indices (forming `KeyValuePair` tuples). + +**Overview** + +- `ArgIndexInputIterator` wraps a random access input iterator `itr` of type `InputIteratorT`. Dereferencing an `ArgIndexInputIterator` at offset `i` produces a `KeyValuePair` value whose `key` field is `i` and whose `value` field is [`itr`](/library/api/cub::_arg_index_input_iterator::itr)`[i]`. +- Can be used with any data type. +- Can be constructed, manipulated, and exchanged within and between host and device functions. Wrapped host memory can only be dereferenced on the host, and wrapped device memory can only be dereferenced on the device. +- Compatible with Thrust API v1.7 or newer. + +**Snippet** + +The code snippet below illustrates the use of `ArgIndexInputIterator` to dereference an array of doubles + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize a device array +double *d_in; // e.g., [8.0, 6.0, 7.0, 5.0, 3.0, 0.0, 9.0] + +// Create an iterator wrapper +cub::ArgIndexInputIterator itr(d_in); + +// Within device code: +cub::ArgIndexInputIterator::value_type tup = *itr; +printf("%f @ %ld\n", + tup.value, + tup.key); // 8.0 @ 0 + +itr = itr + 6; +tup = *itr; +printf("%f @ %ld\n", + tup.value, + tup.key); // 9.0 @ 6 +``` + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + + +The value type of the wrapped input iterator + + + +The difference type of this iterator (Default: `ptrdiff_t`) + + + +The paired value type of the <offset,value> tuple (Default: value type of input iterator) + + + + + +--- + +## Constructors + +### ArgIndexInputIterator inline + + +```cpp showLineNumbers={false} +cub::ArgIndexInputIterator::ArgIndexInputIterator( + InputIteratorT itr, + difference_type offset = 0 +) +``` + + +**Parameters** + + +Input iterator to wrap + + + +OffsetT (in items) from `itr` denoting the position of the iterator + + +--- + +## Methods + +### operator++ inline + + + + +Postfix increment. + + +```cpp showLineNumbers={false} +self_type cub::ArgIndexInputIterator::operator++( + int +) +``` + + + + + +Prefix increment. + + +```cpp showLineNumbers={false} +self_type cub::ArgIndexInputIterator::operator++() +``` + + + + + +### operator* inline const + +Indirection. + + +```cpp showLineNumbers={false} +reference cub::ArgIndexInputIterator::operator*() const +``` + + +### operator+ inline const + +Addition. + + +```cpp showLineNumbers={false} +template +self_type cub::ArgIndexInputIterator::operator+( + Distance n +) const +``` + + +### operator+= inline + +Addition assignment. + + +```cpp showLineNumbers={false} +template +self_type & cub::ArgIndexInputIterator::operator+=( + Distance n +) +``` + + +### operator- inline const + + + + +Subtraction. + + +```cpp showLineNumbers={false} +template +self_type cub::ArgIndexInputIterator::operator-( + Distance n +) const +``` + + + + + +Distance. + + +```cpp showLineNumbers={false} +difference_type cub::ArgIndexInputIterator::operator-( + self_type other +) const +``` + + + + + +### operator-= inline + +Subtraction assignment. + + +```cpp showLineNumbers={false} +template +self_type & cub::ArgIndexInputIterator::operator-=( + Distance n +) +``` + + +### operator[] inline const + +Array subscript. + + +```cpp showLineNumbers={false} +template +reference cub::ArgIndexInputIterator::operator[]( + Distance n +) const +``` + + +### operator-> inline + +Structure dereference. + + +```cpp showLineNumbers={false} +pointer cub::ArgIndexInputIterator::operator->() +``` + + +### operator== inline + +Equal to. + + +```cpp showLineNumbers={false} +bool cub::ArgIndexInputIterator::operator==( + const self_type &rhs +) +``` + + +### operator!= inline + +Not equal to. + + +```cpp showLineNumbers={false} +bool cub::ArgIndexInputIterator::operator!=( + const self_type &rhs +) +``` + + +### normalize inline + +Normalize. + + +```cpp showLineNumbers={false} +void cub::ArgIndexInputIterator::normalize() +``` + + +### operator<< inline + + +```cpp showLineNumbers={false} +friend::std::ostream & cub::ArgIndexInputIterator::operator<<( + ::std::ostream &os, + const self_type & +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `self_type` | `ArgIndexInputIterator` | My own type. | +| `difference_type` | `OffsetT` | Type to express the result of subtracting one iterator from another. | +| `value_type` | `KeyValuePair< difference_type, OutputValueT >` | The type of the element the iterator can point to. | +| `pointer` | `value_type *` | The type of a pointer to an element the iterator can point to. | +| `reference` | `value_type` | The type of a reference to an element the iterator can point to. | +| `iterator_category` | `THRUST_NS_QUALIFIER::detail::iterator_facade_category_t< THRUST_NS_QUALIFIER::any_system_tag, THRUST_NS_QUALIFIER::random_access_traversal_tag >` | The iterator category. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `itr` | `InputIteratorT` | | +| `offset` | `difference_type` | | diff --git a/fern/cudapages/cub/cub/cub/ArgMax.mdx b/fern/cudapages/cub/cub/cub/ArgMax.mdx new file mode 100644 index 0000000..0aadeef --- /dev/null +++ b/fern/cudapages/cub/cub/cub/ArgMax.mdx @@ -0,0 +1,24 @@ +--- +title: cub::ArgMax +description: "Arg max functor (keeps the value and offset of the first occurrence of the larger item)." +--- + +Arg max functor (keeps the value and offset of the first occurrence of the larger item). + +--- + +## Methods + +### operator() inline const + +Boolean max operator, preferring the item having the smaller offset in case of ties. + + +```cpp showLineNumbers={false} +template +KeyValuePair cub::ArgMax::operator()( + const KeyValuePair &a, + const KeyValuePair &b +) const +``` + diff --git a/fern/cudapages/cub/cub/cub/ArgMin.mdx b/fern/cudapages/cub/cub/cub/ArgMin.mdx new file mode 100644 index 0000000..9b17384 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/ArgMin.mdx @@ -0,0 +1,24 @@ +--- +title: cub::ArgMin +description: "Arg min functor (keeps the value and offset of the first occurrence of the smallest item)." +--- + +Arg min functor (keeps the value and offset of the first occurrence of the smallest item). + +--- + +## Methods + +### operator() inline const + +Boolean min operator, preferring the item having the smaller offset in case of ties. + + +```cpp showLineNumbers={false} +template +KeyValuePair cub::ArgMin::operator()( + const KeyValuePair &a, + const KeyValuePair &b +) const +``` + diff --git a/fern/cudapages/cub/cub/cub/BFEDigitExtractor.mdx b/fern/cudapages/cub/cub/cub/BFEDigitExtractor.mdx new file mode 100644 index 0000000..5d7be9b --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BFEDigitExtractor.mdx @@ -0,0 +1,82 @@ +--- +title: cub::BFEDigitExtractor +description: "A wrapper type to extract digits." +--- + +A wrapper type to extract digits. + +Uses the BFE intrinsic to extract a key from a digit. + + + + + + + + + + +**Inherits from:** `cub::BaseDigitExtractor< KeyT >` (public) + +--- + +## Constructors + +### BFEDigitExtractor inline explicit + + +```cpp showLineNumbers={false} +cub::BFEDigitExtractor::BFEDigitExtractor( + ::cuda::std::uint32_t bit_start = 0, + ::cuda::std::uint32_t num_bits = 0 +) +``` + + +--- + +## Methods + +### Digit inline const + + +```cpp showLineNumbers={false} +::cuda::std::uint32_t cub::BFEDigitExtractor::Digit( + UnsignedBits key +) const +``` + + +--- + +## Static methods + +### ProcessFloatMinusZero inline static + + +```cpp showLineNumbers={false} +static UnsignedBits cub::BaseDigitExtractor::ProcessFloatMinusZero( + UnsignedBits key +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `TraitsT` | `Traits< KeyT >` | +| `UnsignedBits` | `typename TraitsT::UnsignedBits` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `bit_start` | `::cuda::std::uint32_t` | | +| `num_bits` | `::cuda::std::uint32_t` | | diff --git a/fern/cudapages/cub/cub/cub/BaseDigitExtractor.mdx b/fern/cudapages/cub/cub/cub/BaseDigitExtractor.mdx new file mode 100644 index 0000000..22db328 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BaseDigitExtractor.mdx @@ -0,0 +1,49 @@ +--- +title: cub::BaseDigitExtractor +description: "Base struct for digit extractor." +--- + +Base struct for digit extractor. + +Contains common code to provide special handling for floating-point -0.0. + + +This handles correctly both the case when the keys are bitwise-complemented after twiddling for descending sort (in onesweep) as well as when the keys are not bit-negated, but the implementation handles descending sort separately (in other implementations in CUB). Twiddling alone maps -0.0f to 0x7fffffff and +0.0f to 0x80000000 for float, which are subsequent bit patterns and bitwise complements of each other. For onesweep, both -0.0f and +0.0f are mapped to the bit pattern of +0.0f (0x80000000) for ascending sort, and to the pattern of -0.0f (0x7fffffff) for descending sort. For all other sorting implementations in CUB, both are always mapped to +0.0f. Since bit patterns for both -0.0f and +0.0f are next to each other and only one of them is used, the sorting works correctly. For double, the same applies, but with 64-bit patterns. + + + + + + + + + + + + + + +--- + +## Static methods + +### ProcessFloatMinusZero inline static + + +```cpp showLineNumbers={false} +static UnsignedBits cub::BaseDigitExtractor::ProcessFloatMinusZero( + UnsignedBits key +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `TraitsT` | `Traits< KeyT >` | +| `UnsignedBits` | `typename TraitsT::UnsignedBits` | diff --git a/fern/cudapages/cub/cub/cub/BaseDigitExtractor_KeyT_true.mdx b/fern/cudapages/cub/cub/cub/BaseDigitExtractor_KeyT_true.mdx new file mode 100644 index 0000000..86098fe --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BaseDigitExtractor_KeyT_true.mdx @@ -0,0 +1,38 @@ +--- +title: "cub::BaseDigitExtractor< KeyT, true >" +description: "" +--- + + + + + + + + + + +--- + +## Static methods + +### ProcessFloatMinusZero inline static + +": "/library/api/cub::BaseDigitExtractor%3C KeyT, true %3E"}}> +```cpp showLineNumbers={false} +static UnsignedBits cub::BaseDigitExtractor::ProcessFloatMinusZero( + UnsignedBits key +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `TraitsT` | `Traits< KeyT >` | +| `UnsignedBits` | `typename TraitsT::UnsignedBits` | diff --git a/fern/cudapages/cub/cub/cub/BlockAdjacentDifference.mdx b/fern/cudapages/cub/cub/cub/BlockAdjacentDifference.mdx new file mode 100644 index 0000000..c3586f8 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockAdjacentDifference.mdx @@ -0,0 +1,762 @@ +--- +title: cub::BlockAdjacentDifference +description: "" +--- + +BlockAdjacentDifference provides collective methods for computing the differences of adjacent elements partitioned across a CUDA thread block. + +## Example + +The code snippet below illustrates how to use BlockAdjacentDifference to compute the left difference between adjacent elements. + +`{ [4,2,1,1], [1,1,1,1], [2,3,3,3], [3,4,1,4], ... }`. The corresponding output `result` in those threads will be `{ [4,-2,-1,0], [0,0,0,0], [1,1,0,0], [0,1,-3,3], ... }`. + +```cpp showLineNumbers={false} +#include +// or equivalently + +struct CustomDifference +{ + template + __host__ DataType operator()(DataType &lhs, DataType &rhs) + { + return lhs - rhs; + } +}; + +__global__ void ExampleKernel(...) +{ + // Specialize BlockAdjacentDifference for a 1D block of + // 128 threads of type int + using BlockAdjacentDifferenceT = + cub::BlockAdjacentDifference; + + // Allocate shared memory for BlockAdjacentDifference + __shared__ typename BlockAdjacentDifferenceT::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute adjacent_difference + int result[4]; + + BlockAdjacentDifferenceT(temp_storage).SubtractLeft(thread_data, result, + CustomDifference()); +} +``` + + + + + + + + + + + + + + + + + + + +--- + +## Collective constructors + +### BlockAdjacentDifference inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockAdjacentDifference::BlockAdjacentDifference() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockAdjacentDifference::BlockAdjacentDifference( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockAdjacentDifference::TempStorage) + + + + + +--- + +## Read left operations + +### SubtractLeft inline + + + + +Subtracts the left element of each adjacent pair of elements partitioned across a CUDA thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockAdjacentDifference::SubtractLeft( + T (&input)[ITEMS_PER_THREAD], + OutputType (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input items (may be aliased to `output`) + + + +Calling thread's adjacent difference result + + + +Binary difference operator + + +**Example** + +The code snippet below illustrates how to use BlockAdjacentDifference to compute the left difference between adjacent elements. + +`{ [4,2,1,1], [1,1,1,1], [2,3,3,3], [3,4,1,4], ... }`. The corresponding output `result` in those threads will be `{ [4,-2,-1,0], [0,0,0,0], [1,1,0,0], [0,1,-3,3], ... }`. + +```cpp showLineNumbers={false} +#include +// or equivalently + +struct CustomDifference +{ + template + __host__ DataType operator()(DataType &lhs, DataType &rhs) + { + return lhs - rhs; + } +}; + +__global__ void ExampleKernel(...) +{ + // Specialize BlockAdjacentDifference for a 1D block + // of 128 threads of type int + using BlockAdjacentDifferenceT = + cub::BlockAdjacentDifference; + + // Allocate shared memory for BlockAdjacentDifference + __shared__ typename BlockAdjacentDifferenceT::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute adjacent_difference + BlockAdjacentDifferenceT(temp_storage).SubtractLeft(thread_data, thread_data, + CustomDifference()); +} +``` + + + + +Subtracts the left element of each adjacent pair of elements partitioned across a CUDA thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockAdjacentDifference::SubtractLeft( + T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + T tile_predecessor_item +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input items (may be aliased to `output`) + + + +Calling thread's adjacent difference result + + + +Binary difference operator + + + +Embed:rst:leading-asterisk +//! *thread*\ :sub:`0` only item which is going to be subtracted from the first tile item +//! (*input*\ :sub:`0` from *thread*\ :sub:`0`). +//! + + +**Example** + +The code snippet below illustrates how to use BlockAdjacentDifference to compute the left difference between adjacent elements. + +`{ [4,2,1,1], [1,1,1,1], [2,3,3,3], [3,4,1,4], ... }`. and that `tile_predecessor_item` is `3`. The corresponding output `result` in those threads will be `{ [1,-2,-1,0], [0,0,0,0], [1,1,0,0], [0,1,-3,3], ... }`. + +```cpp showLineNumbers={false} +#include +// or equivalently + +struct CustomDifference +{ + template + __host__ DataType operator()(DataType &lhs, DataType &rhs) + { + return lhs - rhs; + } +}; + +__global__ void ExampleKernel(...) +{ + // Specialize BlockAdjacentDifference for a 1D block of + // 128 threads of type int + using BlockAdjacentDifferenceT = + cub::BlockAdjacentDifference; + + // Allocate shared memory for BlockAdjacentDifference + __shared__ typename BlockAdjacentDifferenceT::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // The last item in the previous tile: + int tile_predecessor_item = ...; + + // Collectively compute adjacent_difference + BlockAdjacentDifferenceT(temp_storage).SubtractLeft( + thread_data, + thread_data, + CustomDifference(), + tile_predecessor_item); +``` + + + + +### SubtractLeftPartialTile inline + + + + +Subtracts the left element of each adjacent pair of elements partitioned across a CUDA thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockAdjacentDifference::SubtractLeftPartialTile( + T (&input)[ITEMS_PER_THREAD], + OutputType (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + int valid_items +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input items (may be aliased to `output`) + + + +Calling thread's adjacent difference result + + + +Binary difference operator + + + +Number of valid items in thread block + + +**Example** + +The code snippet below illustrates how to use BlockAdjacentDifference to compute the left difference between adjacent elements. + +`{ [4,2,1,1], [1,1,1,1], [2,3,3,3], [3,4,1,4], ... }`. The corresponding output `result` in those threads will be `{ [4,-2,-1,0], [0,0,0,0], [1,3,3,3], [3,4,1,4], ... }`. + +```cpp showLineNumbers={false} +#include +// or equivalently + +struct CustomDifference +{ + template + __host__ DataType operator()(DataType &lhs, DataType &rhs) + { + return lhs - rhs; + } +}; + +__global__ void ExampleKernel(...) +{ + // Specialize BlockAdjacentDifference for a 1D block of + // 128 threads of type int + using BlockAdjacentDifferenceT = + cub::BlockAdjacentDifference; + + // Allocate shared memory for BlockAdjacentDifference + __shared__ typename BlockAdjacentDifferenceT::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + int valid_items = 9; + + // Collectively compute adjacent_difference + BlockAdjacentDifferenceT(temp_storage).SubtractLeftPartialTile( + thread_data, + thread_data, + CustomDifference(), + valid_items); +``` + + + + +Subtracts the left element of each adjacent pair of elements partitioned across a CUDA thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockAdjacentDifference::SubtractLeftPartialTile( + T (&input)[ITEMS_PER_THREAD], + OutputType (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + int valid_items, + T tile_predecessor_item +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input items (may be aliased to `output`) + + + +Calling thread's adjacent difference result + + + +Binary difference operator + + + +Number of valid items in thread block + + + +Embed:rst:leading-asterisk +//! *thread*\ :sub:`0` only item which is going to be subtracted from the first tile item +//! (*input*\ :sub:`0` from *thread*\ :sub:`0`). +//! + + +**Example** + +The code snippet below illustrates how to use BlockAdjacentDifference to compute the left difference between adjacent elements. + +`{ [4,2,1,1], [1,1,1,1], [2,3,3,3], [3,4,1,4], ... }`. The corresponding output `result` in those threads will be `{ [0,-2,-1,0], [0,0,0,0], [1,3,3,3], [3,4,1,4], ... }`. + +```cpp showLineNumbers={false} +#include +// or equivalently + +struct CustomDifference +{ + template + __host__ DataType operator()(DataType &lhs, DataType &rhs) + { + return lhs - rhs; + } +}; + +__global__ void ExampleKernel(...) +{ + // Specialize BlockAdjacentDifference for a 1D block of + // 128 threads of type int + using BlockAdjacentDifferenceT = + cub::BlockAdjacentDifference; + + // Allocate shared memory for BlockAdjacentDifference + __shared__ typename BlockAdjacentDifferenceT::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + int valid_items = 9; + int tile_predecessor_item = 4; + + // Collectively compute adjacent_difference + BlockAdjacentDifferenceT(temp_storage).SubtractLeftPartialTile( + thread_data, + thread_data, + CustomDifference(), + valid_items, + tile_predecessor_item); +``` + + + + +--- + +## Read right operations + +### SubtractRight inline + + + + + +```cpp showLineNumbers={false} +template +void cub::BlockAdjacentDifference::SubtractRight( + T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op +) +``` + + + + + +Subtracts the right element of each adjacent pair of elements partitioned across a CUDA thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockAdjacentDifference::SubtractRight( + T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + T tile_successor_item +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input items (may be aliased to `output`) + + + +Calling thread's adjacent difference result + + + +Binary difference operator + + + +Embed:rst:leading-asterisk +//! *thread*\ :sub:`BLOCK_THREADS` only item which is going to be subtracted from the last tile item +//! (*input*\ :sub:`ITEMS_PER_THREAD` from *thread*\ :sub:`BLOCK_THREADS`). +//! + + +**Example** + +The code snippet below illustrates how to use BlockAdjacentDifference to compute the right difference between adjacent elements. + +`{ ...3], [4,2,1,1], [1,1,1,1], [2,3,3,3], [3,4,1,4] }`, and that `tile_successor_item` is `3`. The corresponding output `result` in those threads will be `{ ...-1, [2,1,0,0], [0,0,0,-1], [-1,0,0,0], [-1,3,-3,1] }`. + +```cpp showLineNumbers={false} +#include +// or equivalently + +struct CustomDifference +{ + template + __host__ DataType operator()(DataType &lhs, DataType &rhs) + { + return lhs - rhs; + } +}; + +__global__ void ExampleKernel(...) +{ + // Specialize BlockAdjacentDifference for a 1D block of + // 128 threads of type int + using BlockAdjacentDifferenceT = + cub::BlockAdjacentDifference; + + // Allocate shared memory for BlockAdjacentDifference + __shared__ typename BlockAdjacentDifferenceT::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // The first item in the next tile: + int tile_successor_item = ...; + + // Collectively compute adjacent_difference + BlockAdjacentDifferenceT(temp_storage).SubtractRight( + thread_data, + thread_data, + CustomDifference(), + tile_successor_item); +``` + + + + +### SubtractRightPartialTile inline + +Subtracts the right element of each adjacent pair in range of elements partitioned across a CUDA thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockAdjacentDifference::SubtractRightPartialTile( + T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + int valid_items +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input items (may be aliased to `output`) + + + +Calling thread's adjacent difference result + + + +Binary difference operator + + + +Number of valid items in thread block + + +**Example** + +The code snippet below illustrates how to use BlockAdjacentDifference to compute the right difference between adjacent elements. + +`{ ...3], [4,2,1,1], [1,1,1,1], [2,3,3,3], [3,4,1,4] }`. and that `valid_items` is `507`. The corresponding output `result` in those threads will be `{ ...-1, [2,1,0,0], [0,0,0,-1], [-1,0,3,3], [3,4,1,4] }`. + +```cpp showLineNumbers={false} +#include +// or equivalently + +struct CustomDifference +{ + template + __host__ DataType operator()(DataType &lhs, DataType &rhs) + { + return lhs - rhs; + } +}; + +__global__ void ExampleKernel(...) +{ + // Specialize BlockAdjacentDifference for a 1D block of + // 128 threads of type int + using BlockAdjacentDifferenceT = + cub::BlockAdjacentDifference; + + // Allocate shared memory for BlockAdjacentDifference + __shared__ typename BlockAdjacentDifferenceT::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute adjacent_difference + BlockAdjacentDifferenceT(temp_storage).SubtractRightPartialTile( + thread_data, + thread_data, + CustomDifference(), + valid_items); +``` + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockAdjacentDifference::PrivateStorage() +``` + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | + +--- + +## Inner classes + +### _TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockAdjacentDifference::_TempStorage +``` + + +Shared memory storage layout type (last element from each thread's input). + +| Name | Type | Description | +|---|---|---| +| `first_items` | `T` | | +| `last_items` | `T` | | + +### ApplyOp + + +```cpp showLineNumbers={false} +struct cub::BlockAdjacentDifference::ApplyOp +``` + + +Specialization for when FlagOp has third index param. + +### ApplyOp< FlagOp, false > + + +```cpp showLineNumbers={false} +struct cub::BlockAdjacentDifference::ApplyOp< FlagOp, false > +``` + + +Specialization for when FlagOp does not have a third index param. + +### Iterate + + +```cpp showLineNumbers={false} +struct cub::BlockAdjacentDifference::Iterate +``` + + +Templated unrolling of item comparison (inductive case). + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockAdjacentDifference::TempStorage +``` + + +The operations exposed by `BlockAdjacentDifference` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockDiscontinuity.mdx b/fern/cudapages/cub/cub/cub/BlockDiscontinuity.mdx new file mode 100644 index 0000000..f34d82a --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockDiscontinuity.mdx @@ -0,0 +1,1003 @@ +--- +title: cub::BlockDiscontinuity +description: "" +--- + +The BlockDiscontinuity class provides collective methods for flagging discontinuities within an ordered set of items partitioned across a CUDA thread block. + +## Performance considerations + +- Performance is sensitive to the degree of data movement across the block. +- Incurs zero bank conflicts for most types + +## Example + +The code snippet below illustrates the head flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +`{ [0,0,1,1], [1,1,1,1], [2,3,3,3], [3,4,4,4], ... }`. The corresponding output `head_flags` in those threads will be `{ [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int + using BlockDiscontinuity = cub::BlockDiscontinuity; + + // Allocate shared memory for BlockDiscontinuity + __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute head flags for discontinuities in the segment + int head_flags[4]; + BlockDiscontinuity(temp_storage).FlagHeads(head_flags, thread_data, cub::Inequality()); +} +``` + + + + + +The data type to be flagged. + + + +The thread block length in threads along the X dimension + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockDiscontinuity inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockDiscontinuity::BlockDiscontinuity() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockDiscontinuity::BlockDiscontinuity( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockDiscontinuity::TempStorage) + + + + + +--- + +## Head flag operations + +### FlagHeads inline + + + + +Sets head flags indicating discontinuities between items partitioned across the thread block, for which the first item has no reference and is always flagged. + + +```cpp showLineNumbers={false} +template +void cub::BlockDiscontinuity::FlagHeads( + FlagT (&head_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The flag `head_flags[i]` is set for item `input[i]` when `flag_op(previous-item, input[i])` returns `true` (where `previous-item` is either the preceding item in the same thread or the last item in the previous thread). +For *thread*0, item `input[0]` is always flagged. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread + + + +**[inferred]** The flag type (must be an integer type) + + + +**[inferred]** Binary predicate functor type having member `T operator()(const T &a, const T &b)` or member `T operator()(const T &a, const T &b, unsigned int b_index)`, and returning `true` if a discontinuity exists between `a` and `b`, otherwise `false`. `b_index` is the rank of b in the aggregate tile of data. + + +**Parameters** + + +Calling thread's discontinuity head_flags + + + +Calling thread's input items + + + +Binary boolean flag predicate + + +**Example** + +The code snippet below illustrates the head-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +`{ [0,0,1,1], [1,1,1,1], [2,3,3,3], [3,4,4,4], ... }`. The corresponding output `head_flags` in those threads will be `{ [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int + using BlockDiscontinuity = cub::BlockDiscontinuity; + + // Allocate shared memory for BlockDiscontinuity + __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute head flags for discontinuities in the segment + int head_flags[4]; + BlockDiscontinuity(temp_storage).FlagHeads(head_flags, thread_data, cub::Inequality()); +} +``` + + + + +Sets head flags indicating discontinuities between items partitioned across the thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockDiscontinuity::FlagHeads( + FlagT (&head_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op, + T tile_predecessor_item +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The flag `head_flags[i]` is set for item `input[i]` when `flag_op(previous-item, input[i])` returns `true` (where `previous-item` is either the preceding item in the same thread or the last item in the previous thread). +For *thread*0, item `input[0]` is compared against `tile_predecessor_item`. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** The flag type (must be an integer type) + + + +**[inferred]** Binary predicate functor type having member `T operator()(const T &a, const T &b)` or member `T operator()(const T &a, const T &b, unsigned int b_index)`, and returning `true` if a discontinuity exists between `a` and `b`, otherwise `false`. `b_index` is the rank of b in the aggregate tile of data. + + +**Parameters** + + +Calling thread's discontinuity `head_flags` + + + +Calling thread's input items + + + +Binary boolean flag predicate + + + +Embed:rst:leading-asterisk +//! *thread*\ :sub:`0` only item with which to compare the first tile item (``input[0]`` from *thread*\ :sub:`0`). +//! + + +**Example** + +The code snippet below illustrates the head-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +`{ [0,0,1,1], [1,1,1,1], [2,3,3,3], [3,4,4,4], ... }`, and that `tile_predecessor_item` is `0`. The corresponding output `head_flags` in those threads will be `{ [0,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int + using BlockDiscontinuity = cub::BlockDiscontinuity; + + // Allocate shared memory for BlockDiscontinuity + __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Have thread0 obtain the predecessor item for the entire tile + int tile_predecessor_item; + if (threadIdx.x == 0) tile_predecessor_item == ... + + // Collectively compute head flags for discontinuities in the segment + int head_flags[4]; + BlockDiscontinuity(temp_storage).FlagHeads(head_flags, thread_data, + cub::Inequality(), tile_predecessor_item); +} +``` + + + + +--- + +## Tail flag operations + +### FlagTails inline + + + + +Sets tail flags indicating discontinuities between items partitioned across the thread block, for which the last item has no reference and is always flagged. + + +```cpp showLineNumbers={false} +template +void cub::BlockDiscontinuity::FlagTails( + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The flag `tail_flags[i]` is set for item `input[i]` when `flag_op(input[i], next-item)` returns `true` (where `next-item` is either the next item in the same thread or the first item in the next thread). +For *thread*BLOCK_THREADS - 1, item `input[ITEMS_PER_THREAD - 1]` is always flagged. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** The flag type (must be an integer type) + + + +**[inferred]** Binary predicate functor type having member `T operator()(const T &a, const T &b)` or member `T operator()(const T &a, const T &b, unsigned int b_index)`, and returning `true` if a discontinuity exists between `a` and `b`, otherwise `false`. `b_index` is the rank of `b` in the aggregate tile of data. + + +**Parameters** + + +Calling thread's discontinuity tail_flags + + + +Calling thread's input items + + + +Binary boolean flag predicate + + +**Example** + +The code snippet below illustrates the tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +`{ [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }`. The corresponding output `tail_flags` in those threads will be `{ [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,1] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int + using BlockDiscontinuity = cub::BlockDiscontinuity; + + // Allocate shared memory for BlockDiscontinuity + __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute tail flags for discontinuities in the segment + int tail_flags[4]; + BlockDiscontinuity(temp_storage).FlagTails(tail_flags, thread_data, cub::Inequality()); +} +``` + + + + +Sets tail flags indicating discontinuities between items partitioned across the thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockDiscontinuity::FlagTails( + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op, + T tile_successor_item +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The flag `tail_flags[i]` is set for item `input[i]` when `flag_op(input[i], next-item)` returns `true` (where `next-item` is either the next item in the same thread or the first item in the next thread). +For *thread*BLOCK_THREADS - 1, item `input[ITEMS_PER_THREAD - 1]` is compared against `tile_successor_item`. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** The flag type (must be an integer type) + + + +**[inferred]** Binary predicate functor type having member `T operator()(const T &a, const T &b)` or member `T operator()(const T &a, const T &b, unsigned int b_index)`, and returning `true` if a discontinuity exists between `a` and `b`, otherwise `false`. `b_index` is the rank of `b` in the aggregate tile of data. + + +**Parameters** + + +Calling thread's discontinuity tail_flags + + + +Calling thread's input items + + + +Binary boolean flag predicate + + + +Embed:rst:leading-asterisk +//! *thread*\ :sub:`BLOCK_THREADS - 1` only item with which to +//! compare the last tile item (``input[ITEMS_PER_THREAD - 1]`` from +//! *thread*\ :sub:`BLOCK_THREADS - 1`). +//! + + +**Example** + +The code snippet below illustrates the tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +`{ [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }` and that `tile_successor_item` is `125`. The corresponding output `tail_flags` in those threads will be `{ [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,0] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int + using BlockDiscontinuity = cub::BlockDiscontinuity; + + // Allocate shared memory for BlockDiscontinuity + __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Have thread127 obtain the successor item for the entire tile + int tile_successor_item; + if (threadIdx.x == 127) tile_successor_item == ... + + // Collectively compute tail flags for discontinuities in the segment + int tail_flags[4]; + BlockDiscontinuity(temp_storage).FlagTails(tail_flags, thread_data, + cub::Inequality(), tile_successor_item); +} +``` + + + + +--- + +## Head & tail flag operations + +### FlagHeadsAndTails inline + + + + +Sets both head and tail flags indicating discontinuities between items partitioned across the thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockDiscontinuity::FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The flag `head_flags[i]` is set for item `input[i]` when `flag_op(previous-item, input[i])` returns `true` (where `previous-item` is either the preceding item in the same thread or the last item in the previous thread). +For *thread*0, item `input[0]` is always flagged. +The flag `tail_flags[i]` is set for item `input[i]` when `flag_op(input[i], next-item)` returns `true` (where next-item is either the next item in the same thread or the first item in the next thread). +For *thread*BLOCK_THREADS - 1, item `input[ITEMS_PER_THREAD - 1]` is always flagged. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** The flag type (must be an integer type) + + + +**[inferred]** Binary predicate functor type having member `T operator()(const T &a, const T &b)` or member `T operator()(const T &a, const T &b, unsigned int b_index)`, and returning `true` if a discontinuity exists between `a` and `b`, otherwise `false`. `b_index` is the rank of `b` in the aggregate tile of data. + + +**Parameters** + + +Calling thread's discontinuity head_flags + + + +Calling thread's discontinuity tail_flags + + + +Calling thread's input items + + + +Binary boolean flag predicate + + +**Example** + +The code snippet below illustrates the head- and tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +`{ [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }` and that the tile_successor_item is `125`. The corresponding output `head_flags` in those threads will be `{ [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }`. and the corresponding output `tail_flags` in those threads will be `{ [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,1] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int + using BlockDiscontinuity = cub::BlockDiscontinuity; + + // Allocate shared memory for BlockDiscontinuity + __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute head and flags for discontinuities in the segment + int head_flags[4]; + int tail_flags[4]; + BlockDiscontinuity(temp_storage).FlagHeadsAndTails(head_flags, tail_flags, thread_data, + cub::Inequality()); +} +``` + + + + +Sets both head and tail flags indicating discontinuities between items partitioned across the thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockDiscontinuity::FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T tile_successor_item, + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The flag `head_flags[i]` is set for item `input[i]` when `flag_op(previous-item, input[i])` returns `true` (where `previous-item` is either the preceding item in the same thread or the last item in the previous thread). +For *thread*0, item `input[0]` is always flagged. +The flag `tail_flags[i]` is set for item `input[i]` when `flag_op(input[i], next-item)` returns `true` (where `next-item` is either the next item in the same thread or the first item in the next thread). +For *thread*BLOCK_THREADS - 1, item `input[ITEMS_PER_THREAD - 1]` is compared against `tile_predecessor_item`. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** The flag type (must be an integer type) + + + +**[inferred]** Binary predicate functor type having member `T operator()(const T &a, const T &b)` or member `T operator()(const T &a, const T &b, unsigned int b_index)`, and returning `true` if a discontinuity exists between `a` and `b`, otherwise `false`. `b_index` is the rank of b in the aggregate tile of data. + + +**Parameters** + + +Calling thread's discontinuity head_flags + + + +Calling thread's discontinuity tail_flags + + + +Embed:rst:leading-asterisk +//! *thread*\ :sub:`BLOCK_THREADS - 1` only item with which to compare +//! the last tile item (``input[ITEMS_PER_THREAD - 1]`` from +//! *thread*\ :sub:`BLOCK_THREADS - 1`). +//! + + + +Calling thread's input items + + + +Binary boolean flag predicate + + +**Example** + +The code snippet below illustrates the head- and tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +`{ [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }` and that the tile_successor_item is `125`. The corresponding output `head_flags` in those threads will be `{ [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }`. and the corresponding output `tail_flags` in those threads will be `{ [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,0] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int + using BlockDiscontinuity = cub::BlockDiscontinuity; + + // Allocate shared memory for BlockDiscontinuity + __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Have thread127 obtain the successor item for the entire tile + int tile_successor_item; + if (threadIdx.x == 127) tile_successor_item == ... + + // Collectively compute head and flags for discontinuities in the segment + int head_flags[4]; + int tail_flags[4]; + BlockDiscontinuity(temp_storage).FlagHeadsAndTails(head_flags, tail_flags, + tile_successor_item, thread_data, + cub::Inequality()); +} +``` + + + + +Sets both head and tail flags indicating discontinuities between items partitioned across the thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockDiscontinuity::FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], + T tile_predecessor_item, + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The flag `head_flags[i]` is set for item `input[i]` when `flag_op(previous-item, input[i])` returns `true` (where `previous-item` is either the preceding item in the same thread or the last item in the previous thread). +For *thread*0, item `input[0]` is compared against `tile_predecessor_item`. +The flag `tail_flags[i]` is set for item `input[i]` when `flag_op(input[i], next-item)` returns `true` (where `next-item` is either the next item in the same thread or the first item in the next thread). +For *thread*BLOCK_THREADS - 1, item `input[ITEMS_PER_THREAD - 1]` is always flagged. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** The flag type (must be an integer type) + + + +**[inferred]** Binary predicate functor type having member `T operator()(const T &a, const T &b)` or member `T operator()(const T &a, const T &b, unsigned int b_index)`, and returning `true` if a discontinuity exists between `a` and `b`, otherwise `false`. `b_index` is the rank of b in the aggregate tile of data. + + +**Parameters** + + +Calling thread's discontinuity head_flags + + + +Embed:rst:leading-asterisk +//! *thread*\ :sub:`0` only item with which to compare the first tile item (``input[0]`` from *thread*\ :sub:`0`). +//! + + + +Calling thread's discontinuity tail_flags + + + +Calling thread's input items + + + +Binary boolean flag predicate + + +**Example** + +The code snippet below illustrates the head- and tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +`{ [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }`, that the `tile_predecessor_item` is `0`, and that the `tile_successor_item` is `125`. The corresponding output `head_flags` in those threads will be `{ [0,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }`, and the corresponding output `tail_flags` in those threads will be `{ [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,1] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int + using BlockDiscontinuity = cub::BlockDiscontinuity; + + // Allocate shared memory for BlockDiscontinuity + __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Have thread0 obtain the predecessor item for the entire tile + int tile_predecessor_item; + if (threadIdx.x == 0) tile_predecessor_item == ... + + // Have thread127 obtain the successor item for the entire tile + int tile_successor_item; + if (threadIdx.x == 127) tile_successor_item == ... + + // Collectively compute head and flags for discontinuities in the segment + int head_flags[4]; + int tail_flags[4]; + BlockDiscontinuity(temp_storage).FlagHeadsAndTails(head_flags, tile_predecessor_item, + tail_flags, tile_successor_item, + thread_data, cub::Inequality()); +} +``` + + + + +Sets both head and tail flags indicating discontinuities between items partitioned across the thread block. + + +```cpp showLineNumbers={false} +template +void cub::BlockDiscontinuity::FlagHeadsAndTails( + FlagT (&head_flags)[ITEMS_PER_THREAD], + T tile_predecessor_item, + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T tile_successor_item, + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The flag `head_flags[i]` is set for item `input[i]` when `flag_op(previous-item, input[i])` returns `true` (where `previous-item` is either the preceding item in the same thread or the last item in the previous thread). +For *thread*0, item `input[0]` is compared against `tile_predecessor_item`. +The flag `tail_flags[i]` is set for item `input[i]` when `flag_op(input[i], next-item)` returns `true` (where `next-item` is either the next item in the same thread or the first item in the next thread). +For *thread*BLOCK_THREADS - 1, item `input[ITEMS_PER_THREAD - 1]` is compared against `tile_successor_item`. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** The flag type (must be an integer type) + + + +**[inferred]** Binary predicate functor type having member `T operator()(const T &a, const T &b)` or member `T operator()(const T &a, const T &b, unsigned int b_index)`, and returning `true` if a discontinuity exists between `a` and `b`, otherwise `false`. `b_index` is the rank of `b` in the aggregate tile of data. + + +**Parameters** + + +Calling thread's discontinuity head_flags + + + +Embed:rst:leading-asterisk +//! *thread*\ :sub:`0` only item with which to compare the first tile item (``input[0]`` from *thread*\ :sub:`0`). +//! + + + +Calling thread's discontinuity tail_flags + + + +Embed:rst:leading-asterisk +//! *thread*\ :sub:`BLOCK_THREADS - 1` only item with which to compare the last tile item +//! (``input[ITEMS_PER_THREAD - 1]`` from *thread*\ :sub:`BLOCK_THREADS - 1`). +//! + + + +Calling thread's input items + + + +Binary boolean flag predicate + + +**Example** + +The code snippet below illustrates the head- and tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +`{ [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }`, that the `tile_predecessor_item` is `0`, and that the `tile_successor_item` is `125`. The corresponding output `head_flags` in those threads will be `{ [0,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }`. and the corresponding output `tail_flags` in those threads will be `{ [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,0] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int + using BlockDiscontinuity = cub::BlockDiscontinuity; + + // Allocate shared memory for BlockDiscontinuity + __shared__ typename BlockDiscontinuity::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Have thread0 obtain the predecessor item for the entire tile + int tile_predecessor_item; + if (threadIdx.x == 0) tile_predecessor_item == ... + + // Have thread127 obtain the successor item for the entire tile + int tile_successor_item; + if (threadIdx.x == 127) tile_successor_item == ... + + // Collectively compute head and flags for discontinuities in the segment + int head_flags[4]; + int tail_flags[4]; + BlockDiscontinuity(temp_storage).FlagHeadsAndTails(head_flags, tile_predecessor_item, + tail_flags, tile_successor_item, + thread_data, cub::Inequality()); +} +``` + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockDiscontinuity::PrivateStorage() +``` + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | + +--- + +## Inner classes + +### _TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockDiscontinuity::_TempStorage +``` + + +Shared memory storage layout type (last element from each thread's input). + +| Name | Type | Description | +|---|---|---| +| `first_items` | `T` | | +| `last_items` | `T` | | + +### ApplyOp + + +```cpp showLineNumbers={false} +struct cub::BlockDiscontinuity::ApplyOp +``` + + +Specialization for when FlagOp has third index param. + +### ApplyOp< FlagOp, false > + + +```cpp showLineNumbers={false} +struct cub::BlockDiscontinuity::ApplyOp< FlagOp, false > +``` + + +Specialization for when FlagOp does not have a third index param. + +### Iterate + + +```cpp showLineNumbers={false} +struct cub::BlockDiscontinuity::Iterate +``` + + +Templated unrolling of item comparison (inductive case). + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockDiscontinuity::TempStorage +``` + + +The operations exposed by `BlockDiscontinuity` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockExchange.mdx b/fern/cudapages/cub/cub/cub/BlockExchange.mdx new file mode 100644 index 0000000..f213ba7 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockExchange.mdx @@ -0,0 +1,972 @@ +--- +title: cub::BlockExchange +description: "" +--- + +The BlockExchange class provides collective methods for rearranging data partitioned across a CUDA thread block. + +## Performance considerations + +- Proper device-specific padding ensures zero bank conflicts for most types. + +## Example + +The code snippet below illustrates the conversion from a "blocked" to a "striped" arrangement of 512 integer items partitioned across 128 threads where each thread owns 4 items. + +[1,129,257,385], ..., [127,255,383,511] }`. The corresponding output `thread_data`` in those threads will be `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each + using BlockExchange = cub::BlockExchange; + + // Allocate shared memory for BlockExchange + __shared__ typename BlockExchange::TempStorage temp_storage; + + // Load a tile of data striped across threads + int thread_data[4]; + cub::LoadDirectStriped<128>(threadIdx.x, d_data, thread_data); + + // Collectively exchange data into a blocked arrangement across threads + BlockExchange(temp_storage).StripedToBlocked(thread_data); +} +``` + + + + + +The data type to be exchanged + + + +The thread block length in threads along the X dimension + + + +The number of items partitioned onto each thread. + + + +**[optional]** When `true`, only use enough shared memory for a single warp's worth of tile data, time-slicing the block-wide exchange over multiple synchronized rounds. Yields a smaller memory footprint at the expense of decreased parallelism. (Default: false) + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockExchange inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockExchange::BlockExchange() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockExchange::BlockExchange( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::_block_exchange::TempStorage) + + + + + +--- + +## Structured exchanges + +### StripedToBlocked inline + +Transposes data items from **striped** arrangement to **blocked** arrangement. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::StripedToBlocked( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Items to exchange, converting between **striped** and **blocked** arrangements. + + + +Items from exchange, converting between **striped** and **blocked** arrangements. + + +**Example** + +The code snippet below illustrates the conversion from a "striped" to a "blocked" arrangement of 512 integer items partitioned across 128 threads where each thread owns 4 items. + +[1,129,257,385], ..., [127,255,383,511] }`` after loading from device-accessible memory. The corresponding output `thread_data` in those threads will be `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each + using BlockExchange = cub::BlockExchange; + + // Allocate shared memory for BlockExchange + __shared__ typename BlockExchange::TempStorage temp_storage; + + // Load a tile of ordered data into a striped arrangement across block threads + int thread_data[4]; + cub::LoadDirectStriped<128>(threadIdx.x, d_data, thread_data); + + // Collectively exchange data into a blocked arrangement across threads + BlockExchange(temp_storage).StripedToBlocked(thread_data, thread_data); +} +``` + +### BlockedToStriped inline + +Transposes data items from **blocked** arrangement to **striped** arrangement. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::BlockedToStriped( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Items to exchange, converting between **striped** and **blocked** arrangements. + + + +Items from exchange, converting between **striped** and **blocked** arrangements. + + +**Example** + +The code snippet below illustrates the conversion from a "blocked" to a "striped" arrangement of 512 integer items partitioned across 128 threads where each thread owns 4 items. + +[8,9,10,11], ..., [508,509,510,511] }`. The corresponding output `thread_data`` in those threads will be `{ [0,128,256,384], [1,129,257,385], ..., [127,255,383,511] }` in preparation for storing to device-accessible memory. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each + using BlockExchange = cub::BlockExchange; + + // Allocate shared memory for BlockExchange + __shared__ typename BlockExchange::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively exchange data into a striped arrangement across threads + BlockExchange(temp_storage).BlockedToStriped(thread_data, thread_data); + + // Store data striped across block threads into an ordered tile + cub::StoreDirectStriped(threadIdx.x, d_data, thread_data); +} +``` + +### WarpStripedToBlocked inline + +Transposes data items from **warp-striped** arrangement to **blocked** arrangement. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::WarpStripedToBlocked( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Items to exchange, converting between **striped** and **blocked** arrangements. + + + +Items from exchange, converting between **striped** and **blocked** arrangements. + + +**Example** + +The code snippet below illustrates the conversion from a "warp-striped" to a "blocked" arrangement of 512 integer items partitioned across 128 threads where each thread owns 4 items. + +[1,33,65,97], [2,34,66,98], ..., [415,447,479,511] }`` after loading from device-accessible memory. (The first 128 items are striped across the first warp of 32 threads, the second 128 items are striped across the second warp, etc.) The corresponding output `thread_data` in those threads will be `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each + using BlockExchange = cub::BlockExchange; + + // Allocate shared memory for BlockExchange + __shared__ typename BlockExchange::TempStorage temp_storage; + + // Load a tile of ordered data into a warp-striped arrangement across warp threads + int thread_data[4]; + cub::LoadSWarptriped(threadIdx.x, d_data, thread_data); + + // Collectively exchange data into a blocked arrangement across threads + BlockExchange(temp_storage).WarpStripedToBlocked(thread_data); +} +``` + +### BlockedToWarpStriped inline + +Transposes data items from **blocked** arrangement to **warp-striped** arrangement. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::BlockedToWarpStriped( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Items to exchange, converting between **striped** and **blocked** arrangements. + + + +Items from exchange, converting between **striped** and **blocked** arrangements. + + +**Example** + +The code snippet below illustrates the conversion from a "blocked" to a "warp-striped" arrangement of 512 integer items partitioned across 128 threads where each thread owns 4 items. + +[8,9,10,11], ..., [508,509,510,511] }`. The corresponding output `thread_data`` in those threads will be `{ [0,32,64,96], [1,33,65,97], [2,34,66,98], ..., [415,447,479,511] }` in preparation for storing to device-accessible memory. (The first 128 items are striped across the first warp of 32 threads, the second 128 items are striped across the second warp, etc.) + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each + using BlockExchange = cub::BlockExchange; + + // Allocate shared memory for BlockExchange + __shared__ typename BlockExchange::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively exchange data into a warp-striped arrangement across threads + BlockExchange(temp_storage).BlockedToWarpStriped(thread_data, thread_data); + + // Store data striped across warp threads into an ordered tile + cub::StoreDirectStriped(threadIdx.x, d_data, thread_data); +} +``` + +--- + +## Scatter exchanges + +### ScatterToBlocked inline + +Exchanges data items annotated by rank into **blocked** arrangement. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::ScatterToBlocked( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + OffsetT (&ranks)[ItemsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Signed integer type for local offsets + + +**Parameters** + + +Items to exchange, converting between **striped** and **blocked** arrangements. + + + +Items from exchange, converting between **striped** and **blocked** arrangements. + + + +Corresponding scatter ranks + + +### ScatterToStriped inline + +Exchanges data items annotated by rank into **striped** arrangement. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::ScatterToStriped( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + OffsetT (&ranks)[ItemsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Signed integer type for local offsets + + +**Parameters** + + +Items to exchange, converting between **striped** and **blocked** arrangements. + + + +Items from exchange, converting between **striped** and **blocked** arrangements. + + + +Corresponding scatter ranks + + +### ScatterToStripedGuarded inline + +Exchanges data items annotated by rank into **striped** arrangement. Items with rank -1 are not exchanged. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::ScatterToStripedGuarded( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + OffsetT (&ranks)[ItemsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Signed integer type for local offsets + + +**Parameters** + + +Items to exchange, converting between **striped** and **blocked** arrangements. + + + +Items from exchange, converting between **striped** and **blocked** arrangements. + + + +Corresponding scatter ranks + + +### ScatterToStripedFlagged inline + +Exchanges valid data items annotated by rank into **striped** arrangement. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::ScatterToStripedFlagged( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + OffsetT (&ranks)[ItemsPerThread], + ValidFlag (&is_valid)[ItemsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Signed integer type for local offsets + + + +**[inferred]** FlagT type denoting which items are valid + + +**Parameters** + + +Items to exchange, converting between **striped** and **blocked** arrangements. + + + +Items from exchange, converting between **striped** and **blocked** arrangements. + + + +Corresponding scatter ranks + + + +Corresponding flag denoting item validity + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockExchange::PrivateStorage() +``` + + +### BlockedToStriped inline + + + + +Transposes data items from **blocked** arrangement to **striped** arrangement. + +Specialized for no timeslicing. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::BlockedToStriped( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + ::cuda::std::false_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + + + +Transposes data items from **blocked** arrangement to **striped** arrangement. + +Specialized for warp-timeslicing. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::BlockedToStriped( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + ::cuda::std::true_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + + + +### BlockedToWarpStriped inline + + + + +Transposes data items from **blocked** arrangement to **warp-striped** arrangement. + +Specialized for no timeslicing + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::BlockedToWarpStriped( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + ::cuda::std::false_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + + + +Transposes data items from **blocked** arrangement to **warp-striped** arrangement. + +Specialized for warp-timeslicing + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::BlockedToWarpStriped( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + ::cuda::std::true_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + + + +### StripedToBlocked inline + + + + +Transposes data items from **striped** arrangement to **blocked** arrangement. + +Specialized for no timeslicing. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::StripedToBlocked( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + ::cuda::std::false_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + + + +Transposes data items from **striped** arrangement to **blocked** arrangement. + +Specialized for warp-timeslicing. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::StripedToBlocked( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + ::cuda::std::true_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + + + +### WarpStripedToBlocked inline + + + + +Transposes data items from **warp-striped** arrangement to **blocked** arrangement. + +Specialized for no timeslicing + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::WarpStripedToBlocked( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + ::cuda::std::false_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + + + +Transposes data items from **warp-striped** arrangement to **blocked** arrangement. + +Specialized for warp-timeslicing + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::WarpStripedToBlocked( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + ::cuda::std::true_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + + + +### ScatterToBlocked inline + + + + +Exchanges data items annotated by rank into **blocked** arrangement. + +Specialized for no timeslicing. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::ScatterToBlocked( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + OffsetT (&ranks)[ItemsPerThread], + ::cuda::std::false_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Corresponding scatter ranks + + + + + +Exchanges data items annotated by rank into **blocked** arrangement. + +Specialized for warp-timeslicing. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::ScatterToBlocked( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + OffsetT ranks[ItemsPerThread], + ::cuda::std::true_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Corresponding scatter ranks + + + + + +### ScatterToStriped inline + + + + +Exchanges data items annotated by rank into **striped** arrangement. + +Specialized for no timeslicing. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::ScatterToStriped( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + OffsetT (&ranks)[ItemsPerThread], + ::cuda::std::false_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Corresponding scatter ranks + + + + + +Exchanges data items annotated by rank into **striped** arrangement. + +Specialized for warp-timeslicing. + + +```cpp showLineNumbers={false} +template +void cub::BlockExchange::ScatterToStriped( + const T (&input_items)[ItemsPerThread], + OutputT (&output_items)[ItemsPerThread], + OffsetT (&ranks)[ItemsPerThread], + ::cuda::std::true_type +) +``` + + +**Parameters** + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Items to exchange, converting between **blocked** and **striped** arrangements. + + + +Corresponding scatter ranks + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `TempStorage` | `Uninitialized< _TempStorage >` | The operations exposed by `BlockExchange` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `WARP_THREADS` static constexpr | `int` | | +| `WARPS` static constexpr | `int` | | +| `LOG_SMEM_BANKS` static constexpr | `int` | | +| `TILE_ITEMS` static constexpr | `int` | | +| `TIME_SLICES` static constexpr | `int` | | +| `TIME_SLICED_THREADS` static constexpr | `int` | | +| `TIME_SLICED_ITEMS` static constexpr | `int` | | +| `WARP_TIME_SLICED_THREADS` static constexpr | `int` | | +| `WARP_TIME_SLICED_ITEMS` static constexpr | `int` | | +| `INSERT_PADDING` static constexpr | `bool` | | +| `PADDING_ITEMS` static constexpr | `int` | | +| `temp_storage` | `_TempStorage &` | | +| `linear_tid` | `unsigned int` | | +| `lane_id` | `unsigned int` | | +| `warp_id` | `unsigned int` | | +| `warp_offset` | `unsigned int` | | + +--- + +## Inner classes + +### _TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockExchange::_TempStorage +``` + + +Shared memory storage layout type. + +| Name | Type | Description | +|---|---|---| +| `buff` | `T` | | diff --git a/fern/cudapages/cub/cub/cub/BlockHistogram.mdx b/fern/cudapages/cub/cub/cub/BlockHistogram.mdx new file mode 100644 index 0000000..075d741 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockHistogram.mdx @@ -0,0 +1,360 @@ +--- +title: cub::BlockHistogram +description: "" +--- + +The BlockHistogram class provides collective methods for constructing block-wide histograms from data samples partitioned across a CUDA thread block. + +## Performance considerations + +- Performance is sensitive to the degree of data movement across the block. +- All input values must fall between `[0, Bins)`, or behavior is undefined. +- The histogram output can be constructed in shared or device-accessible memory +- See `cub::BlockHistogramAlgorithm` for performance details regarding algorithmic alternatives + +## Example + +The code snippet below illustrates a 256-bin histogram of 512 integer samples that are partitioned across 128 threads where each thread owns 4 samples. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize a 256-bin BlockHistogram type for a 1D block of 128 threads having 4 character samples each + using BlockHistogram = cub::BlockHistogram; + + // Allocate shared memory for BlockHistogram + __shared__ typename BlockHistogram::TempStorage temp_storage; + + // Allocate shared memory for block-wide histogram bin counts + __shared__ unsigned int smem_histogram[256]; + + // Obtain input samples per thread + unsigned char data[4]; + ... + + // Compute the block-wide histogram + BlockHistogram(temp_storage).Histogram(data, smem_histogram); +} +``` + + + + + +The sample type being histogrammed (must be castable to an integer bin identifier) + + + +The thread block length in threads along the X dimension + + + +The number of items per thread + + + +The number bins within the histogram + + + +**[optional]** [cub::BlockHistogramAlgorithm](/library/api/cub::BlockHistogramAlgorithm) enumerator specifying the underlying algorithm to use (default: [cub::BLOCK_HISTO_SORT](/library/api/cub::BLOCK_HISTO_SORT)) + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockHistogram inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockHistogram::BlockHistogram() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockHistogram::BlockHistogram( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockHistogram::TempStorage) + + + + + +--- + +## Histogram operations + +### InitHistogram inline + +Initialize the shared histogram counters to zero. + + +```cpp showLineNumbers={false} +template +void cub::BlockHistogram::InitHistogram( + CounterT histogram[Bins] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** Histogram counter type + + +**Example** + +The code snippet below illustrates a the initialization and update of a histogram of 512 integer samples that are partitioned across 128 threads where each thread owns 4 samples. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize a 256-bin BlockHistogram type for a 1D block of 128 threads having 4 character samples each + using BlockHistogram = cub::BlockHistogram; + + // Allocate shared memory for BlockHistogram + __shared__ typename BlockHistogram::TempStorage temp_storage; + + // Allocate shared memory for block-wide histogram bin counts + __shared__ unsigned int smem_histogram[256]; + + // Obtain input samples per thread + unsigned char thread_samples[4]; + ... + + // Initialize the block-wide histogram + BlockHistogram(temp_storage).InitHistogram(smem_histogram); + + // Update the block-wide histogram + BlockHistogram(temp_storage).Composite(thread_samples, smem_histogram); +} +``` + +### Histogram inline + +Constructs a block-wide histogram in shared/device-accessible memory. Each thread contributes an array of input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockHistogram::Histogram( + T (&items)[ItemsPerThread], + CounterT histogram[Bins] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Histogram counter type + + +**Parameters** + + +Calling thread's input values to histogram + + + +Reference to shared/device-accessible memory histogram + + +**Example** + +The code snippet below illustrates a 256-bin histogram of 512 integer samples that are partitioned across 128 threads where each thread owns 4 samples. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize a 256-bin BlockHistogram type for a 1D block of 128 threads having 4 character samples each + using BlockHistogram = cub::BlockHistogram; + + // Allocate shared memory for BlockHistogram + __shared__ typename BlockHistogram::TempStorage temp_storage; + + // Allocate shared memory for block-wide histogram bin counts + __shared__ unsigned int smem_histogram[256]; + + // Obtain input samples per thread + unsigned char thread_samples[4]; + ... + + // Compute the block-wide histogram + BlockHistogram(temp_storage).Histogram(thread_samples, smem_histogram); +} +``` + +### Composite inline + +Updates an existing block-wide histogram in shared/device-accessible memory. Each thread composites an array of input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockHistogram::Composite( + T (&items)[ItemsPerThread], + CounterT histogram[Bins] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Histogram counter type + + +**Parameters** + + +Calling thread's input values to histogram + + + +Reference to shared/device-accessible memory histogram + + +**Example** + +The code snippet below illustrates a the initialization and update of a histogram of 512 integer samples that are partitioned across 128 threads where each thread owns 4 samples. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize a 256-bin BlockHistogram type for a 1D block of 128 threads having 4 character samples each + using BlockHistogram = cub::BlockHistogram; + + // Allocate shared memory for BlockHistogram + __shared__ typename BlockHistogram::TempStorage temp_storage; + + // Allocate shared memory for block-wide histogram bin counts + __shared__ unsigned int smem_histogram[256]; + + // Obtain input samples per thread + unsigned char thread_samples[4]; + ... + + // Initialize the block-wide histogram + BlockHistogram(temp_storage).InitHistogram(smem_histogram); + + // Update the block-wide histogram + BlockHistogram(temp_storage).Composite(thread_samples, smem_histogram); +} +``` + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockHistogram::PrivateStorage() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `InternalBlockHistogram` | `::cuda::std::_If< Algorithm==BLOCK_HISTO_SORT, detail::BlockHistogramSort< T, BlockDimX, ItemsPerThread, Bins, BlockDimY, BlockDimZ >, detail::BlockHistogramAtomic< Bins > >` | Internal specialization. | +| `_TempStorage` | `typename InternalBlockHistogram::TempStorage` | Shared memory storage layout type for `BlockHistogram`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockHistogram::TempStorage +``` + + +The operations exposed by `BlockHistogram` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockLoad.mdx b/fern/cudapages/cub/cub/cub/BlockLoad.mdx new file mode 100644 index 0000000..17184e9 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockLoad.mdx @@ -0,0 +1,419 @@ +--- +title: cub::BlockLoad +description: "" +--- + +The BlockLoad class provides collective data movement methods for loading a linear segment of items from memory into a blocked arrangement across a CUDA thread block. + +## Example + +The code snippet below illustrates the loading of a linear segment of 512 integers into a "blocked" arrangement across 128 threads where each thread owns 4 consecutive items. The load is specialized for `BLOCK_LOAD_WARP_TRANSPOSE`, meaning memory references are efficiently coalesced using a warp-striped access pattern (after which items are locally reordered among threads). + +those threads will be `{ [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each + using BlockLoad = cub::BlockLoad; + + // Allocate shared memory for BlockLoad + __shared__ typename BlockLoad::TempStorage temp_storage; + + // Load a segment of consecutive items that are blocked across threads + int thread_data[4]; + BlockLoad(temp_storage).Load(d_data, thread_data); +} +``` + + + + + +The data type to read into (which must be convertible from the input iterator's value type). + + + + + + +The number of consecutive items partitioned onto each thread. + + + + + + + + + + + + + + +--- + +## Collective constructors + +### BlockLoad inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockLoad::BlockLoad() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockLoad::BlockLoad( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::_block_load::TempStorage) + + + + + +--- + +## Data movement + +### Load inline + + + + +Load a linear segment of items from memory. + + +```cpp showLineNumbers={false} +template +void cub::BlockLoad::Load( + RandomAccessIterator block_src_it, + T (&dst_items)[ItemsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The thread block's base iterator for loading from + + + +Destination to load data into + + +**Example** + +The code snippet below illustrates the loading of a linear segment of 512 integers into a "blocked" arrangement across 128 threads where each thread owns 4 consecutive items. The load is specialized for `BLOCK_LOAD_WARP_TRANSPOSE`, meaning memory references are efficiently coalesced using a warp-striped access pattern (after which items are locally reordered among threads). + +in those threads will be `{ [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each + using BlockLoad = cub::BlockLoad; + + // Allocate shared memory for BlockLoad + __shared__ typename BlockLoad::TempStorage temp_storage; + + // Load a segment of consecutive items that are blocked across threads + int thread_data[4]; + BlockLoad(temp_storage).Load(d_data, thread_data); +} +``` + + + + +Load a linear segment of items from memory, guarded by range. + + +```cpp showLineNumbers={false} +template +void cub::BlockLoad::Load( + RandomAccessIterator block_src_it, + T (&dst_items)[ItemsPerThread], + int block_items_end +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The thread block's base iterator for loading from + + + +Destination to load data into + + + +Number of valid items to load + + +**Example** + +The code snippet below illustrates the guarded loading of a linear segment of 512 integers into a "blocked" arrangement across 128 threads where each thread owns 4 consecutive items. The load is specialized for `BLOCK_LOAD_WARP_TRANSPOSE`, meaning memory references are efficiently coalesced using a warp-striped access pattern (after which items are locally reordered among threads). + +`thread_data` across the block of threads in those threads will be `{ [0,1,2,3], [4,?,?,?], ..., [?,?,?,?] }`, with only the first two threads being unmasked to load portions of valid data (and other items remaining unassigned). + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, int block_items_end, ...) +{ + // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each + using BlockLoad = cub::BlockLoad; + + // Allocate shared memory for BlockLoad + __shared__ typename BlockLoad::TempStorage temp_storage; + + // Load a segment of consecutive items that are blocked across threads + int thread_data[4]; + BlockLoad(temp_storage).Load(d_data, thread_data, block_items_end); +} +``` + + + + +Load a linear segment of items from memory, guarded by range, with a fall-back assignment of out-of-bound elements + + +```cpp showLineNumbers={false} +template +void cub::BlockLoad::Load( + RandomAccessIterator block_src_it, + T (&dst_items)[ItemsPerThread], + int block_items_end, + DefaultT oob_default +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The thread block's base iterator for loading from + + + +Destination to load data into + + + +Number of valid items to load + + + +Default value to assign out-of-bound items + + +**Example** + +The code snippet below illustrates the guarded loading of a linear segment of 512 integers into a "blocked" arrangement across 128 threads where each thread owns 4 consecutive items. The load is specialized for `BLOCK_LOAD_WARP_TRANSPOSE`, meaning memory references are efficiently coalesced using a warp-striped access pattern (after which items are locally reordered among threads). + +default is `-1`. The set of `thread_data` across the block of threads in those threads will be `{ [0,1,2,3], [4,-1,-1,-1], ..., [-1,-1,-1,-1] }`, with only the first two threads being unmasked to load portions of valid data (and other items are assigned `-1`) + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, int block_items_end, ...) +{ + // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each + using BlockLoad = cub::BlockLoad; + + // Allocate shared memory for BlockLoad + __shared__ typename BlockLoad::TempStorage temp_storage; + + // Load a segment of consecutive items that are blocked across threads + int thread_data[4]; + BlockLoad(temp_storage).Load(d_data, thread_data, block_items_end, -1); +} +``` + + + + +--- + +## Utility methods + +### PrivateStorage inline + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockLoad::PrivateStorage() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `InternalLoad` | `LoadInternal< Algorithm, 0 >` | | +| `_TempStorage` | `typename InternalLoad::TempStorage` | | +| `TempStorage` | `Uninitialized< _TempStorage >` | The operations exposed by `BlockLoad` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BlockThreads` static constexpr | `int` | | +| `temp_storage` | `_TempStorage &` | | +| `linear_tid` | `int` | | + +--- + +## Inner classes + +### LoadInternal + + +```cpp showLineNumbers={false} +struct cub::BlockLoad::LoadInternal +``` + + +### LoadInternal< BLOCK_LOAD_DIRECT, Dummy > + + +```cpp showLineNumbers={false} +struct cub::BlockLoad::LoadInternal< BLOCK_LOAD_DIRECT, Dummy > +``` + + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | | + +### LoadInternal< BLOCK_LOAD_STRIPED, Dummy > + + +```cpp showLineNumbers={false} +struct cub::BlockLoad::LoadInternal< BLOCK_LOAD_STRIPED, Dummy > +``` + + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | | + +### LoadInternal< BLOCK_LOAD_VECTORIZE, Dummy > + + +```cpp showLineNumbers={false} +struct cub::BlockLoad::LoadInternal< BLOCK_LOAD_VECTORIZE, Dummy > +``` + + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | | + +### LoadInternal< BLOCK_LOAD_TRANSPOSE, Dummy > + + +```cpp showLineNumbers={false} +struct cub::BlockLoad::LoadInternal< BLOCK_LOAD_TRANSPOSE, Dummy > +``` + + +| Name | Type | Description | +|---|---|---| +| `temp_storage` | `_TempStorage &` | | +| `linear_tid` | `int` | | + +### LoadInternal< BLOCK_LOAD_WARP_TRANSPOSE, Dummy > + + +```cpp showLineNumbers={false} +struct cub::BlockLoad::LoadInternal< BLOCK_LOAD_WARP_TRANSPOSE, Dummy > +``` + + +| Name | Type | Description | +|---|---|---| +| `WARP_THREADS` static constexpr | `int` | | +| `temp_storage` | `_TempStorage &` | | +| `linear_tid` | `int` | | + +### LoadInternal< BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED, Dummy > + + +```cpp showLineNumbers={false} +struct cub::BlockLoad::LoadInternal< BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED, Dummy > +``` + + +| Name | Type | Description | +|---|---|---| +| `WARP_THREADS` static constexpr | `int` | | +| `temp_storage` | `_TempStorage &` | | +| `linear_tid` | `int` | | diff --git a/fern/cudapages/cub/cub/cub/BlockLoadType.mdx b/fern/cudapages/cub/cub/cub/BlockLoadType.mdx new file mode 100644 index 0000000..ba43cc8 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockLoadType.mdx @@ -0,0 +1,29 @@ +--- +title: cub::BlockLoadType +description: "" +--- + + + + + + + + + + + + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `cub::BlockLoad< T, Policy::BLOCK_THREADS, Policy::ITEMS_PER_THREAD, Policy::LOAD_ALGORITHM >` | diff --git a/fern/cudapages/cub/cub/cub/BlockMergeSort.mdx b/fern/cudapages/cub/cub/cub/BlockMergeSort.mdx new file mode 100644 index 0000000..8ae5e16 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockMergeSort.mdx @@ -0,0 +1,224 @@ +--- +title: cub::BlockMergeSort +description: "The [BlockMergeSort](/library/api/cub::_block_merge_sort) class provides methods for sorting items partitioned across a CUDA thread block using a merge sorting method." +--- + +The `BlockMergeSort` class provides methods for sorting items partitioned across a CUDA thread block using a merge sorting method. + +**Overview** + +`BlockMergeSort` arranges items into ascending order using a comparison functor with less-than semantics. Merge sort can handle arbitrary types and comparison functors, but is slower than [BlockRadixSort](/library/api/cub::_block_radix_sort) when sorting arithmetic types into ascending/descending order. + +**A Simple Example** + +Every thread in the block uses the `BlockMergeSort` class by first specializing the `BlockMergeSort` type, then instantiating an instance with parameters for communication, and finally invoking one or more collective member functions. + +The code snippet below illustrates a sort of 512 integer keys that are partitioned across 128 threads * where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +struct CustomLess +{ + template + __device__ bool operator()(const DataType &lhs, const DataType &rhs) + { + return lhs < rhs; + } +}; + +__global__ void ExampleKernel(...) +{ + // Specialize BlockMergeSort for a 1D block of 128 threads owning 4 integer items each + using BlockMergeSort = cub::BlockMergeSort; + + // Allocate shared memory for BlockMergeSort + __shared__ typename BlockMergeSort::TempStorage temp_storage_shuffle; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_keys[4]; + ... + + BlockMergeSort(temp_storage_shuffle).Sort(thread_keys, CustomLess()); + ... +} +``` + +Suppose the set of input `thread_keys` across the block of threads is `{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`. The corresponding output `thread_keys` in those threads will be `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }`. + +**Re-using dynamically allocating shared memory** + +The `block/example_block_reduce_dyn_smem.cu` example illustrates usage of dynamically shared memory with [BlockReduce](/library/api/cub::_block_reduce) and how to re-purpose the same memory region. + +This example can be easily adapted to the storage required by `BlockMergeSort`. + + + + + +KeyT type + + + + + + +The number of items per thread + + + +**[optional]** ValueT type (default: `cub::NullType`, which indicates a keys-only sort) + + + + + + + + + + + +**Inherits from:** `cub::BlockMergeSortStrategy< KeyT, NullType, BlockDimX *1 *1, ItemsPerThread, BlockMergeSort< KeyT, BlockDimX, ItemsPerThread, NullType, 1, 1 > >` (public) + +--- + +## Constructors + +### BlockMergeSort inline + + + + + +```cpp showLineNumbers={false} +cub::BlockMergeSort::BlockMergeSort() +``` + + + + + +explicit + + +```cpp showLineNumbers={false} +cub::BlockMergeSort::BlockMergeSort( + typename BlockMergeSortStrategyT::TempStorage &temp_storage +) +``` + + + + + +--- + +## Methods + +### get_linear_tid inline const + + +```cpp showLineNumbers={false} +unsigned int cub::BlockMergeSortStrategy>::get_linear_tid() const +``` + + +### Sort inline + +Sorts items partitioned across a CUDA thread block using a merge sorting method. + + +```cpp showLineNumbers={false} +void cub::BlockMergeSortStrategy>::Sort( + KeyT (&keys)[ItemsPerThread], + CompareOp compare_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Keys to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + +### StableSort inline + +Sorts items partitioned across a CUDA thread block using a merge sorting method. + + +```cpp showLineNumbers={false} +void cub::BlockMergeSortStrategy>::StableSort( + KeyT (&keys)[ItemsPerThread], + CompareOp compare_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Keys to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + +### SyncImplementation inline const + + +```cpp showLineNumbers={false} +void cub::BlockMergeSort::SyncImplementation() const +``` + + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockMergeSortStrategy>::PrivateStorage() +``` + + +### Sync inline const + + +```cpp showLineNumbers={false} +void cub::BlockMergeSortStrategy>::Sync() const +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `BlockMergeSortStrategyT` | `BlockMergeSortStrategy< KeyT, ValueT, BLOCK_THREADS, ItemsPerThread, BlockMergeSort >` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `ITEMS_PER_TILE` static constexpr | `int` | | +| `KEYS_ONLY` static constexpr | `bool` | | +| `BlockMergeSortStrategyT` | `friend` | | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `const unsigned int` | | diff --git a/fern/cudapages/cub/cub/cub/BlockMergeSortStrategy.mdx b/fern/cudapages/cub/cub/cub/BlockMergeSortStrategy.mdx new file mode 100644 index 0000000..0e59e02 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockMergeSortStrategy.mdx @@ -0,0 +1,507 @@ +--- +title: cub::BlockMergeSortStrategy +description: "Generalized merge sort algorithm." +--- + +Generalized merge sort algorithm. + +This class is used to reduce code duplication. Warp and Block merge sort differ only in how they compute thread index and how they synchronize threads. Since synchronization might require access to custom data (like member mask), CRTP is used. + +The code snippet below illustrates the way this class can be used. + +```cpp showLineNumbers={false} +#include // or equivalently + +constexpr int BLOCK_THREADS = 256; +constexpr int ItemsPerThread = 9; + +class BlockMergeSort : public BlockMergeSortStrategy +{ + using BlockMergeSortStrategyT = + BlockMergeSortStrategy; +public: + __device__ __forceinline__ explicit BlockMergeSort( + typename BlockMergeSortStrategyT::TempStorage &temp_storage) + : BlockMergeSortStrategyT(temp_storage, threadIdx.x) + {} + + __device__ __forceinline__ void SyncImplementation() const + { + __syncthreads(); + } +}; +``` + + + + + +KeyT type + + + +ValueT type. cub::NullType indicates a keys-only sort + + + + + + + + + +Provides a way of synchronizing threads. Should be derived from `BlockMergeSortStrategy`. + + + + + +--- + +## Constructors + +### BlockMergeSortStrategy inline + + + + +explicit + + +```cpp showLineNumbers={false} +cub::BlockMergeSortStrategy::BlockMergeSortStrategy( + unsigned int linear_tid +) +``` + + + + + + +```cpp showLineNumbers={false} +cub::BlockMergeSortStrategy::BlockMergeSortStrategy( + TempStorage &temp_storage, + unsigned int linear_tid +) +``` + + + + + + +```cpp showLineNumbers={false} +cub::BlockMergeSortStrategy::BlockMergeSortStrategy() = delete +``` + + + + + +--- + +## Methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockMergeSortStrategy::PrivateStorage() +``` + + +### Sync inline const + + +```cpp showLineNumbers={false} +void cub::BlockMergeSortStrategy::Sync() const +``` + + +### get_linear_tid inline const + + +```cpp showLineNumbers={false} +unsigned int cub::BlockMergeSortStrategy::get_linear_tid() const +``` + + +### Sort inline + + + + +Sorts items partitioned across a CUDA thread block using a merge sorting method. + + +```cpp showLineNumbers={false} +template +void cub::BlockMergeSortStrategy::Sort( + KeyT (&keys)[ItemsPerThread], + CompareOp compare_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. `CompareOp` is a model of [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + +**Parameters** + + +Keys to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + + + +Sorts items partitioned across a CUDA thread block using a merge sorting method. + + +```cpp showLineNumbers={false} +template +void cub::BlockMergeSortStrategy::Sort( + KeyT (&keys)[ItemsPerThread], + CompareOp compare_op, + int valid_items, + KeyT oob_default +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. `CompareOp` is a model of [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + +**Parameters** + + +Keys to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + +Number of valid items to sort + + + +Default value to assign out-of-bound items + + + + + +Sorts items partitioned across a CUDA thread block using a merge sorting method. + + +```cpp showLineNumbers={false} +template +void cub::BlockMergeSortStrategy::Sort( + KeyT (&keys)[ItemsPerThread], + ValueT (&items)[ItemsPerThread], + CompareOp compare_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. `CompareOp` is a model of [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + + + +Sorts items partitioned across a CUDA thread block using a merge sorting method. + + +```cpp showLineNumbers={false} +template +void cub::BlockMergeSortStrategy::Sort( + KeyT (&keys)[ItemsPerThread], + ValueT (&items)[ItemsPerThread], + CompareOp compare_op, + int valid_items, + KeyT oob_default +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Functor type having member `bool operator()(KeyT lhs, KeyT rhs)` `CompareOp` is a model of [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + + +True if `valid_items` isn't equal to the [`ITEMS_PER_TILE`](/library/api/cub::_block_merge_sort_strategy::ITEMS_PER_TILE) + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + +Number of valid items to sort + + + +Default value to assign out-of-bound items + + + + + +### StableSort inline + + + + +Sorts items partitioned across a CUDA thread block using a merge sorting method. + + +```cpp showLineNumbers={false} +template +void cub::BlockMergeSortStrategy::StableSort( + KeyT (&keys)[ItemsPerThread], + CompareOp compare_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. `CompareOp` is a model of [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + +**Parameters** + + +Keys to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + + + +Sorts items partitioned across a CUDA thread block using a merge sorting method. + + +```cpp showLineNumbers={false} +template +void cub::BlockMergeSortStrategy::StableSort( + KeyT (&keys)[ItemsPerThread], + ValueT (&items)[ItemsPerThread], + CompareOp compare_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. `CompareOp` is a model of [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + + + +Sorts items partitioned across a CUDA thread block using a merge sorting method. + + +```cpp showLineNumbers={false} +template +void cub::BlockMergeSortStrategy::StableSort( + KeyT (&keys)[ItemsPerThread], + CompareOp compare_op, + int valid_items, + KeyT oob_default +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. `CompareOp` is a model of [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + +**Parameters** + + +Keys to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + +Number of valid items to sort + + + +Default value to assign out-of-bound items + + + + + +Sorts items partitioned across a CUDA thread block using a merge sorting method. + + +```cpp showLineNumbers={false} +template +void cub::BlockMergeSortStrategy::StableSort( + KeyT (&keys)[ItemsPerThread], + ValueT (&items)[ItemsPerThread], + CompareOp compare_op, + int valid_items, + KeyT oob_default +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. `CompareOp` is a model of [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + + +True if `valid_items` isn't equal to the [`ITEMS_PER_TILE`](/library/api/cub::_block_merge_sort_strategy::ITEMS_PER_TILE) + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + +Number of valid items to sort + + + +Default value to assign out-of-bound items + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `ITEMS_PER_TILE` static constexpr | `int` | | +| `KEYS_ONLY` static constexpr | `bool` | | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `const unsigned int` | | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockMergeSortStrategy::TempStorage +``` + + +The operations exposed by [BlockMergeSort](/library/api/cub::_block_merge_sort) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockRadixRank.mdx b/fern/cudapages/cub/cub/cub/BlockRadixRank.mdx new file mode 100644 index 0000000..5133d5a --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockRadixRank.mdx @@ -0,0 +1,284 @@ +--- +title: cub::BlockRadixRank +description: "" +--- + +BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block. + +## Performance considerations + +- Performance is sensitive to the degree of data movement across the block. + +Suppose the set of input `keys` across the block of threads is `{ [16,10], [9,11] }`. The extractor will rank only the lowest 5 bits: `{ [16,10], [9,11] }` (bits 0-4). The corresponding output `ranks` in those threads will be `{ [3,1], [0,2] }`. + + + + + +The thread block length in threads along the X dimension + + + +The number of radix bits per digit place + + + +Whether or not the sorted-order is high-to-low + + + +**[optional]** Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure (default: true for architectures SM35 and newer, false otherwise). See [`BlockScanAlgorithm::BLOCK_SCAN_RAKING_MEMOIZE`](/library/api/cub::BLOCK_SCAN_RAKING_MEMOIZE) for more details. + + + +**[optional]** The [cub::BlockScanAlgorithm](/library/api/cub::BlockScanAlgorithm) algorithm to use (default: [cub::BLOCK_SCAN_WARP_SCANS](/library/api/cub::BLOCK_SCAN_WARP_SCANS)) + + + +**[optional]** Shared memory bank mode (default: `cudaSharedMemBankSizeFourByte`) + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockRadixRank inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockRadixRank::BlockRadixRank() +``` + + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockRadixRank::BlockRadixRank( + TempStorage &temp_storage +) +``` + + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockRadixRank::TempStorage) + + + + + +--- + +## Raking + +### RankKeys inline + + + + +Rank keys. + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixRank::RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], + int (&ranks)[KEYS_PER_THREAD], + DigitExtractorT digit_extractor +) +``` + + +**Parameters** + + +Keys for this tile + + + +For each key, the local rank within the tile + + + +The digit extractor + + + + + +Rank keys. + +For the lower `RADIX_DIGITS` threads, digit counts for each digit are provided for the corresponding thread. + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixRank::RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], + int (&ranks)[KEYS_PER_THREAD], + DigitExtractorT digit_extractor, + int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD] +) +``` + + +**Parameters** + + +Keys for this tile + + + +For each key, the local rank within the tile (out parameter) + + + +The digit extractor + + + +The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] + + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockRadixRank::PrivateStorage() +``` + + +### Upsweep inline + +Performs upsweep raking reduction, returning the aggregate. + + +```cpp showLineNumbers={false} +PackedCounter cub::BlockRadixRank::Upsweep() +``` + + +### ExclusiveDownsweep inline + +Performs exclusive downsweep raking scan. + + +```cpp showLineNumbers={false} +void cub::BlockRadixRank::ExclusiveDownsweep( + PackedCounter raking_partial +) +``` + + +### ResetCounters inline + +Reset shared memory digit counters. + + +```cpp showLineNumbers={false} +void cub::BlockRadixRank::ResetCounters() +``` + + +### ScanCounters inline + +Scan shared memory digit counters. + + +```cpp showLineNumbers={false} +void cub::BlockRadixRank::ScanCounters() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `DigitCounter` | `unsigned short` | | +| `PackedCounter` | `::cuda::std::_If< SMemConfig==cudaSharedMemBankSizeEightByte, unsigned long long, unsigned int >` | | +| `BlockScan` | `BlockScan< PackedCounter, BlockDimX, InnerScanAlgorithm, BlockDimY, BlockDimZ >` | [BlockScan](/library/api/cub::BlockRadixRank::BlockScan) type. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `max_tile_size` static constexpr | `DigitCounter` | | +| `BLOCK_THREADS` static constexpr | `int` | | +| `RADIX_DIGITS` static constexpr | `int` | | +| `LOG_WARP_THREADS` static constexpr | `int` | | +| `WARP_THREADS` static constexpr | `int` | | +| `WARPS` static constexpr | `int` | | +| `BYTES_PER_COUNTER` static constexpr | `int` | | +| `LOG_BYTES_PER_COUNTER` static constexpr | `int` | | +| `PACKING_RATIO` static constexpr | `int` | | +| `LOG_PACKING_RATIO` static constexpr | `int` | | +| `LOG_COUNTER_LANES` static constexpr | `int` | | +| `COUNTER_LANES` static constexpr | `int` | | +| `PADDED_COUNTER_LANES` static constexpr | `int` | | +| `RAKING_SEGMENT` static constexpr | `int` | | +| `BINS_TRACKED_PER_THREAD` static constexpr | `int` | Number of bin-starting offsets tracked per thread. | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | +| `cached_segment` | `PackedCounter` | Copy of raking segment, promoted to registers. | + +--- + +## Inner classes + +### PrefixCallBack + + +```cpp showLineNumbers={false} +struct cub::BlockRadixRank::PrefixCallBack +``` + + +Block-scan prefix callback. + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockRadixRank::TempStorage +``` + + +The operations exposed by [BlockScan](/library/api/cub::BlockRadixRank::BlockScan) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockRadixRankEmptyCallback.mdx b/fern/cudapages/cub/cub/cub/BlockRadixRankEmptyCallback.mdx new file mode 100644 index 0000000..4450e4c --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockRadixRankEmptyCallback.mdx @@ -0,0 +1,29 @@ +--- +title: cub::BlockRadixRankEmptyCallback +description: "Empty callback implementation." +--- + +Empty callback implementation. + + + + + + + + + + +--- + +## Methods + +### operator() inline + + +```cpp showLineNumbers={false} +void cub::BlockRadixRankEmptyCallback::operator()( + int (&bins)[BINS_PER_THREAD] +) +``` + diff --git a/fern/cudapages/cub/cub/cub/BlockRadixRankMatch.mdx b/fern/cudapages/cub/cub/cub/BlockRadixRankMatch.mdx new file mode 100644 index 0000000..31e19bc --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockRadixRankMatch.mdx @@ -0,0 +1,246 @@ +--- +title: cub::BlockRadixRankMatch +description: "Radix-rank using match.any." +--- + +Radix-rank using match.any. + + + + + + + + + + + + + + + + + + + + + + + + + +--- + +## Collective constructors + +### BlockRadixRankMatch inline + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockRadixRankMatch::BlockRadixRankMatch( + TempStorage &temp_storage +) +``` + + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockRadixRankMatch::TempStorage) + + +--- + +## Raking + +### CallBack inline + +Computes the count of keys for each digit value, and calls the callback with the array of key counts. + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixRankMatch::CallBack( + CountsCallback callback +) +``` + + +**Template parameters** + + +The callback type. It should implement an instance overload of operator()(int (&bins)[BINS_TRACKED_PER_THREAD]), where bins is an array of key counts for each digit value distributed in block distribution among the threads of the thread block. Key counts can be used, to update other data structures in global or shared memory. Depending on the implementation of the ranking algoirhtm (see [BlockRadixRankMatchEarlyCounts](/library/api/cub::_block_radix_rank_match_early_counts)), key counts may become available early, therefore, they are returned through a callback rather than a separate output parameter of [RankKeys()](/library/api/cub::_block_radix_rank_match::RankKeys()). + + +### RankKeys inline + + + + +Rank keys. + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixRankMatch::RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], + int (&ranks)[KEYS_PER_THREAD], + DigitExtractorT digit_extractor, + CountsCallback callback +) +``` + + +**Parameters** + + +Keys for this tile + + + +For each key, the local rank within the tile + + + +The digit extractor + + + + + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixRankMatch::RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], + int (&ranks)[KEYS_PER_THREAD], + DigitExtractorT digit_extractor +) +``` + + + + + +Rank keys. + +For the lower `RADIX_DIGITS` threads, digit counts for each digit are provided for the corresponding thread. + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixRankMatch::RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], + int (&ranks)[KEYS_PER_THREAD], + DigitExtractorT digit_extractor, + int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD], + CountsCallback callback +) +``` + + +**Parameters** + + +Keys for this tile + + + +For each key, the local rank within the tile (out parameter) + + + +The digit extractor + + + +The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] + + + + + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixRankMatch::RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], + int (&ranks)[KEYS_PER_THREAD], + DigitExtractorT digit_extractor, + int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD] +) +``` + + +**Parameters** + + +Keys for this tile + + + +For each key, the local rank within the tile (out parameter) + + + +The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `RankT` | `int32_t` | | +| `DigitCounterT` | `int32_t` | | +| `BlockScanT` | `BlockScan< DigitCounterT, BLOCK_THREADS, InnerScanAlgorithm, BlockDimY, BlockDimZ >` | [BlockScan](/library/api/cub::_block_scan) type. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `RADIX_DIGITS` static constexpr | `int` | | +| `LOG_WARP_THREADS` static constexpr | `int` | | +| `WARP_THREADS` static constexpr | `int` | | +| `PARTIAL_WARP_THREADS` static constexpr | `int` | | +| `WARPS` static constexpr | `int` | | +| `PADDED_WARPS` static constexpr | `int` | | +| `COUNTERS` static constexpr | `int` | | +| `RAKING_SEGMENT` static constexpr | `int` | | +| `PADDED_RAKING_SEGMENT` static constexpr | `int` | | +| `BINS_TRACKED_PER_THREAD` static constexpr | `int` | Number of bin-starting offsets tracked per thread. | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockRadixRankMatch::TempStorage +``` + + +The operations exposed by `BlockRadixRankMatch` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockRadixRankMatchEarlyCounts.mdx b/fern/cudapages/cub/cub/cub/BlockRadixRankMatchEarlyCounts.mdx new file mode 100644 index 0000000..de40352 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockRadixRankMatchEarlyCounts.mdx @@ -0,0 +1,173 @@ +--- +title: cub::BlockRadixRankMatchEarlyCounts +description: "Radix-rank using matching which computes the counts of keys for each digit value early, at the expense of doing more work." +--- + +Radix-rank using matching which computes the counts of keys for each digit value early, at the expense of doing more work. + +This may be useful e.g. for decoupled look-back, where it reduces the time other thread blocks need to wait for digit counts to become available. + + + + + + + + + + + + + + + + + + + + + + + + + +--- + +## Constructors + +### BlockRadixRankMatchEarlyCounts inline + + +```cpp showLineNumbers={false} +cub::BlockRadixRankMatchEarlyCounts::BlockRadixRankMatchEarlyCounts( + TempStorage &temp_storage +) +``` + + +--- + +## Methods + +### RankKeys inline + + + + +Rank keys. + +For the lower `RADIX_DIGITS` threads, digit counts for each digit are provided for the corresponding thread. + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixRankMatchEarlyCounts::RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], + int (&ranks)[KEYS_PER_THREAD], + DigitExtractorT digit_extractor, + int (&exclusive_digit_prefix)[BINS_PER_THREAD], + CountsCallback callback +) +``` + + + + + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixRankMatchEarlyCounts::RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], + int (&ranks)[KEYS_PER_THREAD], + DigitExtractorT digit_extractor, + int (&exclusive_digit_prefix)[BINS_PER_THREAD] +) +``` + + + + + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixRankMatchEarlyCounts::RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], + int (&ranks)[KEYS_PER_THREAD], + DigitExtractorT digit_extractor +) +``` + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `BlockScan` | `cub::BlockScan< int, BLOCK_THREADS, InnerScanAlgorithm >` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `RADIX_DIGITS` static constexpr | `int` | | +| `BINS_PER_THREAD` static constexpr | `int` | | +| `BINS_TRACKED_PER_THREAD` static constexpr | `int` | | +| `FULL_BINS` static constexpr | `int` | | +| `WARP_THREADS` static constexpr | `int` | | +| `PARTIAL_WARP_THREADS` static constexpr | `int` | | +| `BLOCK_WARPS` static constexpr | `int` | | +| `PARTIAL_WARP_ID` static constexpr | `int` | | +| `WARP_MASK` static constexpr | `int` | | +| `NUM_MATCH_MASKS` static constexpr | `int` | | +| `MATCH_MASKS_ALLOC_SIZE` static constexpr | `int` | | +| `temp_storage` | `TempStorage &` | | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockRadixRankMatchEarlyCounts::TempStorage +``` + + +| Name | Type | Description | +|---|---|---| +| `warp_offsets` | `int` | | +| `warp_histograms` | `int` | | +| `` | `union cub::BlockRadixRankMatchEarlyCounts::TempStorage` | | +| `match_masks` | `::cuda::std::uint32_t` | | +| `prefix_tmp` | `BlockScan::TempStorage` | | + +### BlockRadixRankMatchInternal + + +```cpp showLineNumbers={false} +struct cub::BlockRadixRankMatchEarlyCounts::BlockRadixRankMatchInternal +``` + + +| Name | Type | Description | +|---|---|---| +| `s` | `TempStorage &` | | +| `digit_extractor` | `DigitExtractorT` | | +| `callback` | `CountsCallback` | | +| `warp` | `int` | | +| `lane` | `int` | | diff --git a/fern/cudapages/cub/cub/cub/BlockRadixSort.mdx b/fern/cudapages/cub/cub/cub/BlockRadixSort.mdx new file mode 100644 index 0000000..53da08e --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockRadixSort.mdx @@ -0,0 +1,1753 @@ +--- +title: cub::BlockRadixSort +description: "" +--- + +BlockRadixSort class provides collective methods for sorting items partitioned across a CUDA thread block using a radix sorting method. + +![](../../img/sorting_logo.png) + +The [radix sorting method](http://en.wikipedia.org/wiki/Radix_sort) arranges items into ascending order. It relies upon a positional representation for keys, i.e., each key is comprised of an ordered sequence of symbols (e.g., digits, characters, etc.) specified from least-significant to most-significant. For a given input sequence of keys and a set of rules specifying a total ordering of the symbolic alphabet, the radix sorting method produces a lexicographic ordering of those keys. + +Assumes threads are in row-major order. + + +BlockRadixSort can sort all of the built-in C++ numeric primitive types (`unsigned char`, `int`, `double`, etc.) as well as CUDA's `__half` half-precision floating-point type. User-defined types are supported as long as decomposer object is provided. + + +- Positive and negative zeros are considered equivalent, and will be treated + as such in the output. +- No special handling is implemented for NaN values; these are sorted + according to their bit representations after any transformations. + + +Although the direct radix sorting method can only be applied to unsigned integral types, BlockRadixSort is able to sort signed and floating-point types via simple bit-wise transformations that ensure lexicographic key ordering. + +These transformations must be considered when restricting the `[begin_bit, end_bit)` range, as the bitwise transformations will occur before the bit-range truncation. + +Any transformations applied to the keys prior to sorting are reversed while writing to the final output buffer. + + +To convert the input values into a radix-sortable bitwise representation, the following transformations take place prior to sorting: + +* For unsigned integral values, the keys are used directly. +* For signed integral values, the sign bit is inverted. +* For positive floating point values, the sign bit is inverted. +* For negative floating point values, the full key is inverted. + + +Unlike `DeviceRadixSort`, `BlockRadixSort` does not invert the input key bits when performing a descending sort. Instead, it has special logic to reverse the order of the keys while sorting. + + +BlockRadixSort is stable. For floating-point types -0.0 and +0.0 are considered equal and appear in the result in the same order as they appear in the input. + + + +* Performance is sensitive to the degree of data movement across the block. + + + + +The code snippet below illustrates a sort of 512 integer keys that are partitioned in a [blocked arrangement](../index.html#sec5sec3) across 128 threads where each thread owns 4 consecutive items. + +.. tab-set-code:: + +Suppose the set of input `thread_keys` across the block of threads is `{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`. The corresponding output `thread_keys` in those threads will be `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }`. + + +The `block/example_block_reduce_dyn_smem.cu` example illustrates usage of dynamically shared memory with BlockReduce and how to re-purpose the same memory region. + +This example can be easily adapted to the storage required by BlockRadixSort. + + + + + +KeyT type + + + +The thread block length in threads along the X dimension + + + +The number of items per thread + + + +**[optional]** ValueT type (default: cub::NullType, which indicates a keys-only sort) + + + +**[optional]** The number of radix bits per digit place (default: 4 bits) + + + +**[optional]** Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure (default: true for architectures SM35 and newer, false otherwise). + + + +**[optional]** The [cub::BlockScanAlgorithm](/library/api/cub::BlockScanAlgorithm) algorithm to use (default: [cub::BLOCK_SCAN_WARP_SCANS](/library/api/cub::BLOCK_SCAN_WARP_SCANS)) + + + +**[Optional]*8 Shared memory bank mode (default: `cudaSharedMemBankSizeFourByte`) + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockRadixSort inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockRadixSort::BlockRadixSort() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockRadixSort::BlockRadixSort( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockRadixSort::TempStorage) + + + + + +--- + +## Sorting (blocked arrangements) + +### Sort inline + + + + +Performs an ascending block-wide radix sort over a blocked arrangement of keys. + + +```cpp showLineNumbers={false} +void cub::BlockRadixSort::Sort( + KeyT (&keys)[ItemsPerThread], + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Keys to sort + + + +**[optional]** The beginning (least-significant) bit index needed for key comparison + + + +**[optional]** The past-the-end (most-significant) bit index needed for key comparison + + +**Example** + +The code snippet below illustrates a sort of 512 integer keys that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive keys. + +`{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`. The corresponding output `thread_keys` in those threads will be `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys each + using BlockRadixSort = cub::BlockRadixSort; + + // Allocate shared memory for BlockRadixSort + __shared__ typename BlockRadixSort::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_keys[4]; + ... + + // Collectively sort the keys + BlockRadixSort(temp_storage).Sort(thread_keys); +``` + + + + +Performs an ascending block-wide radix sort over a blocked arrangement of keys. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 2 keys that are partitioned in a blocked arrangement across 2 threads where each thread owns 1 key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-bits :end-before: example-end keys-bits + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::Sort( + KeyT (&keys)[ItemsPerThread], + DecomposerT decomposer, + int begin_bit, + int end_bit +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + + + +Performs an ascending block-wide radix sort over a blocked arrangement of keys. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 6 keys that are partitioned in a blocked arrangement across 2 threads where each thread owns 3 consecutive keys. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys :end-before: example-end keys + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::Sort( + KeyT (&keys)[ItemsPerThread], + DecomposerT decomposer +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + + + +Performs an ascending block-wide radix sort across a blocked arrangement of keys and values. + + +```cpp showLineNumbers={false} +void cub::BlockRadixSort::Sort( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +BlockRadixSort can only accommodate one associated tile of values. To "truck along" more than one tile of values, simply perform a key-value sort of the keys paired with a temporary value array that enumerates the key indices. The reordered indices can then be used as a gather-vector for exchanging other associated tile data through shared memory. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +**[optional]** The beginning (least-significant) bit index needed for key comparison + + + +**[optional]** The past-the-end (most-significant) bit index needed for key comparison + + +**Example** + +The code snippet below illustrates a sort of 512 integer keys and values that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive pairs. + +Suppose the set of input `thread_keys` across the block of threads is `{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`. The corresponding output `thread_keys` in those threads will be `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys and values each + using BlockRadixSort = cub::BlockRadixSort; + + // Allocate shared memory for BlockRadixSort + __shared__ typename BlockRadixSort::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_keys[4]; + int thread_values[4]; + ... + + // Collectively sort the keys and values among block threads + BlockRadixSort(temp_storage).Sort(thread_keys, thread_values); +``` + + + + +Performs an ascending block-wide radix sort over a blocked arrangement of keys and values. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 2 keys and values that are partitioned in a blocked arrangement across 2 threads where each thread owns 1 pair. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-bits :end-before: example-end pairs-bits + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::Sort( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + DecomposerT decomposer, + int begin_bit, + int end_bit +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* BlockRadixSort can only accommodate one associated tile of values. To "truck along" more than one tile of values, simply perform a key-value sort of the keys paired with a temporary value array that enumerates the key indices. The reordered indices can then be used as a gather-vector for exchanging other associated tile data through shared memory. * Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + + + +Performs an ascending block-wide radix sort over a blocked arrangement of keys and values. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 6 keys and values that are partitioned in a blocked arrangement across 2 threads where each thread owns 3 consecutive pairs. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs :end-before: example-end pairs + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::Sort( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + DecomposerT decomposer +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* BlockRadixSort can only accommodate one associated tile of values. To "truck along" more than one tile of values, simply perform a key-value sort of the keys paired with a temporary value array that enumerates the key indices. The reordered indices can then be used as a gather-vector for exchanging other associated tile data through shared memory. * Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + + + +### SortDescending inline + + + + +Performs a descending block-wide radix sort over a blocked arrangement of keys. + + +```cpp showLineNumbers={false} +void cub::BlockRadixSort::SortDescending( + KeyT (&keys)[ItemsPerThread], + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Keys to sort + + + +**[optional]** The beginning (least-significant) bit index needed for key comparison + + + +**[optional]** The past-the-end (most-significant) bit index needed for key comparison + + +**Example** + +The code snippet below illustrates a sort of 512 integer keys that are partitioned in a [blocked arrangement](../index.html#sec5sec3) across 128 threads where each thread owns 4 consecutive keys. + +`{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`. The corresponding output `thread_keys` in those threads will be `{ [511,510,509,508], [11,10,9,8], [7,6,5,4], ..., [3,2,1,0] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys each + using BlockRadixSort = cub::BlockRadixSort; + + // Allocate shared memory for BlockRadixSort + __shared__ typename BlockRadixSort::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_keys[4]; + ... + + // Collectively sort the keys + BlockRadixSort(temp_storage).Sort(thread_keys); +``` + + + + +Performs a descending block-wide radix sort over a blocked arrangement of keys. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 2 keys that are partitioned in a blocked arrangement across 2 threads where each thread owns 1 key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-descending-bits :end-before: example-end keys-descending-bits + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortDescending( + KeyT (&keys)[ItemsPerThread], + DecomposerT decomposer, + int begin_bit, + int end_bit +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + + + +Performs a descending block-wide radix sort over a blocked arrangement of keys. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 6 keys that are partitioned in a blocked arrangement across 2 threads where each thread owns 3 consecutive keys. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-descending :end-before: example-end keys-descending + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortDescending( + KeyT (&keys)[ItemsPerThread], + DecomposerT decomposer +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + + + +Performs a descending block-wide radix sort across a blocked arrangement of keys and values. + + +```cpp showLineNumbers={false} +void cub::BlockRadixSort::SortDescending( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +BlockRadixSort can only accommodate one associated tile of values. To "truck along" more than one tile of values, simply perform a key-value sort of the keys paired with a temporary value array that enumerates the key indices. The reordered indices can then be used as a gather-vector for exchanging other associated tile data through shared memory. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +**[optional]** The beginning (least-significant) bit index needed for key comparison + + + +**[optional]** The past-the-end (most-significant) bit index needed for key comparison + + +**Example** + +The code snippet below illustrates a sort of 512 integer keys and values that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive pairs. + +`{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`. The corresponding output `thread_keys` in those threads will be `{ [511,510,509,508], [11,10,9,8], [7,6,5,4], ..., [3,2,1,0] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys and values each + using BlockRadixSort = cub::BlockRadixSort; + + // Allocate shared memory for BlockRadixSort + __shared__ typename BlockRadixSort::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_keys[4]; + int thread_values[4]; + ... + + // Collectively sort the keys and values among block threads + BlockRadixSort(temp_storage).Sort(thread_keys, thread_values); +``` + + + + +Performs a descending block-wide radix sort over a blocked arrangement of keys and values. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 2 pairs that are partitioned in a blocked arrangement across 2 threads where each thread owns 1 pair. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-descending-bits :end-before: example-end pairs-descending-bits + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortDescending( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + DecomposerT decomposer, + int begin_bit, + int end_bit +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* BlockRadixSort can only accommodate one associated tile of values. To "truck along" more than one tile of values, simply perform a key-value sort of the keys paired with a temporary value array that enumerates the key indices. The reordered indices can then be used as a gather-vector for exchanging other associated tile data through shared memory. * Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + + + +Performs a descending block-wide radix sort over a blocked arrangement of keys and values. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 6 keys and values that are partitioned in a blocked arrangement across 2 threads where each thread owns 3 consecutive pairs. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-descending :end-before: example-end pairs-descending + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortDescending( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + DecomposerT decomposer +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* BlockRadixSort can only accommodate one associated tile of values. To "truck along" more than one tile of values, simply perform a key-value sort of the keys paired with a temporary value array that enumerates the key indices. The reordered indices can then be used as a gather-vector for exchanging other associated tile data through shared memory. * Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + + + +--- + +## Sorting (blocked arrangement -> striped arrangement) + +### SortBlockedToStriped inline + + + + +Performs an ascending radix sort across a blocked arrangement of keys, leaving them in a striped arrangement. + + +```cpp showLineNumbers={false} +void cub::BlockRadixSort::SortBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Keys to sort + + + +**[optional]** The beginning (least-significant) bit index needed for key comparison + + + +**[optional]** The past-the-end (most-significant) bit index needed for key comparison + + +**Example** + +The code snippet below illustrates a sort of 512 integer keys that are initially partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive keys. The final partitioning is striped. + +`{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`. The corresponding output `thread_keys` in those threads will be `{ [0,128,256,384], [1,129,257,385], [2,130,258,386], ..., [127,255,383,511] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys each + using BlockRadixSort = cub::BlockRadixSort; + + // Allocate shared memory for BlockRadixSort + __shared__ typename BlockRadixSort::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_keys[4]; + ... + + // Collectively sort the keys + BlockRadixSort(temp_storage).SortBlockedToStriped(thread_keys); +``` + + + + +Performs an ascending block-wide radix sort over a blocked arrangement of keys, leaving them in a striped arrangement. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 4 keys that are partitioned in a blocked arrangement across 2 threads where each thread owns 2 consecutive keys. The final partitioning is striped. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-striped-bits :end-before: example-end keys-striped-bits + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + DecomposerT decomposer, + int begin_bit, + int end_bit +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + + + +Performs an ascending block-wide radix sort over a blocked arrangement of keys, leaving them in a striped arrangement. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 6 keys that are partitioned in a blocked arrangement across 2 threads where each thread owns 3 consecutive keys. The final partitioning is striped. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-striped :end-before: example-end keys-striped + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + DecomposerT decomposer +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + + + +Performs an ascending radix sort across a blocked arrangement of keys and values, leaving them in a striped arrangement. + + +```cpp showLineNumbers={false} +void cub::BlockRadixSort::SortBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +BlockRadixSort can only accommodate one associated tile of values. To "truck along" more than one tile of values, simply perform a key-value sort of the keys paired with a temporary value array that enumerates the key indices. The reordered indices can then be used as a gather-vector for exchanging other associated tile data through shared memory. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +**[optional]** The beginning (least-significant) bit index needed for key comparison + + + +**[optional]** The past-the-end (most-significant) bit index needed for key comparison + + +**Example** + +The code snippet below illustrates a sort of 512 integer keys and values that are initially partitioned in a [blocked arrangement](../index.html#sec5sec3) across 128 threads where each thread owns 4 consecutive pairs. The final partitioning is striped. + +`{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`. The corresponding output `thread_keys` in those threads will be `{ [0,128,256,384], [1,129,257,385], [2,130,258,386], ..., [127,255,383,511] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys and values each + using BlockRadixSort = cub::BlockRadixSort; + + // Allocate shared memory for BlockRadixSort + __shared__ typename BlockRadixSort::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_keys[4]; + int thread_values[4]; + ... + + // Collectively sort the keys and values among block threads + BlockRadixSort(temp_storage).SortBlockedToStriped(thread_keys, thread_values); +``` + + + + +Performs an ascending block-wide radix sort over a blocked arrangement of keys and values, leaving them in a striped arrangement. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 4 pairs that are partitioned in a blocked arrangement across 2 threads where each thread owns 2 consecutive pairs. The final partitioning is striped. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-striped-bits :end-before: example-end pairs-striped-bits + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + DecomposerT decomposer, + int begin_bit, + int end_bit +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + + + +Performs an ascending block-wide radix sort over a blocked arrangement of keys and values, leaving them in a striped arrangement. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 6 pairs that are partitioned in a blocked arrangement across 2 threads where each thread owns 3 consecutive pairs. The final partitioning is striped. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-striped :end-before: example-end pairs-striped + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + DecomposerT decomposer +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + + + +### SortDescendingBlockedToStriped inline + + + + +Performs a descending radix sort across a blocked arrangement of keys, leaving them in a striped arrangement. + + +```cpp showLineNumbers={false} +void cub::BlockRadixSort::SortDescendingBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Keys to sort + + + +**[optional]** The beginning (least-significant) bit index needed for key comparison + + + +**[optional]** The past-the-end (most-significant) bit index needed for key comparison + + +**Example** + +The code snippet below illustrates a sort of 512 integer keys that are initially partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive keys. The final partitioning is striped. + +`{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`. The corresponding output `thread_keys` in those threads will be `{ [511,383,255,127], [386,258,130,2], [385,257,128,1], ..., [384,256,128,0] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys each + using BlockRadixSort = cub::BlockRadixSort; + + // Allocate shared memory for BlockRadixSort + __shared__ typename BlockRadixSort::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_keys[4]; + ... + + // Collectively sort the keys + BlockRadixSort(temp_storage).SortBlockedToStriped(thread_keys); +``` + + + + +Performs a descending block-wide radix sort over a blocked arrangement of keys, leaving them in a striped arrangement. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 4 keys that are partitioned in a blocked arrangement across 2 threads where each thread owns 2 consecutive keys. The final partitioning is striped. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-striped-descending-bits :end-before: example-end keys-striped-descending-bits + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortDescendingBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + DecomposerT decomposer, + int begin_bit, + int end_bit +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + + + +Performs a descending block-wide radix sort over a blocked arrangement of keys, leaving them in a striped arrangement. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 6 keys that are partitioned in a blocked arrangement across 2 threads where each thread owns 3 consecutive keys. The final partitioning is striped. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-striped-descending :end-before: example-end keys-striped-descending + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortDescendingBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + DecomposerT decomposer +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + + + +Performs a descending radix sort across a blocked arrangement of keys and values, leaving them in a striped arrangement + + +```cpp showLineNumbers={false} +void cub::BlockRadixSort::SortDescendingBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +BlockRadixSort can only accommodate one associated tile of values. To "truck along" more than one tile of values, simply perform a key-value sort of the keys paired with a temporary value array that enumerates the key indices. The reordered indices can then be used as a gather-vector for exchanging other associated tile data through shared memory. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +**[optional]** The beginning (least-significant) bit index needed for key comparison + + + +**[optional]** The past-the-end (most-significant) bit index needed for key comparison + + +**Example** + +The code snippet below illustrates a sort of 512 integer keys and values that are initially partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive pairs. The final partitioning is striped. + +`{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`. The corresponding output `thread_keys` in those threads will be `{ [511,383,255,127], [386,258,130,2], [385,257,128,1], ..., [384,256,128,0] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer keys and values each + using BlockRadixSort = cub::BlockRadixSort; + + // Allocate shared memory for BlockRadixSort + __shared__ typename BlockRadixSort::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_keys[4]; + int thread_values[4]; + ... + + // Collectively sort the keys and values among block threads + BlockRadixSort(temp_storage).SortBlockedToStriped(thread_keys, thread_values); +``` + + + + +Performs a descending block-wide radix sort over a blocked arrangement of keys and values, leaving them in a striped arrangement. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 4 keys and values that are partitioned in a blocked arrangement across 2 threads where each thread owns 2 consecutive pairs. The final partitioning is striped. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-striped-descending-bits :end-before: example-end pairs-striped-descending-bits + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortDescendingBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + DecomposerT decomposer, + int begin_bit, + int end_bit +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + + + +Performs a descending block-wide radix sort over a blocked arrangement of keys and values, leaving them in a striped arrangement. + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The code snippet below illustrates a sort of 6 keys and values that are partitioned in a blocked arrangement across 2 threads where each thread owns 3 consecutive pairs. The final partitioning is striped. + +.. literalinclude:: ../../../cub/test/catch2_test_block_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-striped-descending :end-before: example-end pairs-striped-descending + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t> cub::BlockRadixSort::SortDescendingBlockedToStriped( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + DecomposerT decomposer +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Performance is sensitive to the degree of data movement across the block. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockRadixSort::PrivateStorage() +``` + + +### RankKeys inline + + + + +Rank keys (specialized for ascending sort). + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixSort::RankKeys( + bit_ordered_type (&unsigned_keys)[ItemsPerThread], + int (&ranks)[ItemsPerThread], + DigitExtractorT digit_extractor, + ::cuda::std::false_type +) +``` + + + + + +Rank keys (specialized for descending sort). + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixSort::RankKeys( + bit_ordered_type (&unsigned_keys)[ItemsPerThread], + int (&ranks)[ItemsPerThread], + DigitExtractorT digit_extractor, + ::cuda::std::true_type +) +``` + + + + + +### ExchangeValues inline + + + + +ExchangeValues (specialized for key-value sort, to-blocked arrangement). + + +```cpp showLineNumbers={false} +void cub::BlockRadixSort::ExchangeValues( + ValueT (&values)[ItemsPerThread], + int (&ranks)[ItemsPerThread], + ::cuda::std::false_type, + ::cuda::std::true_type +) +``` + + + + + +ExchangeValues (specialized for key-value sort, to-striped arrangement). + + +```cpp showLineNumbers={false} +void cub::BlockRadixSort::ExchangeValues( + ValueT (&values)[ItemsPerThread], + int (&ranks)[ItemsPerThread], + ::cuda::std::false_type, + ::cuda::std::false_type +) +``` + + + + + +ExchangeValues (specialized for keys-only sort). + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixSort::ExchangeValues( + ValueT (&)[ItemsPerThread], + int (&)[ItemsPerThread], + ::cuda::std::true_type, + ::cuda::std::bool_constant +) +``` + + + + + +### SortBlocked inline + +Sort blocked arrangement. + + +```cpp showLineNumbers={false} +template +void cub::BlockRadixSort::SortBlocked( + KeyT (&keys)[ItemsPerThread], + ValueT (&values)[ItemsPerThread], + int begin_bit, + int end_bit, + ::cuda::std::bool_constant is_descending, + ::cuda::std::bool_constant is_keys_only, + DecomposerT decomposer = {} +) +``` + + +**Parameters** + + +Keys to sort + + + +Values to sort + + + +The beginning (least-significant) bit index needed for key comparison + + + +The past-the-end (most-significant) bit index needed for key comparison + + + +Tag whether is a descending-order sort + + + +Tag whether is keys-only sort + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `traits` | `detail::radix::traits_t< KeyT >` | | +| `bit_ordered_type` | `typename traits::bit_ordered_type` | | +| `bit_ordered_conversion` | `typename traits::bit_ordered_conversion_policy` | | +| `AscendingBlockRadixRank` | `BlockRadixRank< BlockDimX, RadixBits, false, MemoizeOuterScan, InnerScanAlgorithm, SMemConfig, BlockDimY, BlockDimZ >` | Ascending [BlockRadixRank](/library/api/cub::_block_radix_rank) utility type. | +| `DescendingBlockRadixRank` | `BlockRadixRank< BlockDimX, RadixBits, true, MemoizeOuterScan, InnerScanAlgorithm, SMemConfig, BlockDimY, BlockDimZ >` | Descending [BlockRadixRank](/library/api/cub::_block_radix_rank) utility type. | +| `fundamental_digit_extractor_t` | `BFEDigitExtractor< KeyT >` | Digit extractor type. | +| `BlockExchangeKeys` | `BlockExchange< KeyT, BlockDimX, ItemsPerThread, false, BlockDimY, BlockDimZ >` | [BlockExchange](/library/api/cub::_block_exchange) utility type for keys. | +| `BlockExchangeValues` | `BlockExchange< ValueT, BlockDimX, ItemsPerThread, false, BlockDimY, BlockDimZ >` | [BlockExchange](/library/api/cub::_block_exchange) utility type for values. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `KEYS_ONLY` static constexpr | `bool` | | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockRadixSort::TempStorage +``` + + +The operations exposed by `BlockRadixSort` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockRakingLayout.mdx b/fern/cudapages/cub/cub/cub/BlockRakingLayout.mdx new file mode 100644 index 0000000..8417726 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockRakingLayout.mdx @@ -0,0 +1,96 @@ +--- +title: cub::BlockRakingLayout +description: "" +--- + +BlockRakingLayout provides a conflict-free shared memory layout abstraction for 1D raking across thread block data. + + + + + +The data type to be exchanged. + + + +The thread block size in threads. + + + + + +--- + +## Static methods + +### PlacementPtr inline static + +Returns the location for the calling thread to place data into the grid. + + +```cpp showLineNumbers={false} +static T * cub::BlockRakingLayout::PlacementPtr( + TempStorage &temp_storage, + unsigned int linear_tid +) +``` + + +### RakingPtr inline static + +Returns the location for the calling thread to begin sequential raking. + + +```cpp showLineNumbers={false} +static T * cub::BlockRakingLayout::RakingPtr( + TempStorage &temp_storage, + unsigned int linear_tid +) +``` + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `SHARED_ELEMENTS` static constexpr | `int` | The total number of elements that need to be cooperatively reduced. | +| `MAX_RAKING_THREADS` static constexpr | `int` | Maximum number of warp-synchronous raking threads. | +| `SEGMENT_LENGTH` static constexpr | `int` | Number of raking elements per warp-synchronous raking thread (rounded up). | +| `RAKING_THREADS` static constexpr | `int` | Never use a raking thread that will have no valid data (e.g., when BlockThreads is 62 and SEGMENT_LENGTH is 2, we should only use 31 raking threads). | +| `HAS_CONFLICTS` static constexpr | `bool` | Whether we will have bank conflicts (technically we should find out if the GCD is > 1). | +| `CONFLICT_DEGREE` static constexpr | `int` | Degree of bank conflicts (e.g., 4-way). | +| `USE_SEGMENT_PADDING` static constexpr | `bool` | Pad each segment length with one element if segment length is not relatively prime to warp size and can't be optimized as a vector load. | +| `GRID_ELEMENTS` static constexpr | `int` | Total number of elements in the raking grid. | +| `UNGUARDED` static constexpr | `int` | Whether or not we need bounds checking during raking (the number of reduction elements is not a multiple of the number of raking threads). | + +--- + +## Inner classes + +### _TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockRakingLayout::_TempStorage +``` + + +Shared memory storage type. + +| Name | Type | Description | +|---|---|---| +| `buff` | `T` | | + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockRakingLayout::TempStorage +``` + + +Alias wrapper allowing storage to be unioned. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockReduce.mdx b/fern/cudapages/cub/cub/cub/BlockReduce.mdx new file mode 100644 index 0000000..2dbf17d --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockReduce.mdx @@ -0,0 +1,565 @@ +--- +title: cub::BlockReduce +description: "" +--- + +The BlockReduce class provides collective methods for computing a parallel reduction of items partitioned across a CUDA thread block. + +## Performance considerations + +- Performance is sensitive to the degree of data movement across the block. +- Very efficient (only one synchronization barrier). +- Incurs zero bank conflicts for most types +- Computation is slightly more efficient (i.e., having lower instruction overhead) for: + - Summation (vs. generic reduction) + - `BLOCK_THREADS` is a multiple of the architecture's warp size + - Every thread has a valid input (i.e., full vs. partial-tiles) +- See cub::BlockReduceAlgorithm for performance details regarding algorithmic alternatives + +## Example + +The code snippet below illustrates a sum reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data); +} +``` + + + + + +Data type being reduced + + + +The thread block length in threads along the X dimension + + + +**[optional]** [cub::BlockReduceAlgorithm](/library/api/cub::BlockReduceAlgorithm) enumerator specifying the underlying algorithm to use (default: [cub::BLOCK_REDUCE_WARP_REDUCTIONS](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS)) + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockReduce inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockReduce::BlockReduce() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockReduce::BlockReduce( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockReduce::TempStorage) + + + + + +--- + +## Generic reductions + +### Reduce inline + + + + +Computes a block-wide reduction for thread0 using the specified binary reduction functor. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +template +T cub::BlockReduce::Reduce( + T input, + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + + +The return value is undefined in threads other than thread0. + + +**Template parameters** + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction functor + + +**Example** + +The code snippet below illustrates a max reduction of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item + int thread_data; + ... + + // Compute the block-wide max for thread0 + int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cuda::maximum<>{}); +} +``` + + + + +Computes a block-wide reduction for thread0 using the specified binary reduction functor. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +T cub::BlockReduce::Reduce( + T (&inputs)[ITEMS_PER_THREAD], + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Performance is sensitive to the degree of data movement across the block. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + + +The return value is undefined in threads other than thread0. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input segment + + + +Binary reduction functor + + +**Example** + +The code snippet below illustrates a max reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Compute the block-wide max for thread0 + int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cuda::maximum<>{}); +} +``` + + + + +Computes a block-wide reduction for thread0 using the specified binary reduction functor. The first `num_valid` threads each contribute one input element. + + +```cpp showLineNumbers={false} +template +T cub::BlockReduce::Reduce( + T input, + ReductionOp reduction_op, + int num_valid +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + + +The return value is undefined in threads other than thread0. + + +**Template parameters** + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction functor + + + +Number of threads containing valid elements (may be less than BLOCK_THREADS) + + +**Example** + +The code snippet below illustrates a max reduction of a partially-full tile of integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int num_valid, ...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item + int thread_data; + if (threadIdx.x < num_valid) thread_data = ... + + // Compute the block-wide max for thread0 + int aggregate = BlockReduce(temp_storage).Reduce(thread_data, cuda::maximum<>{}, num_valid); +} +``` + + + + +--- + +## Summation reductions + +### Sum inline + + + + +Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +T cub::BlockReduce::Sum( + T input +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + + +The return value is undefined in threads other than thread0. + + +**Parameters** + + +Calling thread's input + + +**Example** + +The code snippet below illustrates a sum reduction of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item + int thread_data; + ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data); +} +``` + + + + +Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +T cub::BlockReduce::Sum( + T (&inputs)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Performance is sensitive to the degree of data movement across the block. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + + +The return value is undefined in threads other than thread0. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input segment + + +**Example** + +The code snippet below illustrates a sum reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data); +} +``` + + + + +Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. The first `num_valid` threads each contribute one input element. + + +```cpp showLineNumbers={false} +T cub::BlockReduce::Sum( + T input, + int num_valid +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + + +The return value is undefined in threads other than thread0. + + +**Parameters** + + +Calling thread's input + + + +Number of threads containing valid elements (may be less than BLOCK_THREADS) + + +**Example** + +The code snippet below illustrates a sum reduction of a partially-full tile of integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int num_valid, ...) +{ + // Specialize BlockReduce for a 1D block of 128 threads of type int + using BlockReduce = cub::BlockReduce; + + // Allocate shared memory for BlockReduce + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each thread obtains an input item (up to num_items) + int thread_data; + if (threadIdx.x < num_valid) + thread_data = ... + + // Compute the block-wide sum for thread0 + int aggregate = BlockReduce(temp_storage).Sum(thread_data, num_valid); +} +``` + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockReduce::PrivateStorage() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `WarpReductions` | `detail::BlockReduceWarpReductions< T, BlockDimX, BlockDimY, BlockDimZ >` | | +| `WarpReductionsNondeterministic` | `detail::BlockReduceWarpReductions< T, BlockDimX, BlockDimY, BlockDimZ, false >` | | +| `RakingCommutativeOnly` | `detail::BlockReduceRakingCommutativeOnly< T, BlockDimX, BlockDimY, BlockDimZ >` | | +| `Raking` | `detail::BlockReduceRaking< T, BlockDimX, BlockDimY, BlockDimZ >` | | +| `InternalBlockReduce` | `::cuda::std::_If< Algorithm==BLOCK_REDUCE_WARP_REDUCTIONS, WarpReductions, ::cuda::std::_If< Algorithm==BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC, WarpReductionsNondeterministic, ::cuda::std::_If< Algorithm==BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY, RakingCommutativeOnly, Raking > > >` | Internal specialization type. | +| `_TempStorage` | `typename InternalBlockReduce::TempStorage` | Shared memory storage layout type for `BlockReduce`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockReduce::TempStorage +``` + + +The operations exposed by `BlockReduce` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockRunLengthDecode.mdx b/fern/cudapages/cub/cub/cub/BlockRunLengthDecode.mdx new file mode 100644 index 0000000..e16de79 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockRunLengthDecode.mdx @@ -0,0 +1,296 @@ +--- +title: cub::BlockRunLengthDecode +description: "" +--- + +The BlockRunLengthDecode class supports decoding a run-length encoded array of items. That is, given the two arrays `run_value[N]` and `run_lengths[N]`, `run_value[i]` is repeated `run_lengths[i]` many times in the output array. Due to the nature of the run-length decoding algorithm ("decompression"), the output size of the run-length decoded array is runtime-dependent and potentially without any upper bound. To address this, BlockRunLengthDecode allows retrieving a "window" from the run-length decoded array. The window's offset can be specified and BLOCK_THREADS * DecodedItemsPerThread (i.e., referred to as window_size) decoded items from the specified window will be returned. + +.. note:: + + Trailing runs of length 0 are supported (i.e., they may only appear at the end of the run_lengths array). A run of length zero may not be followed by a run length that is not zero. + +Suppose the set of input `run_values` across the block of threads is `{ [0, 1], [2, 3], [4, 5], [6, 7], ..., [254, 255] }` and `run_lengths` is `{ [1, 2], [3, 4], [5, 1], [2, 3], ..., [5, 1] }`. The corresponding output `decoded_items` in those threads will be `{ [0, 1, 1, 2], [2, 2, 3, 3], [3, 3, 4, 4], [4, 4, 4, 5], ..., [169, 169, 170, 171] }` and `relative_offsets` will be `{ [0, 0, 1, 0], [1, 2, 0, 1], [2, 3, 0, 1], [2, 3, 4, 0], ..., [3, 4, 0, 0] }` during the first iteration of the while loop. + + + + + +The data type of the items being run-length decoded + + + +The thread block length in threads along the X dimension + + + +The number of consecutive runs that each thread contributes + + + +The maximum number of decoded items that each thread holds + + + +Type used to index into the block's decoded items (large enough to hold the sum over all the runs' lengths) + + + +The thread block length in threads along the Y dimension + + + +The thread block length in threads along the Z dimension + + + + + +--- + +## Constructors + +### BlockRunLengthDecode inline + + + + +Constructor specialised for user-provided temporary storage, initializing using the runs' lengths. + + +```cpp showLineNumbers={false} +template +cub::BlockRunLengthDecode::BlockRunLengthDecode( + TempStorage &temp_storage, + ItemT (&run_values)[RunsPerThread], + RunLengthT (&run_lengths)[RunsPerThread], + TotalDecodedSizeT &total_decoded_size +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Constructor specialised for user-provided temporary storage, initializing using the runs' offsets. + + +```cpp showLineNumbers={false} +template +cub::BlockRunLengthDecode::BlockRunLengthDecode( + TempStorage &temp_storage, + ItemT (&run_values)[RunsPerThread], + UserRunOffsetT (&run_offsets)[RunsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Constructor specialised for static temporary storage, initializing using the runs' lengths. + + +```cpp showLineNumbers={false} +template +cub::BlockRunLengthDecode::BlockRunLengthDecode( + ItemT (&run_values)[RunsPerThread], + RunLengthT (&run_lengths)[RunsPerThread], + TotalDecodedSizeT &total_decoded_size +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Constructor specialised for static temporary storage, initializing using the runs' offsets. + + +```cpp showLineNumbers={false} +template +cub::BlockRunLengthDecode::BlockRunLengthDecode( + ItemT (&run_values)[RunsPerThread], + UserRunOffsetT (&run_offsets)[RunsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +--- + +## Methods + +### PrivateStorage inline + +Internal storage allocator (used when the user does not provide pre-allocated shared memory). + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockRunLengthDecode::PrivateStorage() +``` + + +### StaticUpperBound inline + +Returns the offset of the first value within `input` which compares greater than `val`. + +This version takes `MAX_NUM_ITEMS`, an upper bound of the array size, which will be used to determine the number of binary search iterations at compile time. + + +```cpp showLineNumbers={false} +template +OffsetT cub::BlockRunLengthDecode::StaticUpperBound( + InputIteratorT input, + OffsetT num_items, + T val +) +``` + + +**Parameters** + + +Input sequence + + + +Input sequence length + + + +Search key + + +### InitWithRunOffsets inline + + +```cpp showLineNumbers={false} +template +void cub::BlockRunLengthDecode::InitWithRunOffsets( + ItemT (&run_values)[RunsPerThread], + RunOffsetT (&run_offsets)[RunsPerThread] +) +``` + + +### InitWithRunLengths inline + + +```cpp showLineNumbers={false} +template +void cub::BlockRunLengthDecode::InitWithRunLengths( + ItemT (&run_values)[RunsPerThread], + RunLengthT (&run_lengths)[RunsPerThread], + TotalDecodedSizeT &total_decoded_size +) +``` + + +### RunLengthDecode inline + + + + +Run-length decodes the runs previously passed via a call to Init(...) and returns the run-length decoded items in a blocked arrangement to `decoded_items`. + + +```cpp showLineNumbers={false} +template +void cub::BlockRunLengthDecode::RunLengthDecode( + ItemT (&decoded_items)[DecodedItemsPerThread], + RelativeOffsetT (&item_offsets)[DecodedItemsPerThread], + DecodedOffsetT from_decoded_offset = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +The run-length decoded items to be returned in a blocked arrangement + + + +The run-length decoded items' relative offset within the run they belong to + + + +If invoked with from_decoded_offset that is larger than total_decoded_size results in undefined behavior. + + + + + +Run-length decodes the runs previously passed via a call to Init(...) and returns the run-length decoded items in a blocked arrangement to `decoded_items`. + + +```cpp showLineNumbers={false} +void cub::BlockRunLengthDecode::RunLengthDecode( + ItemT (&decoded_items)[DecodedItemsPerThread], + DecodedOffsetT from_decoded_offset = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +The run-length decoded items to be returned in a blocked arrangement + + + +If invoked with from_decoded_offset that is larger than total_decoded_size results in undefined behavior. + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `RunOffsetScanT` | `BlockScan< DecodedOffsetT, BlockDimX, BLOCK_SCAN_RAKING_MEMOIZE, BlockDimY, BlockDimZ >` | [BlockScan](/library/api/cub::_block_scan) used to determine the beginning of each run (i.e., prefix sum over the runs' length). | +| `RunOffsetT` | `uint32_t` | Type used to index into the block's runs. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `BLOCK_RUNS` static constexpr | `int` | The number of runs that the block decodes (out-of-bounds items may be padded with run lengths of '0'). | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `uint32_t` | Linear thread-id. | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockRunLengthDecode::TempStorage +``` + + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockScan.mdx b/fern/cudapages/cub/cub/cub/BlockScan.mdx new file mode 100644 index 0000000..e36ca1d --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockScan.mdx @@ -0,0 +1,1797 @@ +--- +title: cub::BlockScan +description: "" +--- + +The BlockScan class provides collective methods for computing a parallel prefix sum/scan of items partitioned across a CUDA thread block. + +## Performance considerations + +- Performance is sensitive to the degree of data movement across the block. +- Uses special instructions when applicable (e.g., warp `SHFL`) +- Uses synchronization-free communication between warp lanes when applicable +- Invokes a minimal number of minimal block-wide synchronization barriers (only + one or two depending on algorithm selection) +- Incurs zero bank conflicts for most types +- Computation is slightly more efficient (i.e., having lower instruction overhead) for: + + - Prefix sum variants (vs. generic scan) + - `BLOCK_THREADS` is a multiple of the architecture's warp size + +- See cub::BlockScanAlgorithm for performance details regarding algorithmic alternatives + + + + + +Data type being scanned + + + +The thread block length in threads along the X dimension + + + +**[optional]** [cub::BlockScanAlgorithm](/library/api/cub::BlockScanAlgorithm) enumerator specifying the underlying algorithm to use (default: [cub::BLOCK_SCAN_RAKING](/library/api/cub::BLOCK_SCAN_RAKING)) + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockScan inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockScan::BlockScan() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockScan::BlockScan( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockScan::TempStorage) + + + + + +--- + +## Exclusive prefix sum operations + +### ExclusiveSum inline + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. The value of 0 is applied as the initial value, and is assigned to `output` in *thread*0. + + +```cpp showLineNumbers={false} +void cub::BlockScan::ExclusiveSum( + T input, + T &output +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses the identity element (zero) as the initial value. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. The value of 0 is applied as the initial value, and is assigned to `output` in *thread*0. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +void cub::BlockScan::ExclusiveSum( + T input, + T &output, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses the identity element (zero) as the initial value. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Block-wide aggregate reduction of input items + + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. Instead of using 0 as the block-wide prefix, the call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T input, + T &output, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses the identity element (zero) as the initial value. +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! + + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. The value of 0 is applied as the initial value, and is assigned to `output[0]` in *thread*0. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses the identity element (zero) as the initial value. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. The value of 0 is applied as the initial value, and is assigned to `output[0]` in *thread*0. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses the identity element (zero) as the initial value. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Block-wide aggregate reduction of input items + + + + + +Computes an exclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Instead of using 0 as the block-wide prefix, the call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses the identity element (zero) as the initial value. +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! + + + + + +--- + +## Exclusive prefix scan operations + +### ExclusiveScan inline + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T input, + T &output, + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Embed:rst:leading-asterisk +//! Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*\ :sub:`0`) +//! + + + +Binary scan functor + + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T input, + T &output, + T initial_value, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Embed:rst:leading-asterisk +//! Initial value to seed the exclusive scan (and is assigned to ``output[0]`` in *thread*\ :sub:`0`). It is not +//! taken into account for ``block_aggregate``. +//! +//! + + + +Binary scan functor + + + +Block-wide aggregate reduction of input items + + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T input, + T &output, + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + + +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! + + +**Example** + +The code snippet below illustrates a single thread block that progressively computes an exclusive prefix max scan over multiple "tiles" of input using a prefix functor to maintain a running total between block-wide scans. Each tile consists of 128 integer items that are partitioned across 128 threads. + +The corresponding output for the first segment will be `INT_MIN, 0, 0, 2, ..., 124, 126`. The output for the second segment will be `126, 128, 128, 130, ..., 252, 254`. + +```cpp showLineNumbers={false} +#include // or equivalently + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +struct BlockPrefixCallbackOp +{ + // Running prefix + int running_total; + + // Constructor + __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ int operator()(int block_aggregate) + { + int old_prefix = running_total; + running_total = (block_aggregate > old_prefix) ? block_aggregate : old_prefix; + return old_prefix; + } +}; + +__global__ void ExampleKernel(int *d_data, int num_items, ...) +{ + // Specialize BlockScan for a 1D block of 128 threads + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Initialize running total + BlockPrefixCallbackOp prefix_op(INT_MIN); + + // Have the block iterate over segments of items + for (int block_offset = 0; block_offset < num_items; block_offset += 128) + { + // Load a segment of consecutive items that are blocked across threads + int thread_data = d_data[block_offset + threadIdx.x]; + + // Collectively compute the block-wide exclusive prefix max scan + BlockScan(temp_storage).ExclusiveScan( + thread_data, thread_data, INT_MIN, cuda::maximum<>{}, prefix_op); + __syncthreads(); + + // Store scanned items to output segment + d_data[block_offset + threadIdx.x] = thread_data; + } +} +``` + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Embed:rst:leading-asterisk +//! Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*\ :sub:`0`) +//! + + + +Binary scan functor + + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Embed:rst:leading-asterisk +//! Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*\ :sub:`0`). It is not taken +//! into account for ``block_aggregate``. +//! + + + +Binary scan functor + + + +Block-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates an exclusive prefix max scan of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +`{ [0,-1,2,-3], [4,-5,6,-7], ..., [508,-509,510,-511] }`. The corresponding output `thread_data` in those threads will be `{ [INT_MIN,0,0,2], [2,4,4,6], ..., [506,508,508,510] }`. Furthermore the value `510` will be stored in `block_aggregate` for all threads. + +.. note:: + +`initial_value` is not applied to the block-wide aggregate. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute the block-wide exclusive prefix max scan + int block_aggregate; + BlockScan(temp_storage).ExclusiveScan( + thread_data, thread_data, INT_MIN, cuda::maximum<>{}, block_aggregate); +``` + + + + +Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::ExclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! + + + + + +--- + +## Inclusive prefix sum operations + +### InclusiveSum inline + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +void cub::BlockScan::InclusiveSum( + T input, + T &output +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +void cub::BlockScan::InclusiveSum( + T input, + T &output, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Block-wide aggregate reduction of input items + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. Instead of using 0 as the block-wide prefix, the call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T input, + T &output, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied +//! to the logical input sequence. +//! + + +**Example** + +The code snippet below illustrates a single thread block that progressively computes an inclusive prefix sum over multiple "tiles" of input using a prefix functor to maintain a running total between block-wide scans. Each tile consists of 128 integer items that are partitioned across 128 threads. + +The corresponding output for the first segment will be `1, 2, ..., 128`. The output for the second segment will be `129, 130, ..., 256`. + +```cpp showLineNumbers={false} +#include // or equivalently + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +struct BlockPrefixCallbackOp +{ + // Running prefix + int running_total; + + // Constructor + __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ int operator()(int block_aggregate) + { + int old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +__global__ void ExampleKernel(int *d_data, int num_items, ...) +{ + // Specialize BlockScan for a 1D block of 128 threads + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Initialize running total + BlockPrefixCallbackOp prefix_op(0); + + // Have the block iterate over segments of items + for (int block_offset = 0; block_offset < num_items; block_offset += 128) + { + // Load a segment of consecutive items that are blocked across threads + int thread_data = d_data[block_offset + threadIdx.x]; + + // Collectively compute the block-wide inclusive prefix sum + BlockScan(temp_storage).InclusiveSum( + thread_data, thread_data, prefix_op); + __syncthreads(); + + // Store scanned items to output segment + d_data[block_offset + threadIdx.x] = thread_data; + } +``` + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Block-wide aggregate reduction of input items + + + + + +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes an array of consecutive input elements. Instead of using 0 as the block-wide prefix, the call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveSum( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to the +//! logical input sequence. +//! + + + + + +--- + +## Inclusive prefix scan operations + +### InclusiveScan inline + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T input, + T &output, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T input, + T &output, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + + +Block-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates an inclusive prefix max scan of 128 integer items that are partitioned across 128 threads. + +`0, -1, 2, -3, ..., 126, -127`. The corresponding output `thread_data` in those threads will be `0, 0, 2, 2, ..., 126, 126`. Furthermore the value `126` will be stored in `block_aggregate` for all threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain input item for each thread + int thread_data; + ... + + // Collectively compute the block-wide inclusive prefix max scan + int block_aggregate; + BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}, block_aggregate); +``` + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes one input element. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T input, + T &output, + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor's input parameter The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item (may be aliased to `input`) + + + +Binary scan functor + + + +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the inclusive scan (uniform across block) + + + +Binary scan functor + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + +Block-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates an inclusive prefix max scan of 512 integer items that are partitioned in a [blocked arrangement](../index.html#sec5sec3) across 128 threads where each thread owns 4 consecutive items. + +`{ [0,-1,2,-3], [4,-5,6,-7], ..., [508,-509,510,-511] }`. The corresponding output `thread_data` in those threads will be `{ [0,0,2,2], [4,4,6,6], ..., [508,508,510,510] }`. Furthermore the value `510` will be stored in `block_aggregate` for all threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute the block-wide inclusive prefix max scan + int block_aggregate; + BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}, block_aggregate); +``` + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + T initial_value, + ScanOp scan_op, + T &block_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Initial value to seed the inclusive scan (uniform across block). It is not taken into account for `block_aggregate`. + + + +Binary scan functor + + + +Block-wide aggregate reduction of input items + + + + + +Computes an inclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. + + +```cpp showLineNumbers={false} +template +void cub::BlockScan::InclusiveScan( + T (&input)[ITEMS_PER_THREAD], + T (&output)[ITEMS_PER_THREAD], + ScanOp scan_op, + BlockPrefixCallbackOp &block_prefix_callback_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** The number of consecutive items partitioned onto each thread. + + + +**[inferred]** Binary scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Call-back functor type having member `T operator()(T block_aggregate)` + + +**Parameters** + + +Calling thread's input items + + + +Calling thread's output items (may be aliased to `input`) + + + +Binary scan functor + + + +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! + + +**Example** + +The code snippet below illustrates a single thread block that progressively computes an inclusive prefix max scan over multiple "tiles" of input using a prefix functor to maintain a running total between block-wide scans. Each tile consists of 128 integer items that are partitioned across 128 threads. + +The corresponding output for the first segment will be `0, 0, 2, 2, 4, 4, ..., 510, 510`. The output for the second segment will be `512, 512, 514, 514, 516, 516, ..., 1022, 1022`. + +```cpp showLineNumbers={false} +#include // or equivalently + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +struct BlockPrefixCallbackOp +{ + // Running prefix + int running_total; + + // Constructor + __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ int operator()(int block_aggregate) + { + int old_prefix = running_total; + running_total = (block_aggregate > old_prefix) ? block_aggregate : old_prefix; + return old_prefix; + } +}; + +__global__ void ExampleKernel(int *d_data, int num_items, ...) +{ + // Specialize BlockLoad, BlockStore, and BlockScan for a 1D block of 128 threads, 4 ints per thread + using BlockLoad = cub::BlockLoad ; + using BlockStore = cub::BlockStore ; + using BlockScan = cub::BlockScan ; + + // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan + __shared__ union { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + + // Initialize running total + BlockPrefixCallbackOp prefix_op(0); + + // Have the block iterate over segments of items + for (int block_offset = 0; block_offset < num_items; block_offset += 128 * 4) + { + // Load a segment of consecutive items that are blocked across threads + int thread_data[4]; + BlockLoad(temp_storage.load).Load(d_data + block_offset, thread_data); + __syncthreads(); + + // Collectively compute the block-wide inclusive prefix max scan + BlockScan(temp_storage.scan).InclusiveScan( + thread_data, thread_data, cuda::maximum<>{}, prefix_op); + __syncthreads(); + + // Store scanned items to output segment + BlockStore(temp_storage.store).Store(d_data + block_offset, thread_data); + __syncthreads(); + } +``` + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockScan::PrivateStorage() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `WarpScans` | `detail::BlockScanWarpScans< T, BlockDimX, BlockDimY, BlockDimZ >` | | +| `Raking` | `detail::BlockScanRaking< T, BlockDimX, BlockDimY, BlockDimZ,(SAFE_ALGORITHM==BLOCK_SCAN_RAKING_MEMOIZE)>` | | +| `InternalBlockScan` | `::cuda::std::_If< SAFE_ALGORITHM==BLOCK_SCAN_WARP_SCANS, WarpScans, Raking >` | Define the delegate type for the desired algorithm. | +| `_TempStorage` | `typename InternalBlockScan::TempStorage` | Shared memory storage layout type for `BlockScan`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `SAFE_ALGORITHM` static constexpr | `BlockScanAlgorithm` | Ensure the template parameterization meets the requirements of the specified algorithm. | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockScan::TempStorage +``` + + +The operations exposed by `BlockScan` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockScanRunningPrefixOp.mdx b/fern/cudapages/cub/cub/cub/BlockScanRunningPrefixOp.mdx new file mode 100644 index 0000000..111eb37 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockScanRunningPrefixOp.mdx @@ -0,0 +1,91 @@ +--- +title: cub::BlockScanRunningPrefixOp +description: "Stateful callback operator type for supplying [BlockScan](/library/api/cub::_block_scan) prefixes." +--- + +Stateful callback operator type for supplying [BlockScan](/library/api/cub::_block_scan) prefixes. + +Maintains a running prefix that can be applied to consecutive [BlockScan](/library/api/cub::_block_scan) operations. + + + + + +[BlockScan](/library/api/cub::_block_scan) value type + + + +Wrapped scan operator type + + + + + +--- + +## Constructors + +### BlockScanRunningPrefixOp inline + + + + +Constructor. + + +```cpp showLineNumbers={false} +cub::BlockScanRunningPrefixOp::BlockScanRunningPrefixOp( + ScanOpT op +) +``` + + + + + +Constructor. + + +```cpp showLineNumbers={false} +cub::BlockScanRunningPrefixOp::BlockScanRunningPrefixOp( + T starting_prefix, + ScanOpT op +) +``` + + + + + +--- + +## Methods + +### operator() inline + +Prefix callback operator. + +Returns the block-wide running_total in thread-0. + + +```cpp showLineNumbers={false} +T cub::BlockScanRunningPrefixOp::operator()( + const T &block_aggregate +) +``` + + +**Parameters** + + +The aggregate sum of the [BlockScan](/library/api/cub::_block_scan) inputs + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `op` | `ScanOpT` | Wrapped scan operator. | +| `running_total` | `T` | Running block-wide prefix. | diff --git a/fern/cudapages/cub/cub/cub/BlockShuffle.mdx b/fern/cudapages/cub/cub/cub/BlockShuffle.mdx new file mode 100644 index 0000000..3942086 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockShuffle.mdx @@ -0,0 +1,344 @@ +--- +title: cub::BlockShuffle +description: "" +--- + +The BlockShuffle class provides collective methods for shuffling data partitioned across a CUDA thread block. + + + + + +The data type to be exchanged. + + + +The thread block length in threads along the X dimension + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockShuffle inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockShuffle::BlockShuffle() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockShuffle::BlockShuffle( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockShuffle::TempStorage) + + + + + +--- + +## Shuffle movement + +### Offset inline + +Each *thread*i obtains the `input` provided by *thread*i + distance. The offset `distance` may be negative. + + +```cpp showLineNumbers={false} +void cub::BlockShuffle::Offset( + T input, + T &output, + int distance = 1 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Embed:rst:leading-asterisk +//! The input item from the calling thread (*thread*\ :sub:`i`) +//! + + + +Embed:rst:leading-asterisk +//! The ``input`` item from the successor (or predecessor) thread +//! *thread*\ :sub:`i + distance` (may be aliased to ``input``). +//! This value is only updated for for *thread*\ :sub:`i` when +//! ``0 <= (i + distance) < BLOCK_THREADS - 1`` +//! + + + +Offset distance (may be negative) + + +### Up inline + + + + +Each *thread*i obtains the `input` provided by *thread*i + distance. + + +```cpp showLineNumbers={false} +template +void cub::BlockShuffle::Up( + T (&input)[ITEMS_PER_THREAD], + T (&prev)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The calling thread's input item + + + +Embed:rst:leading-asterisk +//! The corresponding predecessor items (may be aliased to ``input``). +//! The item ``prev[0]`` is not updated for *thread*\ :sub:`0`. +//! + + + + + +The thread block rotates its blocked arrangement of `input` items, shifting it up by one item. All threads receive the `input` provided by *thread*BLOCK_THREADS - 1. + + +```cpp showLineNumbers={false} +template +void cub::BlockShuffle::Up( + T (&input)[ITEMS_PER_THREAD], + T (&prev)[ITEMS_PER_THREAD], + T &block_suffix +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The calling thread's input items + + + +Embed:rst:leading-asterisk +//! The corresponding predecessor items (may be aliased to ``input``). +//! The item ``prev[0]`` is not updated for *thread*\ :sub:`0`. +//! + + + +Embed:rst:leading-asterisk +//! The item ``input[ITEMS_PER_THREAD - 1]`` from *thread*\ :sub:`BLOCK_THREADS - 1`, provided to all threads +//! + + + + + +### Down inline + + + + +The thread block rotates its blocked arrangement of `input` items, shifting it down by one item. + + +```cpp showLineNumbers={false} +template +void cub::BlockShuffle::Down( + T (&input)[ITEMS_PER_THREAD], + T (&prev)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The calling thread's input items + + + +Embed:rst:leading-asterisk +//! The corresponding predecessor items (may be aliased to ``input``). +//! The value ``prev[0]`` is not updated for *thread*\ :sub:`BLOCK_THREADS - 1`. +//! + + + + + +The thread block rotates its blocked arrangement of input items, shifting it down by one item. All threads receive `input[0]` provided by *thread*0. + + +```cpp showLineNumbers={false} +template +void cub::BlockShuffle::Down( + T (&input)[ITEMS_PER_THREAD], + T (&prev)[ITEMS_PER_THREAD], + T &block_prefix +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The calling thread's input items + + + +Embed:rst:leading-asterisk +//! The corresponding predecessor items (may be aliased to ``input``). +//! The value ``prev[0]`` is not updated for *thread*\ :sub:`BLOCK_THREADS - 1`. +//! + + + +Embed:rst:leading-asterisk +//! The item ``input[0]`` from *thread*\ :sub:`0`, provided to all threads +//! + + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockShuffle::PrivateStorage() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `_TempStorage` | `T[BLOCK_THREADS]` | Shared memory storage layout type (last element from each thread's input). | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | | +| `LOG_WARP_THREADS` static constexpr | `int` | | +| `WARP_THREADS` static constexpr | `int` | | +| `WARPS` static constexpr | `int` | | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockShuffle::TempStorage +``` + + +The operations exposed by `BlockShuffle` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/BlockStore.mdx b/fern/cudapages/cub/cub/cub/BlockStore.mdx new file mode 100644 index 0000000..aa6978d --- /dev/null +++ b/fern/cudapages/cub/cub/cub/BlockStore.mdx @@ -0,0 +1,388 @@ +--- +title: cub::BlockStore +description: "" +--- + +The BlockStore class provides collective data movement methods for writing a blocked arrangement of items partitioned across a CUDA thread block to a linear segment of memory. + +## Example + +The code snippet below illustrates the storing of a "blocked" arrangement of 512 integers across 128 threads (where each thread owns 4 consecutive items) into a linear segment of memory. The store is specialized for `BLOCK_STORE_WARP_TRANSPOSE`, meaning items are locally reordered among threads so that memory references will be efficiently coalesced using a warp-striped access pattern. + +`{ [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }`. The output `d_data` will be `0, 1, 2, 3, 4, 5, ...`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + // Specialize BlockStore for a 1D block of 128 threads owning 4 integer items each + using BlockStore = cub::BlockStore; + + // Allocate shared memory for BlockStore + __shared__ typename BlockStore::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Store items to linear memory + BlockStore(temp_storage).Store(d_data, thread_data); +} +``` + + + + + +The type of data to be written. + + + +The thread block length in threads along the X dimension + + + +The number of consecutive items partitioned onto each thread. + + + + + + +**[optional]** The thread block length in threads along the Y dimension (default: 1) + + + +**[optional]** The thread block length in threads along the Z dimension (default: 1) + + + + + +--- + +## Collective constructors + +### BlockStore inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockStore::BlockStore() +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::BlockStore::BlockStore( + TempStorage &temp_storage +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockStore::TempStorage) + + + + + +--- + +## Data movement + +### Store inline + + + + +Store items into a linear segment of memory + + +```cpp showLineNumbers={false} +template +void cub::BlockStore::Store( + OutputIteratorT block_itr, + T (&items)[ItemsPerThread] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The thread block's base output iterator for storing to + + + +Data to store + + +**Example** + +The code snippet below illustrates the storing of a "blocked" arrangement of 512 integers across 128 threads (where each thread owns 4 consecutive items) into a linear segment of memory. The store is specialized for `BLOCK_STORE_WARP_TRANSPOSE`, meaning items are locally reordered among threads so that memory references will be efficiently coalesced using a warp-striped access pattern. + +`{ [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }`. The output `d_data` will be `0, 1, 2, 3, 4, 5, ...`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + // Specialize BlockStore for a 1D block of 128 threads owning 4 integer items each + using BlockStore = cub::BlockStore; + + // Allocate shared memory for BlockStore + __shared__ typename BlockStore::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Store items to linear memory + BlockStore(temp_storage).Store(d_data, thread_data); +} +``` + + + + +Store items into a linear segment of memory, guarded by range. + + +```cpp showLineNumbers={false} +template +void cub::BlockStore::Store( + OutputIteratorT block_itr, + T (&items)[ItemsPerThread], + int valid_items +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The thread block's base output iterator for storing to + + + +Data to store + + + +Number of valid items to write + + +**Example** + +The code snippet below illustrates the guarded storing of a "blocked" arrangement of 512 integers across 128 threads (where each thread owns 4 consecutive items) into a linear segment of memory. The store is specialized for `BLOCK_STORE_WARP_TRANSPOSE`, meaning items are locally reordered among threads so that memory references will be efficiently coalesced using a warp-striped access pattern. + +`{ [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }` and `valid_items` is `5`. The output `d_data` will be `0, 1, 2, 3, 4, ?, ?, ?, ...`, with only the first two threads being unmasked to store portions of valid data. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, int valid_items, ...) +{ + // Specialize BlockStore for a 1D block of 128 threads owning 4 integer items each + using BlockStore = cub::BlockStore; + + // Allocate shared memory for BlockStore + __shared__ typename BlockStore::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Store items to linear memory + BlockStore(temp_storage).Store(d_data, thread_data, valid_items); +} +``` + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockStore::PrivateStorage() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `InternalStore` | `StoreInternal< Algorithm, 0 >` | Internal load implementation to use. | +| `_TempStorage` | `typename InternalStore::TempStorage` | Shared memory storage layout type. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `temp_storage` | `_TempStorage &` | Thread reference to shared storage. | +| `linear_tid` | `int` | Linear thread-id. | + +--- + +## Inner classes + +### StoreInternal + + +```cpp showLineNumbers={false} +struct cub::BlockStore::StoreInternal +``` + + +Store helper. + +### StoreInternal< BLOCK_STORE_DIRECT, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::BlockStore::StoreInternal< BLOCK_STORE_DIRECT, DUMMY > +``` + + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | Linear thread-id. | + +### StoreInternal< BLOCK_STORE_STRIPED, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::BlockStore::StoreInternal< BLOCK_STORE_STRIPED, DUMMY > +``` + + +BLOCK_STORE_STRIPED specialization of store helper. + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | Linear thread-id. | + +### StoreInternal< BLOCK_STORE_VECTORIZE, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::BlockStore::StoreInternal< BLOCK_STORE_VECTORIZE, DUMMY > +``` + + +BLOCK_STORE_VECTORIZE specialization of store helper. + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | Linear thread-id. | + +### StoreInternal< BLOCK_STORE_TRANSPOSE, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::BlockStore::StoreInternal< BLOCK_STORE_TRANSPOSE, DUMMY > +``` + + +BLOCK_STORE_TRANSPOSE specialization of store helper. + +| Name | Type | Description | +|---|---|---| +| `temp_storage` | `_TempStorage &` | Thread reference to shared storage. | +| `linear_tid` | `int` | Linear thread-id. | + +### StoreInternal< BLOCK_STORE_WARP_TRANSPOSE, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::BlockStore::StoreInternal< BLOCK_STORE_WARP_TRANSPOSE, DUMMY > +``` + + +BLOCK_STORE_WARP_TRANSPOSE specialization of store helper. + +| Name | Type | Description | +|---|---|---| +| `WARP_THREADS` static constexpr | `int` | | +| `temp_storage` | `_TempStorage &` | Thread reference to shared storage. | +| `linear_tid` | `int` | Linear thread-id. | + +### StoreInternal< BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::BlockStore::StoreInternal< BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, DUMMY > +``` + + +BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED specialization of store helper. + +| Name | Type | Description | +|---|---|---| +| `WARP_THREADS` static constexpr | `int` | | +| `temp_storage` | `_TempStorage &` | Thread reference to shared storage. | +| `linear_tid` | `int` | Linear thread-id. | + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::BlockStore::TempStorage +``` + + +The operations exposed by `BlockStore` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/CacheModifiedInputIterator.mdx b/fern/cudapages/cub/cub/cub/CacheModifiedInputIterator.mdx new file mode 100644 index 0000000..fdda0b9 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/CacheModifiedInputIterator.mdx @@ -0,0 +1,265 @@ +--- +title: cub::CacheModifiedInputIterator +description: "A random-access input wrapper for dereferencing array values using a PTX cache load modifier." +--- + +A random-access input wrapper for dereferencing array values using a PTX cache load modifier. + +**Overview** + +- `CacheModifiedInputIterator` is a random-access input iterator that wraps a native device pointer of type `ValueType*`. `ValueType` references are made by reading `ValueType` values through loads modified by `MODIFIER`. +- Can be used to load any data type from memory using PTX cache load modifiers (e.g., "LOAD_LDG", "LOAD_CG", "LOAD_CA", "LOAD_CS", "LOAD_CV", etc.). +- Can be constructed, manipulated, and exchanged within and between host and device functions, but can only be dereferenced within device functions. +- Compatible with Thrust API v1.7 or newer. + +**Snippet** + +The code snippet below illustrates the use of `CacheModifiedInputIterator` to dereference a device array of double using the "ldg" PTX load modifier (i.e., load values through texture cache). + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize a device array +double *d_in; // e.g., [8.0, 6.0, 7.0, 5.0, 3.0, 0.0, 9.0] + +// Create an iterator wrapper +cub::CacheModifiedInputIterator itr(d_in); + +// Within device code: +printf("%f\n", itr[0]); // 8.0 +printf("%f\n", itr[1]); // 6.0 +printf("%f\n", itr[6]); // 9.0 +``` + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + + + + + +The value type of this iterator + + + +The difference type of this iterator (Default: `ptrdiff_t`) + + + + + +--- + +## Constructors + +### CacheModifiedInputIterator inline + +Constructor. + + +```cpp showLineNumbers={false} +template +cub::CacheModifiedInputIterator::CacheModifiedInputIterator( + QualifiedValueType *ptr +) +``` + + +--- + +## Methods + +### operator++ inline + + + + +Postfix increment. + + +```cpp showLineNumbers={false} +self_type cub::CacheModifiedInputIterator::operator++( + int +) +``` + + + + + +Prefix increment. + + +```cpp showLineNumbers={false} +self_type cub::CacheModifiedInputIterator::operator++() +``` + + + + + +### operator* inline const + +Indirection. + + +```cpp showLineNumbers={false} +reference cub::CacheModifiedInputIterator::operator*() const +``` + + +### operator+ inline const + +Addition. + + +```cpp showLineNumbers={false} +template +self_type cub::CacheModifiedInputIterator::operator+( + Distance n +) const +``` + + +### operator+= inline + +Addition assignment. + + +```cpp showLineNumbers={false} +template +self_type & cub::CacheModifiedInputIterator::operator+=( + Distance n +) +``` + + +### operator- inline const + + + + +Subtraction. + + +```cpp showLineNumbers={false} +template +self_type cub::CacheModifiedInputIterator::operator-( + Distance n +) const +``` + + + + + +Distance. + + +```cpp showLineNumbers={false} +difference_type cub::CacheModifiedInputIterator::operator-( + self_type other +) const +``` + + + + + +### operator-= inline + +Subtraction assignment. + + +```cpp showLineNumbers={false} +template +self_type & cub::CacheModifiedInputIterator::operator-=( + Distance n +) +``` + + +### operator[] inline const + +Array subscript. + + +```cpp showLineNumbers={false} +template +reference cub::CacheModifiedInputIterator::operator[]( + Distance n +) const +``` + + +### operator-> inline + +Structure dereference. + + +```cpp showLineNumbers={false} +pointer cub::CacheModifiedInputIterator::operator->() +``` + + +### operator== inline const + +Equal to. + + +```cpp showLineNumbers={false} +bool cub::CacheModifiedInputIterator::operator==( + const self_type &rhs +) const +``` + + +### operator!= inline const + +Not equal to. + + +```cpp showLineNumbers={false} +bool cub::CacheModifiedInputIterator::operator!=( + const self_type &rhs +) const +``` + + +### operator<< inline + +ostream operator + + +```cpp showLineNumbers={false} +friend::std::ostream & cub::CacheModifiedInputIterator::operator<<( + ::std::ostream &os, + const self_type & +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `self_type` | `CacheModifiedInputIterator` | My own type. | +| `difference_type` | `OffsetT` | Type to express the result of subtracting one iterator from another. | +| `value_type` | `ValueType` | The type of the element the iterator can point to. | +| `pointer` | `ValueType *` | The type of a pointer to an element the iterator can point to. | +| `reference` | `ValueType` | The type of a reference to an element the iterator can point to. | +| `iterator_category` | `THRUST_NS_QUALIFIER::detail::iterator_facade_category_t< THRUST_NS_QUALIFIER::device_system_tag, THRUST_NS_QUALIFIER::random_access_traversal_tag >` | | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `ptr` | `ValueType *` | Wrapped native pointer. | diff --git a/fern/cudapages/cub/cub/cub/CacheModifiedOutputIterator.mdx b/fern/cudapages/cub/cub/cub/CacheModifiedOutputIterator.mdx new file mode 100644 index 0000000..b0d3ad5 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/CacheModifiedOutputIterator.mdx @@ -0,0 +1,279 @@ +--- +title: cub::CacheModifiedOutputIterator +description: "A random-access output wrapper for storing array values using a PTX cache-modifier." +--- + +A random-access output wrapper for storing array values using a PTX cache-modifier. + +**Overview** + +- `CacheModifiedOutputIterator` is a random-access output iterator that wraps a native device pointer of type `ValueType*`. `ValueType` references are made by writing `ValueType` values through stores modified by `MODIFIER`. +- Can be used to store any data type to memory using PTX cache store modifiers (e.g., "STORE_WB", "STORE_CG", "STORE_CS", "STORE_WT", etc.). +- Can be constructed, manipulated, and exchanged within and between host and device functions, but can only be dereferenced within device functions. +- Compatible with Thrust API v1.7 or newer. + +**Snippet** + +The code snippet below illustrates the use of `CacheModifiedOutputIterator` to dereference a device array of doubles using the "wt" PTX load modifier (i.e., write-through to system memory). + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize a device array +double *d_out; // e.g., [, , , , , , ] + +// Create an iterator wrapper +cub::CacheModifiedOutputIterator itr(d_out); + +// Within device code: +itr[0] = 8.0; +itr[1] = 66.0; +itr[55] = 24.0; +``` + +**Usage Considerations** + +- Can only be dereferenced within device code + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + + + + + +The value type of this iterator + + + +The difference type of this iterator (Default: `ptrdiff_t`) + + + + + +--- + +## Constructors + +### CacheModifiedOutputIterator inline + + +```cpp showLineNumbers={false} +template +cub::CacheModifiedOutputIterator::CacheModifiedOutputIterator( + QualifiedValueType *ptr +) +``` + + +**Parameters** + + +Native pointer to wrap + + +--- + +## Methods + +### operator++ inline + + + + +Postfix increment. + + +```cpp showLineNumbers={false} +self_type cub::CacheModifiedOutputIterator::operator++( + int +) +``` + + + + + +Prefix increment. + + +```cpp showLineNumbers={false} +self_type cub::CacheModifiedOutputIterator::operator++() +``` + + + + + +### operator* inline const + +Indirection. + + +```cpp showLineNumbers={false} +reference cub::CacheModifiedOutputIterator::operator*() const +``` + + +### operator+ inline const + +Addition. + + +```cpp showLineNumbers={false} +template +self_type cub::CacheModifiedOutputIterator::operator+( + Distance n +) const +``` + + +### operator+= inline + +Addition assignment. + + +```cpp showLineNumbers={false} +template +self_type & cub::CacheModifiedOutputIterator::operator+=( + Distance n +) +``` + + +### operator- inline const + + + + +Subtraction. + + +```cpp showLineNumbers={false} +template +self_type cub::CacheModifiedOutputIterator::operator-( + Distance n +) const +``` + + + + + +Distance. + + +```cpp showLineNumbers={false} +difference_type cub::CacheModifiedOutputIterator::operator-( + self_type other +) const +``` + + + + + +### operator-= inline + +Subtraction assignment. + + +```cpp showLineNumbers={false} +template +self_type & cub::CacheModifiedOutputIterator::operator-=( + Distance n +) +``` + + +### operator[] inline const + +Array subscript. + + +```cpp showLineNumbers={false} +template +reference cub::CacheModifiedOutputIterator::operator[]( + Distance n +) const +``` + + +### operator== inline + +Equal to. + + +```cpp showLineNumbers={false} +bool cub::CacheModifiedOutputIterator::operator==( + const self_type &rhs +) +``` + + +### operator!= inline + +Not equal to. + + +```cpp showLineNumbers={false} +bool cub::CacheModifiedOutputIterator::operator!=( + const self_type &rhs +) +``` + + +### operator<< inline + +ostream operator + + +```cpp showLineNumbers={false} +friend::std::ostream & cub::CacheModifiedOutputIterator::operator<<( + ::std::ostream &os, + const self_type &itr +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `self_type` | `CacheModifiedOutputIterator` | My own type. | +| `difference_type` | `OffsetT` | Type to express the result of subtracting one iterator from another. | +| `value_type` | `void` | The type of the element the iterator can point to. | +| `pointer` | `void` | The type of a pointer to an element the iterator can point to. | +| `reference` | `Reference` | The type of a reference to an element the iterator can point to. | +| `iterator_category` | `THRUST_NS_QUALIFIER::detail::iterator_facade_category_t< THRUST_NS_QUALIFIER::device_system_tag, THRUST_NS_QUALIFIER::random_access_traversal_tag >` | The iterator category. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `ptr` | `ValueType *` | | + +--- + +## Inner classes + +### Reference + + +```cpp showLineNumbers={false} +struct cub::CacheModifiedOutputIterator::Reference +``` + + +| Name | Type | Description | +|---|---|---| +| `ptr` | `ValueType *` | | diff --git a/fern/cudapages/cub/cub/cub/CachingDeviceAllocator.mdx b/fern/cudapages/cub/cub/cub/CachingDeviceAllocator.mdx new file mode 100644 index 0000000..94091b9 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/CachingDeviceAllocator.mdx @@ -0,0 +1,253 @@ +--- +title: cub::CachingDeviceAllocator +description: "A simple caching allocator for device memory allocations." +--- + +A simple caching allocator for device memory allocations. + +**Overview** + +The allocator is thread-safe and stream-safe and is capable of managing cached device allocations on multiple devices. It behaves as follows: + +- Allocations from the allocator are associated with an `active_stream`. Once freed, the allocation becomes available immediately for reuse within the `active_stream` with which it was associated with during allocation, and it becomes available for reuse within other streams when all prior work submitted to `active_stream` has completed. +- Allocations are categorized and cached by bin size. A new allocation request of a given size will only consider cached allocations within the corresponding bin. +- Bin limits progress geometrically in accordance with the growth factor `bin_growth` provided during construction. Unused device allocations within a larger bin cache are not reused for allocation requests that categorize to smaller bin sizes. +- Allocation requests below ( `bin_growth` ^ `min_bin` ) are rounded up to ( `bin_growth` ^ `min_bin` ). +- Allocations above ( `bin_growth` ^ `max_bin` ) are not rounded up to the nearest bin and are simply freed when they are deallocated instead of being returned to a bin-cache. +- If the total storage of cached allocations on a given device will exceed `max_cached_bytes`, allocations for that device are simply freed when they are deallocated instead of being returned to their bin-cache. + +For example, the default-constructed `CachingDeviceAllocator` is configured with: +- `bin_growth` = 8 +- `min_bin` = 3 +- `max_bin` = 7 +- `max_cached_bytes` = 6MB - 1B + +which delineates five bin-sizes: 512B, 4KB, 32KB, 256KB, and 2MB and sets a maximum of 6,291,455 cached bytes per device + +--- + +## Constructors + +### CachingDeviceAllocator inline + + + + +Constructor. + + +```cpp showLineNumbers={false} +cub::CachingDeviceAllocator::CachingDeviceAllocator( + unsigned int bin_growth, + unsigned int min_bin = 1, + unsigned int max_bin = INVALID_BIN, + size_t max_cached_bytes = INVALID_SIZE, + bool skip_cleanup = false +) +``` + + +**Parameters** + + +Geometric growth factor for bin-sizes + + + +Minimum bin (default is bin_growth ^ 1) + + + +Maximum bin (default is no max bin) + + + +Maximum aggregate cached bytes per device (default is no limit) + + + +Whether or not to skip a call to [`FreeAllCached()`](/library/api/cub::_caching_device_allocator::FreeAllCached()) when the destructor is called (default is to deallocate) + + + + + +Default constructor. + +Configured with: + +- `bin_growth` = 8 +- `min_bin` = 3 +- `max_bin` = 7 +- `max_cached_bytes` = ( `bin_growth` ^ `max_bin`) * 3 ) - 1 = 6,291,455 bytes + +which delineates five bin-sizes: 512B, 4KB, 32KB, 256KB, and 2MB and sets a maximum of 6,291,455 cached bytes per device + + +```cpp showLineNumbers={false} +cub::CachingDeviceAllocator::CachingDeviceAllocator( + bool skip_cleanup = false, + bool debug = false +) +``` + + + + + +### Destructor + +### ~CachingDeviceAllocator inline virtual + +Destructor. + + +```cpp showLineNumbers={false} +virtual cub::CachingDeviceAllocator::~CachingDeviceAllocator() +``` + + +--- + +## Methods + +### SetMaxCachedBytes inline + +Sets the limit on the number bytes this allocator is allowed to cache per device. + +Changing the ceiling of cached bytes does not cause any allocations (in-use or cached-in-reserve) to be freed. See [`FreeAllCached()`](/library/api/cub::_caching_device_allocator::FreeAllCached()). + + +```cpp showLineNumbers={false} +cudaError_t cub::CachingDeviceAllocator::SetMaxCachedBytes( + size_t max_cached_bytes_ +) +``` + + +### DeviceAllocate inline + + + + +Provides a suitable allocation of device memory for the given size on the specified device. + +Once freed, the allocation becomes available immediately for reuse within the `active_stream` with which it was associated with during allocation, and it becomes available for reuse within other streams when all prior work submitted to `active_stream` has completed. + + +```cpp showLineNumbers={false} +cudaError_t cub::CachingDeviceAllocator::DeviceAllocate( + int device, + void **d_ptr, + size_t bytes, + cudaStream_t active_stream = 0 +) +``` + + +**Parameters** + + +Device on which to place the allocation + + + +Reference to pointer to the allocation + + + +Minimum number of bytes for the allocation + + + +The stream to be associated with this allocation + + + + + +Provides a suitable allocation of device memory for the given size on the current device. + +Once freed, the allocation becomes available immediately for reuse within the `active_stream` with which it was associated with during allocation, and it becomes available for reuse within other streams when all prior work submitted to `active_stream` has completed. + + +```cpp showLineNumbers={false} +cudaError_t cub::CachingDeviceAllocator::DeviceAllocate( + void **d_ptr, + size_t bytes, + cudaStream_t active_stream = 0 +) +``` + + +**Parameters** + + +Reference to pointer to the allocation + + + +Minimum number of bytes for the allocation + + + +The stream to be associated with this allocation + + + + + +### DeviceFree inline + + + + +Frees a live allocation of device memory on the specified device, returning it to the allocator. + +Once freed, the allocation becomes available immediately for reuse within the `active_stream` with which it was associated with during allocation, and it becomes available for reuse within other streams when all prior work submitted to `active_stream` has completed. + + +```cpp showLineNumbers={false} +cudaError_t cub::CachingDeviceAllocator::DeviceFree( + int device, + void *d_ptr +) +``` + + + + + +Frees a live allocation of device memory on the current device, returning it to the allocator. + +Once freed, the allocation becomes available immediately for reuse within the `active_stream` with which it was associated with during allocation, and it becomes available for reuse within other streams when all prior work submitted to `active_stream` has completed. + + +```cpp showLineNumbers={false} +cudaError_t cub::CachingDeviceAllocator::DeviceFree( + void *d_ptr +) +``` + + + + + +### FreeAllCached inline + +Frees all cached device allocations on all devices. + + +```cpp showLineNumbers={false} +cudaError_t cub::CachingDeviceAllocator::FreeAllCached() +``` + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `INVALID_BIN` static constexpr | `unsigned int` | Out-of-bounds bin. | +| `INVALID_SIZE` static constexpr | `size_t` | Invalid size. | diff --git a/fern/cudapages/cub/cub/cub/CastOp.mdx b/fern/cudapages/cub/cub/cub/CastOp.mdx new file mode 100644 index 0000000..acb2e6e --- /dev/null +++ b/fern/cudapages/cub/cub/cub/CastOp.mdx @@ -0,0 +1,32 @@ +--- +title: cub::CastOp +description: "Default cast functor." +--- + +Default cast functor. + + + + + + + + + + +--- + +## Methods + +### operator() inline const + +Cast operator, returns `(B) a`. + + +```cpp showLineNumbers={false} +template +B cub::CastOp::operator()( + A &&a +) const +``` + diff --git a/fern/cudapages/cub/cub/cub/ChainedPolicy.mdx b/fern/cudapages/cub/cub/cub/ChainedPolicy.mdx new file mode 100644 index 0000000..9ddf549 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/ChainedPolicy.mdx @@ -0,0 +1,84 @@ +--- +title: cub::ChainedPolicy +description: "Helper for dispatching into a policy chain." +--- + +Helper for dispatching into a policy chain. + + + + + + + + + + + + + + + + +--- + +## Methods + +### runtime_arch_to_compiletime inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::ChainedPolicy::runtime_arch_to_compiletime( + int device_ptx_version, + FunctorT &op +) +``` + + +### find_and_invoke_policy inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::ChainedPolicy::find_and_invoke_policy( + FunctorT &op +) +``` + + +--- + +## Static methods + +### Invoke inline static + +Specializes and dispatches op in accordance to the first policy in the chain of adequate PTX version. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::ChainedPolicy::Invoke( + int device_ptx_version, + FunctorT &op +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `ActivePolicy` | `typename ::cuda::std::_If<(CUB_PTX_ARCH< PolicyPtxVersion &&have_previous_policy), detail::get_active_policy< PrevPolicyT >, ::cuda::std::type_identity< PolicyT > >::type` | The policy for the active compiler pass. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `have_previous_policy` static constexpr | `bool` | | diff --git a/fern/cudapages/cub/cub/cub/DeviceAdjacentDifference.mdx b/fern/cudapages/cub/cub/cub/DeviceAdjacentDifference.mdx new file mode 100644 index 0000000..e7914dc --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceAdjacentDifference.mdx @@ -0,0 +1,524 @@ +--- +title: cub::DeviceAdjacentDifference +description: "" +--- + +DeviceAdjacentDifference provides device-wide, parallel operations for computing the differences of adjacent elements residing within device-accessible memory. + +## Example + +The code snippet below illustrates how to use `DeviceAdjacentDifference` to compute the left difference between adjacent elements. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +int num_items; // e.g., 8 +int *d_values; // e.g., [1, 2, 1, 2, 1, 2, 1, 2] +//... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; + +cub::DeviceAdjacentDifference::SubtractLeft( + d_temp_storage, temp_storage_bytes, d_values, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run operation +cub::DeviceAdjacentDifference::SubtractLeft( + d_temp_storage, temp_storage_bytes, d_values, num_items); + +// d_values <-- [1, 1, -1, 1, -1, 1, -1, 1] +``` + +--- + +## Methods + +### AdjacentDifference inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceAdjacentDifference::AdjacentDifference( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + NumItemsT num_items, + DifferenceOpT difference_op, + cudaStream_t stream +) +``` + + +--- + +## Static methods + +### SubtractLeftCopy inline static + +Subtracts the left element of each adjacent pair of elements residing within device-accessible memory + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceAdjacentDifference::SubtractLeftCopy( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + NumItemsT num_items, + DifferenceOpT difference_op = {}, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Embed:rst:leading-asterisk +//! is a model of `Input Iterator `_, +//! and ``x`` and ``y`` are objects of ``InputIteratorT``'s ``value_type``, then +//! ``x - y`` is defined, and ``InputIteratorT``'s ``value_type`` is convertible to +//! a type in ``OutputIteratorT``'s set of ``value_types``, and the return type +//! of ``x - y`` is convertible to a type in ``OutputIteratorT``'s set of +//! ``value_types``. +//! + + + +Embed:rst:leading-asterisk +//! is a model of `Output Iterator `_. +//! + + + +Its `result_type` is convertible to a type in `OutputIteratorT`'s set of `value_types`. + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence + + + +Pointer to the output sequence + + + +Number of items in the input sequence + + + +The binary function used to compute differences + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0` +//! + + +**Example** + +The code snippet below illustrates how to use `DeviceAdjacentDifference` to compute the difference between adjacent elements. + +```cpp showLineNumbers={false} +#include +// or equivalently + +struct CustomDifference +{ + template + __host__ DataType operator()(DataType &lhs, DataType &rhs) + { + return lhs - rhs; + } +}; + +// Declare, allocate, and initialize device-accessible pointers +int num_items; // e.g., 8 +int *d_input; // e.g., [1, 2, 1, 2, 1, 2, 1, 2] +int *d_output; +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; + +cub::DeviceAdjacentDifference::SubtractLeftCopy( + d_temp_storage, temp_storage_bytes, + d_input, d_output, + num_items, CustomDifference()); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run operation +cub::DeviceAdjacentDifference::SubtractLeftCopy( + d_temp_storage, temp_storage_bytes, + d_input, d_output, + num_items, CustomDifference()); + +// d_input <-- [1, 2, 1, 2, 1, 2, 1, 2] +// d_output <-- [1, 1, -1, 1, -1, 1, -1, 1] +``` + +### SubtractLeft inline static + +Subtracts the left element of each adjacent pair of elements residing within device-accessible memory. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceAdjacentDifference::SubtractLeft( + void *d_temp_storage, + size_t &temp_storage_bytes, + RandomAccessIteratorT d_input, + NumItemsT num_items, + DifferenceOpT difference_op = {}, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Embed:rst:leading-asterisk +//! is a model of `Random Access Iterator `_, +//! ``RandomAccessIteratorT`` is mutable. If ``x`` and ``y`` are objects of +//! ``RandomAccessIteratorT``'s ``value_type``, and ``x - y`` is defined, then the +//! return type of ``x - y`` should be convertible to a type in +//! ``RandomAccessIteratorT``'s set of ``value_types``. +//! + + + +Its `result_type` is convertible to a type in `RandomAccessIteratorT`'s set of `value_types`. + + + +**[inferred]** Type of `num_items` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence and the result + + + +Number of items in the input sequence + + + +The binary function used to compute differences + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates how to use `DeviceAdjacentDifference` to compute the difference between adjacent elements. + +```cpp showLineNumbers={false} +#include +// or equivalently + +struct CustomDifference +{ + template + __host__ DataType operator()(DataType &lhs, DataType &rhs) + { + return lhs - rhs; + } +}; + +// Declare, allocate, and initialize device-accessible pointers +int num_items; // e.g., 8 +int *d_data; // e.g., [1, 2, 1, 2, 1, 2, 1, 2] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceAdjacentDifference::SubtractLeft( + d_temp_storage, temp_storage_bytes, + d_data, num_items, CustomDifference()); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run operation +cub::DeviceAdjacentDifference::SubtractLeft( + d_temp_storage, temp_storage_bytes, + d_data, num_items, CustomDifference()); + +// d_data <-- [1, 1, -1, 1, -1, 1, -1, 1] +``` + +### SubtractRightCopy inline static + +Subtracts the right element of each adjacent pair of elements residing within device-accessible memory. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceAdjacentDifference::SubtractRightCopy( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + NumItemsT num_items, + DifferenceOpT difference_op = {}, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Embed:rst:leading-asterisk +//! is a model of `Input Iterator `_, +//! and ``x`` and ``y`` are objects of ``InputIteratorT``'s ``value_type``, then +//! ``x - y`` is defined, and ``InputIteratorT``'s ``value_type`` is convertible to +//! a type in ``OutputIteratorT``'s set of ``value_types``, and the return type +//! of ``x - y`` is convertible to a type in ``OutputIteratorT``'s set of +//! ``value_types``. +//! + + + +Embed:rst:leading-asterisk +//! is a model of `Output Iterator `_. +//! + + + +Its `result_type` is convertible to a type in `RandomAccessIteratorT`'s set of `value_types`. + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence + + + +Pointer to the output sequence + + + +Number of items in the input sequence + + + +The binary function used to compute differences. + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates how to use `DeviceAdjacentDifference` to compute the difference between adjacent elements. + +```cpp showLineNumbers={false} +#include +// or equivalently + +struct CustomDifference +{ + template + __host__ DataType operator()(DataType &lhs, DataType &rhs) + { + return lhs - rhs; + } +}; + +// Declare, allocate, and initialize device-accessible pointers +int num_items; // e.g., 8 +int *d_input; // e.g., [1, 2, 1, 2, 1, 2, 1, 2] +int *d_output; +.. + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceAdjacentDifference::SubtractRightCopy( + d_temp_storage, temp_storage_bytes, + d_input, d_output, num_items, CustomDifference()); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run operation +cub::DeviceAdjacentDifference::SubtractRightCopy( + d_temp_storage, temp_storage_bytes, + d_input, d_output, num_items, CustomDifference()); + +// d_input <-- [1, 2, 1, 2, 1, 2, 1, 2] +// d_data <-- [-1, 1, -1, 1, -1, 1, -1, 2] +``` + +### SubtractRight inline static + +Subtracts the right element of each adjacent pair of elements residing within device-accessible memory. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceAdjacentDifference::SubtractRight( + void *d_temp_storage, + size_t &temp_storage_bytes, + RandomAccessIteratorT d_input, + NumItemsT num_items, + DifferenceOpT difference_op = {}, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Embed:rst:leading-asterisk +//! is a model of `Random Access Iterator `_, +//! ``RandomAccessIteratorT`` is mutable. If ``x`` and ``y`` are objects of +//! ``RandomAccessIteratorT``'s `value_type`, and ``x - y`` is defined, then the +//! return type of ``x - y`` should be convertible to a type in +//! ``RandomAccessIteratorT``'s set of ``value_types``. +//! + + + +Its `result_type` is convertible to a type in `RandomAccessIteratorT`'s set of `value_types`. + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence + + + +Number of items in the input sequence + + + +The binary function used to compute differences + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates how to use `DeviceAdjacentDifference` to compute the difference between adjacent elements. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +int num_items; // e.g., 8 +int *d_data; // e.g., [1, 2, 1, 2, 1, 2, 1, 2] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceAdjacentDifference::SubtractRight( + d_temp_storage, temp_storage_bytes, d_data, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run operation +cub::DeviceAdjacentDifference::SubtractRight( + d_temp_storage, temp_storage_bytes, d_data, num_items); + +// d_data <-- [-1, 1, -1, 1, -1, 1, -1, 2] +``` diff --git a/fern/cudapages/cub/cub/cub/DeviceCopy.mdx b/fern/cudapages/cub/cub/cub/DeviceCopy.mdx new file mode 100644 index 0000000..345a2cf --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceCopy.mdx @@ -0,0 +1,235 @@ +--- +title: cub::DeviceCopy +description: "[cub::DeviceCopy](/library/api/cub::_device_copy) provides device-wide, parallel operations for copying data." +--- + +`cub::DeviceCopy` provides device-wide, parallel operations for copying data. + +--- + +## Static methods + +### Batched inline static + +Copies data from a batch of given source ranges to their corresponding destination ranges. + +.. note:: + +If any input range aliases any output range the behavior is undefined. If any output range aliases another output range the behavior is undefined. Input ranges can alias one another. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceCopy::Batched( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIt input_it, + OutputIt output_it, + SizeIteratorT sizes, + ::cuda::std::int64_t num_ranges, + cudaStream_t stream = nullptr +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** Device-accessible random-access input iterator type providing the iterators to the source ranges + + + +**[inferred]** Device-accessible random-access input iterator type providing the iterators to the destination ranges + + + +**[inferred]** Device-accessible random-access input iterator type providing the number of items to be copied for each pair of ranges + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible iterator providing the iterators to the source ranges + + + +Device-accessible iterator providing the iterators to the destination ranges + + + +Device-accessible iterator providing the number of elements to be copied for each pair of ranges + + + +The total number of range pairs + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates usage of DeviceCopy::Batched to perform a DeviceRunLength Decode operation. + +```cpp showLineNumbers={false} +struct GetIteratorToRange +{ + __host__ __device__ __forceinline__ auto operator()(uint32_t index) + { + return thrust::make_constant_iterator(d_data_in[index]); + } + int32_t *d_data_in; +}; + +struct GetPtrToRange +{ + __host__ __device__ __forceinline__ auto operator()(uint32_t index) + { + return d_data_out + d_offsets[index]; + } + int32_t *d_data_out; + uint32_t *d_offsets; +}; + +struct GetRunLength +{ + __host__ __device__ __forceinline__ uint32_t operator()(uint32_t index) + { + return d_offsets[index + 1] - d_offsets[index]; + } + uint32_t *d_offsets; +}; + +uint32_t num_ranges = 5; +int32_t *d_data_in; // e.g., [4, 2, 7, 3, 1] +int32_t *d_data_out; // e.g., [0, ... ] +uint32_t *d_offsets; // e.g., [0, 2, 5, 6, 9, 14] + +// Returns a constant iterator to the element of the i-th run +thrust::counting_iterator iota(0); +auto iterators_in = thrust::make_transform_iterator(iota, GetIteratorToRange{d_data_in}); + +// Returns the run length of the i-th run +auto sizes = thrust::make_transform_iterator(iota, GetRunLength{d_offsets}); + +// Returns pointers to the output range for each run +auto ptrs_out = thrust::make_transform_iterator(iota, GetPtrToRange{d_data_out, d_offsets}); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceCopy::Batched(d_temp_storage, temp_storage_bytes, iterators_in, ptrs_out, sizes, +num_ranges); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run batched copy algorithm (used to perform runlength decoding) +cub::DeviceCopy::Batched(d_temp_storage, temp_storage_bytes, iterators_in, ptrs_out, sizes, +num_ranges); + +// d_data_out <-- [4, 4, 2, 2, 2, 7, 3, 3, 3, 1, 1, 1, 1, 1] +``` + +### Copy inline static nodiscard + +Copies data from a multidimensional source mdspan to a destination mdspan. + +This function performs a parallel copy operation between two mdspan objects with potentially different layouts but identical extents. The copy operation handles arbitrary-dimensional arrays and automatically manages layout transformations. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceCopy::Copy( + void *d_temp_storage, + size_t &temp_storage_bytes, + ::cuda::std::mdspan mdspan_in, + ::cuda::std::mdspan mdspan_out, + ::cudaStream_t stream = nullptr +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Returns:** embed:rst:leading-asterisk +//! **cudaSuccess** on success, **cudaErrorInvalidValue** if mdspan extents don't match, or error code on failure +//! + +**Template parameters** + + +**[inferred]** The element type of the source mdspan + + + +**[inferred]** The extents type of the source mdspan + + + +**[inferred]** The layout type of the source mdspan + + + +**[inferred]** The accessor type of the source mdspan + + + +**[inferred]** The element type of the destination mdspan + + + +**[inferred]** The extents type of the destination mdspan + + + +**[inferred]** The layout type of the destination mdspan + + + +**[inferred]** The accessor type of the destination mdspan + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Source mdspan containing the data to be copied + + + +Destination mdspan where the data will be copied + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + diff --git a/fern/cudapages/cub/cub/cub/DeviceFind.mdx b/fern/cudapages/cub/cub/cub/DeviceFind.mdx new file mode 100644 index 0000000..d777461 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceFind.mdx @@ -0,0 +1,270 @@ +--- +title: cub::DeviceFind +description: "" +--- + +--- + +## Static methods + +### FindIf inline static + +Finds the first element in the input sequence that satisfies the given predicate. + +The code snippet below illustrates the finding of the first element that satisfies the predicate. + +.. literalinclude:: ../../../cub/test/catch2_test_device_find_if_api.cu :language: c++ :dedent: :start-after: example-begin find-if-predicate :end-before: example-end find-if-predicate + +.. literalinclude:: ../../../cub/test/catch2_test_device_find_if_api.cu :language: c++ :dedent: :start-after: example-begin device-find-if :end-before: example-end device-find-if + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFind::FindIf( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + + +The search terminates at the first element where the predicate evaluates to true. +The index of the found element is written to `d_out`. +If no element satisfies the predicate, `num_items` is written to `d_out`. +The range `[d_out, d_out + 1)` shall not overlap `[d_in, d_in + num_items)` in any way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing the result index (may be a simple pointer type) + + + +**[inferred]** Unary predicate functor type having member `bool operator()(const T &a)` + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output location for the index of the found element + + + +Unary predicate functor for determining whether an element satisfies the search condition + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +### LowerBound inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFind::LowerBound( + void *d_temp_storage, + size_t &temp_storage_bytes, + RangeIteratorT d_range, + RangeNumItemsT range_num_items, + ValuesIteratorT d_values, + ValuesNumItemsT values_num_items, + OutputIteratorT d_output, + CompareOpT comp, + cudaStream_t stream = 0 +) +``` + + +**Template parameters** + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator), whose value type forms a [Relation](https://en.cppreference.com/w/cpp/concepts/relation) with the value type of `ValuesIteratorT` using `CompareOpT` as the predicate. + + + +Is an integral type representing the number of elements in the range to be searched. + + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator), whose value type forms a [Relation](https://en.cppreference.com/w/cpp/concepts/relation) with the value type of `RangeIteratorT` using `CompareOpT` as the predicate. + + + +Is a model of integral type representing the number of elements in the range of values to be searched for. + + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator), whose value type is assignable from `RangeIteratorT`'s difference type. + + + +Is a model of [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order), which forms a [Relation](https://en.cppreference.com/w/cpp/concepts/relation) with the value types of `RangeIteratorT` and `ValuesIteratorT`. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Iterator to the beginning of the ordered range to be searched. + + + +Number of elements in the ordered range to be searched. + + + +Iterator to the beginning of the range of values to be searched for. + + + +Number of elements in the range of values to be searched for. + + + +Iterator to the beginning of the output range. + + + +Comparison function object which returns true if its first argument is ordered before the second in the [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order) of the range to be searched. + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +### UpperBound inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFind::UpperBound( + void *d_temp_storage, + size_t &temp_storage_bytes, + RangeIteratorT d_range, + RangeNumItemsT range_num_items, + ValuesIteratorT d_values, + ValuesNumItemsT values_num_items, + OutputIteratorT d_output, + CompareOpT comp, + cudaStream_t stream = 0 +) +``` + + +**Template parameters** + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator), whose value type forms a [Relation](https://en.cppreference.com/w/cpp/concepts/relation) with the value type of `ValuesIteratorT` using `CompareOpT` as the predicate. + + + +Is an integral type representing the number of elements in the range to be searched. + + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator), whose value type forms a [Relation](https://en.cppreference.com/w/cpp/concepts/relation) with the value type of `RangeIteratorT` using `CompareOpT` as the predicate. + + + +Is a model of integral type representing the number of elements in the range of values to be searched for. + + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator), whose value type is assignable from `RangeIteratorT`'s difference type. + + + +Is a model of [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order), which forms a [Relation](https://en.cppreference.com/w/cpp/concepts/relation) with the value types of `RangeIteratorT` and `ValuesIteratorT`. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Iterator to the beginning of the ordered range to be searched. + + + +Number of elements in the ordered range to be searched. + + + +Iterator to the beginning of the range of values to be searched for. + + + +Number of elements in the range of values to be searched for. + + + +Iterator to the beginning of the output range. + + + +Comparison function object which returns true if its first argument is ordered before the second in the [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order) of the range to be searched. + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + diff --git a/fern/cudapages/cub/cub/cub/DeviceFor.mdx b/fern/cudapages/cub/cub/cub/DeviceFor.mdx new file mode 100644 index 0000000..06c81f1 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceFor.mdx @@ -0,0 +1,728 @@ +--- +title: cub::DeviceFor +description: "" +--- + +--- + +## Methods + +### for_each_n inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::for_each_n( + RandomAccessOrContiguousIteratorT first, + OffsetT num_items, + OpT op, + cudaStream_t stream +) +``` + + +### ForEachNNoNVTX inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEachNNoNVTX( + RandomAccessIteratorT first, + NumItemsT num_items, + OpT op, + cudaStream_t stream = {} +) +``` + + +### ForEachCopyNNoNVTX inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEachCopyNNoNVTX( + RandomAccessIteratorT first, + NumItemsT num_items, + OpT op, + cudaStream_t stream = {} +) +``` + + +--- + +## Static methods + +### Bulk inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::Bulk( + void *d_temp_storage, + size_t &temp_storage_bytes, + ShapeT shape, + OpT op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Template parameters** + + +Is an integral type + + + +Is a model of [Unary Function](https://en.cppreference.com/w/cpp/utility/functional/unary_function) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Shape of the index space to iterate over + + + +Function object to apply to each index in the index space + + + +CUDA stream to launch kernels within. Default stream is `0`. + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::Bulk( + ShapeT shape, + OpT op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Template parameters** + + +Is an integral type + + + +Is a model of [Unary Function](https://en.cppreference.com/w/cpp/utility/functional/unary_function) + + +**Parameters** + + +Shape of the index space to iterate over + + + +Function object to apply to each index in the index space + + + +CUDA stream to launch kernels within. Default stream is `0`. + + + + + +### ForEachN inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEachN( + void *d_temp_storage, + size_t &temp_storage_bytes, + RandomAccessIteratorT first, + NumItemsT num_items, + OpT op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Template parameters** + + +Is a model of Random Access Iterator whose value type is convertible to `op`'s argument type. + + + +Is an integral type representing the number of elements to iterate over + + + +Is a model of [Unary Function](https://en.cppreference.com/w/cpp/utility/functional/unary_function) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The beginning of the sequence + + + +Number of elements to iterate over + + + +Function object to apply to each element in the range + + + +CUDA stream to launch kernels within. Default stream is `0`. + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEachN( + RandomAccessIteratorT first, + NumItemsT num_items, + OpT op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Template parameters** + + +Is a model of Random Access Iterator whose value type is convertible to `op`'s argument type. + + + +Is an integral type representing the number of elements to iterate over + + + +Is a model of [Unary Function](https://en.cppreference.com/w/cpp/utility/functional/unary_function) + + +**Parameters** + + +The beginning of the sequence + + + +Number of elements to iterate over + + + +Function object to apply to each element in the range + + + +CUDA stream to launch kernels within. Default stream is `0`. + + + + + +### ForEach inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEach( + void *d_temp_storage, + size_t &temp_storage_bytes, + RandomAccessIteratorT first, + RandomAccessIteratorT last, + OpT op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Template parameters** + + +Is a model of Random Access Iterator whose value type is convertible to `op`'s argument type. + + + +Is a model of [Unary Function](https://en.cppreference.com/w/cpp/utility/functional/unary_function) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The beginning of the sequence + + + +The end of the sequence + + + +Function object to apply to each element in the range + + + +CUDA stream to launch kernels within. Default stream is `0`. + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEach( + RandomAccessIteratorT first, + RandomAccessIteratorT last, + OpT op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Template parameters** + + +Is a model of Random Access Iterator whose value type is convertible to `op`'s argument type. + + + +Is a model of [Unary Function](https://en.cppreference.com/w/cpp/utility/functional/unary_function) + + +**Parameters** + + +The beginning of the sequence + + + +The end of the sequence + + + +Function object to apply to each element in the range + + + +CUDA stream to launch kernels within. Default stream is `0`. + + + + + +### ForEachCopyN inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEachCopyN( + void *d_temp_storage, + size_t &temp_storage_bytes, + RandomAccessIteratorT first, + NumItemsT num_items, + OpT op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Template parameters** + + +Is a model of Random Access Iterator whose value type is convertible to `op`'s argument type. + + + +Is an integral type representing the number of elements to iterate over + + + +Is a model of [Unary Function](https://en.cppreference.com/w/cpp/utility/functional/unary_function) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The beginning of the sequence + + + +Number of elements to iterate over + + + +Function object to apply to a copy of each element in the range + + + +CUDA stream to launch kernels within. Default stream is `0`. + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEachCopyN( + RandomAccessIteratorT first, + NumItemsT num_items, + OpT op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Template parameters** + + +Is a model of Random Access Iterator whose value type is convertible to `op`'s argument type. + + + +Is an integral type representing the number of elements to iterate over + + + +Is a model of [Unary Function](https://en.cppreference.com/w/cpp/utility/functional/unary_function) + + +**Parameters** + + +The beginning of the sequence + + + +Number of elements to iterate over + + + +Function object to apply to a copy of each element in the range + + + +CUDA stream to launch kernels within. Default stream is `0`. + + + + + +### ForEachCopy inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEachCopy( + void *d_temp_storage, + size_t &temp_storage_bytes, + RandomAccessIteratorT first, + RandomAccessIteratorT last, + OpT op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Template parameters** + + +Is a model of Random Access Iterator whose value type is convertible to `op`'s argument type. + + + +Is a model of [Unary Function](https://en.cppreference.com/w/cpp/utility/functional/unary_function) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The beginning of the sequence + + + +The end of the sequence + + + +Function object to apply to a copy of each element in the range + + + +CUDA stream to launch kernels within. Default stream is `0`. + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEachCopy( + RandomAccessIteratorT first, + RandomAccessIteratorT last, + OpT op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Template parameters** + + +Is a model of Random Access Iterator whose value type is convertible to `op`'s argument type. + + + +Is a model of [Unary Function](https://en.cppreference.com/w/cpp/utility/functional/unary_function) + + +**Parameters** + + +The beginning of the sequence + + + +The end of the sequence + + + +Function object to apply to a copy of each element in the range + + + +CUDA stream to launch kernels within. Default stream is `0`. + + + + + +### ForEachInExtents inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEachInExtents( + void *d_temp_storage, + size_t &temp_storage_bytes, + const ::cuda::std::extents &extents, + OpType op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Returns:** cudaError_t error status + +**Template parameters** + + +Is an integral type that represents the extent index space (automatically deduced) + + + +Are the extent sizes for each rank index (automatically deduced) + + + +Is a function object with arity equal to the number of extents + 1 for the linear index (iteration) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Extents object that represents a multi-dimensional index space + + + +Function object to apply to each linear index (iteration) and multi-dimensional coordinates + + + +CUDA stream to launch kernels within. Default stream is `NULL` + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEachInExtents( + const ::cuda::std::extents &extents, + OpType op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Returns:** cudaError_t error status + +**Template parameters** + + +Is an integral type that represents the extent index space (automatically deduced) + + + +Are the extent sizes for each rank index (automatically deduced) + + + +Is a function object with arity equal to the number of extents + 1 for the linear index (iteration) + + +**Parameters** + + +Extents object that represents a multi-dimensional index space + + + +Function object to apply to each linear index (iteration) and multi-dimensional coordinates + + + +CUDA stream to launch kernels within. Default stream is `NULL` + + + + + +### ForEachInLayout inline static nodiscard + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceFor::ForEachInLayout( + const LayoutMapping &layout_mapping, + OpType op, + cudaStream_t stream = {} +) +``` + + +*Added in v2.4.0. First appears in CUDA Toolkit 12.5.* + +**Returns:** cudaError_t error status + +**Template parameters** + + +**[inferred]** A function object with arity equal to the number of extents + 1 for the linear index (iteration). The first parameter is the linear index, followed by one parameter for each dimension coordinate. + + +**Parameters** + + +Function object to apply to each linear index (iteration) and multi-dimensional coordinates. Called as `op(linear_index, coord_0, coord_1, ..., coord_n)` + + + +CUDA stream to launch kernels within. Default stream is `nullptr` + diff --git a/fern/cudapages/cub/cub/cub/DeviceHistogram.mdx b/fern/cudapages/cub/cub/cub/DeviceHistogram.mdx new file mode 100644 index 0000000..60fde64 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceHistogram.mdx @@ -0,0 +1,1313 @@ +--- +title: cub::DeviceHistogram +description: "" +--- + +DeviceHistogram provides device-wide parallel operations for constructing histogram(s) from a sequence of samples data residing within device-accessible memory. + +--- + +## Evenly-segmented bin ranges + +### HistogramEven inline static + + + + +Computes an intensity histogram from a sequence of data samples using equal-width bins. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::HistogramEven( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + CounterT *d_histogram, + int num_levels, + LevelT lower_level, + LevelT upper_level, + OffsetT num_samples, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The number of histogram bins is (`num_levels - 1`) +All bins comprise the same width of sample values: `(upper_level - lower_level) / (num_levels - 1)`. +If the common type of `SampleT` and `LevelT` is of integral type, the bin for a sample is computed as `(sample - lower_level) * (num_levels - 1) / (upper_level - lower_level)`, round down to the nearest whole number. To protect against potential overflows, if the product `(upper_level - lower_level) * (num_levels - 1)` exceeds the number representable by an `uint64_t`, the cuda error `cudaErrorInvalidValue` is returned. If the common type is 128 bits wide, bin computation will use 128-bit arithmetic and `cudaErrorInvalidValue` will only be returned if bin computation would overflow for 128-bit arithmetic. +The ranges `[d_samples, d_samples + num_samples)` and `[d_histogram, d_histogram + num_levels - 1)` shall not overlap in any way. +`cuda::std::common_type` must be valid, and both LevelT and SampleT must be valid arithmetic types. The common type must be convertible to `int` and trivially copyable. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input samples (may be a simple pointer type) + + + +**[inferred]** Integer type for histogram bin counters + + + +**[inferred]** Type for specifying boundaries (levels) + + + +**[inferred]** Signed integer type for sequence offsets, list lengths, pointer differences, etc. (Consider using 32-bit values as offsets/lengths/etc. For example, `int` will typically yield better performance than `size_t` in 64-bit memory mode.) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the input sequence of data samples. + + + +The pointer to the histogram counter output array of length `num_levels - 1`. + + + +The number of boundaries (levels) for delineating histogram samples. Implies that the number of bins is `num_levels - 1`. + + + +The lower sample value bound (inclusive) for the lowest histogram bin. + + + +The upper sample value bound (exclusive) for the highest histogram bin. + + + +The number of input samples (i.e., the length of `d_samples`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the computation of a six-bin histogram from a sequence of float samples + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input samples and output histogram +int num_samples; // e.g., 10 +float* d_samples; // e.g., [2.2, 6.1, 7.1, 2.9, 3.5, 0.3, 2.9, 2.1, 6.1, 999.5] +int* d_histogram; // e.g., [ -, -, -, -, -, -] +int num_levels; // e.g., 7 (seven level boundaries for six bins) +float lower_level; // e.g., 0.0 (lower sample value boundary of lowest bin) +float upper_level; // e.g., 12.0 (upper sample value boundary of upper bin) +... + +// Determine temporary device storage requirements +void* d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceHistogram::HistogramEven( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, + lower_level, upper_level, num_samples); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Compute histograms +cub::DeviceHistogram::HistogramEven( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, + lower_level, upper_level, num_samples); + +// d_histogram <-- [1, 5, 0, 3, 0, 0]; +``` + + + + +Computes an intensity histogram from a sequence of data samples using equal-width bins. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::HistogramEven( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + CounterT *d_histogram, + int num_levels, + LevelT lower_level, + LevelT upper_level, + OffsetT num_row_samples, + OffsetT num_rows, + size_t row_stride_bytes, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +A two-dimensional *region of interest* within `d_samples` can be specified using the `num_row_samples`, `num_rows`, and `row_stride_bytes` parameters. +The row stride must be a whole multiple of the sample data type size, i.e., `(row_stride_bytes % sizeof(SampleT)) == 0`. +The number of histogram bins is (`num_levels - 1`) +All bins comprise the same width of sample values: `(upper_level - lower_level) / (num_levels - 1)` +If the common type of `SampleT` and `LevelT` is of integral type, the bin for a sample is computed as `(sample - lower_level) * (num_levels - 1) / (upper_level - lower_level)`, round down to the nearest whole number. To protect against potential overflows, if the product `(upper_level - lower_level) * (num_levels - 1)` exceeds the number representable by an `uint64_t`, the cuda error `cudaErrorInvalidValue` is returned. If the common type is 128 bits wide, bin computation will use 128-bit arithmetic and `cudaErrorInvalidValue` will only be returned if bin computation would overflow for 128-bit arithmetic. +For a given row `r` in `[0, num_rows)`, let `row_begin = d_samples + r * row_stride_bytes / sizeof(SampleT)` and `row_end = row_begin + num_row_samples`. The ranges `[row_begin, row_end)` and `[d_histogram, d_histogram + num_levels - 1)` shall not overlap in any way. +`cuda::std::common_type` must be valid, and both LevelT and SampleT must be valid arithmetic types. The common type must be convertible to `int` and trivially copyable. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input samples. (may be a simple pointer type) + + + +**[inferred]** Integer type for histogram bin counters + + + +**[inferred]** Type for specifying boundaries (levels) + + + +**[inferred]** Signed integer type for sequence offsets, list lengths, pointer differences, etc. (Consider using 32-bit values as offsets/lengths/etc. For example, `int` will typically yield better performance than `size_t` in 64-bit memory mode.) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the input sequence of data samples. + + + +The pointer to the histogram counter output array of length `num_levels - 1`. + + + +The number of boundaries (levels) for delineating histogram samples. Implies that the number of bins is `num_levels - 1`. + + + +The lower sample value bound (inclusive) for the lowest histogram bin. + + + +The upper sample value bound (exclusive) for the highest histogram bin. + + + +The number of data samples per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of bytes between starts of consecutive rows in the region of interest + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the computation of a six-bin histogram from a 2x5 region of interest within a flattened 2x7 array of float samples. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input samples and output histogram +int num_row_samples; // e.g., 5 +int num_rows; // e.g., 2; +size_t row_stride_bytes; // e.g., 7 * sizeof(float) +float* d_samples; // e.g., [2.2, 6.1, 7.1, 2.9, 3.5, -, -, + // 0.3, 2.9, 2.1, 6.1, 999.5, -, -] +int* d_histogram; // e.g., [ -, -, -, -, -, -] +int num_levels; // e.g., 7 (seven level boundaries for six bins) +float lower_level; // e.g., 0.0 (lower sample value boundary of lowest bin) +float upper_level; // e.g., 12.0 (upper sample value boundary of upper bin) +... + +// Determine temporary device storage requirements +void* d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceHistogram::HistogramEven( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, lower_level, upper_level, + num_row_samples, num_rows, row_stride_bytes); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Compute histograms +cub::DeviceHistogram::HistogramEven( + d_temp_storage, temp_storage_bytes, d_samples, d_histogram, + d_samples, d_histogram, num_levels, lower_level, upper_level, + num_row_samples, num_rows, row_stride_bytes); + +// d_histogram <-- [1, 5, 0, 3, 0, 0]; +``` + + + + +### MultiHistogramEven inline static + + + + +Computes per-channel intensity histograms from a sequence of multi-channel "pixel" data samples using equal-width bins. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::MultiHistogramEven( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_histogram, + ::cuda::std::array num_levels, + ::cuda::std::array lower_level, + ::cuda::std::array upper_level, + OffsetT num_pixels, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The input is a sequence of *pixel* structures, where each pixel comprises a record of `NUM_CHANNELS` consecutive data samples (e.g., an *RGBA* pixel). +`NUM_CHANNELS` can be up to 4. +Of the `NUM_CHANNELS` specified, the function will only compute histograms for the first `NUM_ACTIVE_CHANNELS` (e.g., only *RGB* histograms from *RGBA* pixel samples). +The number of histogram bins for channeli is `num_levels[i] - 1`. +For channeli, the range of values for all histogram bins have the same width: `(upper_level[i] - lower_level[i]) / (num_levels[i] - 1)` +If the common type of sample and level is of integral type, the bin for a sample is computed as `(sample - lower_level[i]) * (num_levels - 1) / (upper_level[i] - lower_level[i])`, round down to the nearest whole number. To protect against potential overflows, if, for any channel `i`, the product `(upper_level[i] - lower_level[i]) * (num_levels[i] - 1)` exceeds the number representable by an `uint64_t`, the cuda error `cudaErrorInvalidValue` is returned. If the common type is 128 bits wide, bin computation will use 128-bit arithmetic and `cudaErrorInvalidValue` will only be returned if bin computation would overflow for 128-bit arithmetic. +For a given channel `c` in `[0, NUM_ACTIVE_CHANNELS)`, the ranges `[d_samples, d_samples + NUM_CHANNELS * num_pixels)` and `[d_histogram[c], d_histogram[c] + num_levels[c] - 1)` shall not overlap in any way. +`cuda::std::common_type` must be valid, and both LevelT and SampleT must be valid arithmetic types. The common type must be convertible to `int` and trivially copyable. +@devicestorage + + +**Template parameters** + + +Number of channels interleaved in the input data (may be greater than the number of channels being actively histogrammed) + + + +**[inferred]** Number of channels actively being histogrammed + + + +**[inferred]** Random-access input iterator type for reading input samples. (may be a simple pointer type) + + + +**[inferred]** Integer type for histogram bin counters + + + +**[inferred]** Type for specifying boundaries (levels) + + + +**[inferred]** Signed integer type for sequence offsets, list lengths, pointer differences, etc. (Consider using 32-bit values as offsets/lengths/etc. For example, `int` will typically yield better performance than `size_t` in 64-bit memory mode.) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four *RGBA* 8-bit samples). + + + +Embed:rst:leading-asterisk +//! The pointers to the histogram counter output arrays, one for each active +//! channel. For channel\ :sub:`i`, the allocation length of +//! ``d_histogram[i]`` should be `num_levels[i] - 1``. +//! + + + +Embed:rst:leading-asterisk +//! The number of boundaries (levels) for delineating histogram samples in each active channel. +//! Implies that the number of bins for channel\ :sub:`i` is ``num_levels[i] - 1``. +//! + + + +The lower sample value bound (inclusive) for the lowest histogram bin in each active channel. + + + +The upper sample value bound (exclusive) for the highest histogram bin in each active channel. + + + +The number of multi-channel pixels (i.e., the length of `d_samples / NUM_CHANNELS`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the computation of three 256-bin *RGB* histograms from a quad-channel sequence of *RGBA* pixels (8 bits per channel per pixel) + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input samples and output histograms +int num_pixels; // e.g., 5 +unsigned char* d_samples; // e.g., [(2, 6, 7, 5), (3, 0, 2, 1), (7, 0, 6, 2), + // (0, 6, 7, 5), (3, 0, 2, 6)] +int* d_histogram[3]; // e.g., three device pointers to three device buffers, + // each allocated with 256 integer counters +int num_levels[3]; // e.g., {257, 257, 257}; +unsigned int lower_level[3]; // e.g., {0, 0, 0}; +unsigned int upper_level[3]; // e.g., {256, 256, 256}; +... + +// Determine temporary device storage requirements +void* d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceHistogram::MultiHistogramEven<4, 3>( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, + lower_level, upper_level, num_pixels); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Compute histograms +cub::DeviceHistogram::MultiHistogramEven<4, 3>( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, + lower_level, upper_level, num_pixels); + +// d_histogram <-- [ [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0, ..., 0], +// [0, 3, 0, 0, 0, 0, 2, 0, 0, 0, 0, ..., 0], +// [0, 0, 2, 0, 0, 0, 1, 2, 0, 0, 0, ..., 0] ] +``` + + + + +Deprecate [Since 3.0]. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::MultiHistogramEven( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + CounterT *d_histogram[NUM_ACTIVE_CHANNELS], + const int num_levels[NUM_ACTIVE_CHANNELS], + const LevelT lower_level[NUM_ACTIVE_CHANNELS], + const LevelT upper_level[NUM_ACTIVE_CHANNELS], + OffsetT num_pixels, + cudaStream_t stream = 0 +) +``` + + + + + +Computes per-channel intensity histograms from a sequence of multi-channel "pixel" data samples using equal-width bins. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::MultiHistogramEven( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_histogram, + ::cuda::std::array num_levels, + ::cuda::std::array lower_level, + ::cuda::std::array upper_level, + OffsetT num_row_pixels, + OffsetT num_rows, + size_t row_stride_bytes, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The input is a sequence of *pixel* structures, where each pixel comprises a record of `NUM_CHANNELS` consecutive data samples (e.g., an *RGBA* pixel). +`NUM_CHANNELS` can be up to 4. +Of the `NUM_CHANNELS` specified, the function will only compute histograms for the first `NUM_ACTIVE_CHANNELS` (e.g., only *RGB* histograms from *RGBA* pixel samples). +A two-dimensional *region of interest* within `d_samples` can be specified using the `num_row_samples`, `num_rows`, and `row_stride_bytes` parameters. +The row stride must be a whole multiple of the sample data type size, i.e., `(row_stride_bytes % sizeof(SampleT)) == 0`. +The number of histogram bins for channeli is `num_levels[i] - 1`. +For channeli, the range of values for all histogram bins have the same width: `(upper_level[i] - lower_level[i]) / (num_levels[i] - 1)` +If the common type of sample and level is of integral type, the bin for a sample is computed as `(sample - lower_level[i]) * (num_levels - 1) / (upper_level[i] - lower_level[i])`, round down to the nearest whole number. To protect against potential overflows, if, for any channel `i`, the product `(upper_level[i] - lower_level[i]) * (num_levels[i] - 1)` exceeds the number representable by an `uint64_t`, the cuda error `cudaErrorInvalidValue` is returned. If the common type is 128 bits wide, bin computation will use 128-bit arithmetic and `cudaErrorInvalidValue` will only be returned if bin computation would overflow for 128-bit arithmetic. +For a given row `r` in `[0, num_rows)`, and sample `s` in `[0, num_row_pixels)`, let `row_begin = d_samples + r * row_stride_bytes / sizeof(SampleT)`, `sample_begin = row_begin + s * NUM_CHANNELS`, and `sample_end = sample_begin + NUM_ACTIVE_CHANNELS`. For a given channel `c` in `[0, NUM_ACTIVE_CHANNELS)`, the ranges `[sample_begin, sample_end)` and `[d_histogram[c], d_histogram[c] + num_levels[c] - 1)` shall not overlap in any way. +`cuda::std::common_type` must be valid, and both LevelT and SampleT must be valid arithmetic types. The common type must be convertible to `int` and trivially copyable. +@devicestorage + + +**Template parameters** + + +Number of channels interleaved in the input data (may be greater than the number of channels being actively histogrammed) + + + +**[inferred]** Number of channels actively being histogrammed + + + +**[inferred]** Random-access input iterator type for reading input samples. (may be a simple pointer type) + + + +**[inferred]** Integer type for histogram bin counters + + + +**[inferred]** Type for specifying boundaries (levels) + + + +**[inferred]** Signed integer type for sequence offsets, list lengths, pointer differences, etc. (Consider using 32-bit values as offsets/lengths/etc. For example, `int` will typically yield better performance than `size_t` in 64-bit memory mode.) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four *RGBA* 8-bit samples). + + + +Embed:rst:leading-asterisk +//! The pointers to the histogram counter output arrays, one for each +//! active channel. For channel\ :sub:`i`, the allocation length +//! of ``d_histogram[i]`` should be ``num_levels[i] - 1``. +//! + + + +Embed:rst:leading-asterisk +//! The number of boundaries (levels) for delineating histogram samples in each active channel. +//! Implies that the number of bins for channel\ :sub:`i` is ``num_levels[i] - 1``. +//! + + + +The lower sample value bound (inclusive) for the lowest histogram bin in each active channel. + + + +The upper sample value bound (exclusive) for the highest histogram bin in each active channel. + + + +The number of multi-channel pixels per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of bytes between starts of consecutive rows in the region of interest + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the computation of three 256-bin *RGB* histograms from a 2x3 region of interest of within a flattened 2x4 array of quad-channel *RGBA* pixels (8 bits per channel per pixel). + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for input +// samples and output histograms +int num_row_pixels; // e.g., 3 +int num_rows; // e.g., 2 +size_t row_stride_bytes; // e.g., 4 * sizeof(unsigned char) * NUM_CHANNELS +unsigned char* d_samples; // e.g., [(2, 6, 7, 5), (3, 0, 2, 1), (7, 0, 6, 2), (-, -, -, -), + // (0, 6, 7, 5), (3, 0, 2, 6), (1, 1, 1, 1), (-, -, -, -)] +int* d_histogram[3]; // e.g., three device pointers to three device buffers, + // each allocated with 256 integer counters +int num_levels[3]; // e.g., {257, 257, 257}; +unsigned int lower_level[3]; // e.g., {0, 0, 0}; +unsigned int upper_level[3]; // e.g., {256, 256, 256}; +... + +// Determine temporary device storage requirements +void* d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceHistogram::MultiHistogramEven<4, 3>( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, lower_level, upper_level, + num_row_pixels, num_rows, row_stride_bytes); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Compute histograms +cub::DeviceHistogram::MultiHistogramEven<4, 3>( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, lower_level, upper_level, + num_row_pixels, num_rows, row_stride_bytes); + +// d_histogram <-- [ [1, 1, 1, 2, 0, 0, 0, 1, 0, 0, 0, ..., 0], +// [0, 4, 0, 0, 0, 0, 2, 0, 0, 0, 0, ..., 0], +// [0, 1, 2, 0, 0, 0, 1, 2, 0, 0, 0, ..., 0] ] +``` + + + + +Deprecate [Since 3.0]. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::MultiHistogramEven( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + CounterT *d_histogram[NUM_ACTIVE_CHANNELS], + const int num_levels[NUM_ACTIVE_CHANNELS], + const LevelT lower_level[NUM_ACTIVE_CHANNELS], + const LevelT upper_level[NUM_ACTIVE_CHANNELS], + OffsetT num_row_pixels, + OffsetT num_rows, + size_t row_stride_bytes, + cudaStream_t stream = 0 +) +``` + + + + + +### to_array inline static + + +```cpp showLineNumbers={false} +template +static auto cub::DeviceHistogram::to_array( + T *ptr +) +``` + + +--- + +## Custom bin ranges + +### HistogramRange inline static + + + + +Computes an intensity histogram from a sequence of data samples using the specified bin boundary levels. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::HistogramRange( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + CounterT *d_histogram, + int num_levels, + const LevelT *d_levels, + OffsetT num_samples, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The number of histogram bins is (`num_levels - 1`) +The value range for bini is `[level[i], level[i+1])` +The range `[d_histogram, d_histogram + num_levels - 1)` shall not overlap `[d_samples, d_samples + num_samples)` nor `[d_levels, d_levels + num_levels)` in any way. The ranges `[d_levels, d_levels + num_levels)` and `[d_samples, d_samples + num_samples)` may overlap. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input samples. (may be a simple pointer type) + + + +**[inferred]** Integer type for histogram bin counters + + + +**[inferred]** Type for specifying boundaries (levels) + + + +**[inferred]** Signed integer type for sequence offsets, list lengths, pointer differences, etc. (Consider using 32-bit values as offsets/lengths/etc. For example, `int` will typically yield better performance than `size_t` in 64-bit memory mode.) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the input sequence of data samples. + + + +The pointer to the histogram counter output array of length `num_levels - 1`. + + + +The number of boundaries (levels) for delineating histogram samples. Implies that the number of bins is `num_levels - 1`. + + + +The pointer to the array of boundaries (levels). Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + + + +The number of data samples per row in the region of interest + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the computation of an six-bin histogram from a sequence of float samples + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for input +// samples and output histogram +int num_samples; // e.g., 10 +float* d_samples; // e.g., [2.2, 6.0, 7.1, 2.9, 3.5, 0.3, 2.9, 2.0, 6.1, 999.5] +int* d_histogram; // e.g., [ -, -, -, -, -, -] +int num_levels // e.g., 7 (seven level boundaries for six bins) +float* d_levels; // e.g., [0.0, 2.0, 4.0, 6.0, 8.0, 12.0, 16.0] +... + +// Determine temporary device storage requirements +void* d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceHistogram::HistogramRange( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, d_levels, num_samples); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Compute histograms +cub::DeviceHistogram::HistogramRange( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, d_levels, num_samples); + +// d_histogram <-- [1, 5, 0, 3, 0, 0]; +``` + + + + +Computes an intensity histogram from a sequence of data samples using the specified bin boundary levels. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::HistogramRange( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + CounterT *d_histogram, + int num_levels, + const LevelT *d_levels, + OffsetT num_row_samples, + OffsetT num_rows, + size_t row_stride_bytes, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +A two-dimensional *region of interest* within `d_samples` can be specified using the `num_row_samples`, `num_rows`, and `row_stride_bytes` parameters. +The row stride must be a whole multiple of the sample data type size, i.e., `(row_stride_bytes % sizeof(SampleT)) == 0`. +The number of histogram bins is (`num_levels - 1`) +The value range for bini is `[level[i], level[i+1])` +For a given row `r` in `[0, num_rows)`, let `row_begin = d_samples + r * row_stride_bytes / sizeof(SampleT)` and `row_end = row_begin + num_row_samples`. The range `[d_histogram, d_histogram + num_levels - 1)` shall not overlap `[row_begin, row_end)` nor `[d_levels, d_levels + num_levels)`. The ranges `[d_levels, d_levels + num_levels)` and `[row_begin, row_end)` may overlap. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input samples. (may be a simple pointer type) + + + +**[inferred]** Integer type for histogram bin counters + + + +**[inferred]** Type for specifying boundaries (levels) + + + +**[inferred]** Signed integer type for sequence offsets, list lengths, pointer differences, etc. (Consider using 32-bit values as offsets/lengths/etc. For example, `int` will typically yield better performance than `size_t` in 64-bit memory mode.) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the input sequence of data samples. + + + +The pointer to the histogram counter output array of length `num_levels - 1`. + + + +The number of boundaries (levels) for delineating histogram samples. Implies that the number of bins is `num_levels - 1`. + + + +The pointer to the array of boundaries (levels). Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + + + +The number of data samples per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of bytes between starts of consecutive rows in the region of interest + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the computation of a six-bin histogram from a 2x5 region of interest within a flattened 2x7 array of float samples. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for input samples and +// output histogram +int num_row_samples; // e.g., 5 +int num_rows; // e.g., 2; +int row_stride_bytes; // e.g., 7 * sizeof(float) +float* d_samples; // e.g., [2.2, 6.0, 7.1, 2.9, 3.5, -, -, + // 0.3, 2.9, 2.0, 6.1, 999.5, -, -] +int* d_histogram; // e.g., [ -, -, -, -, -, -] +int num_levels // e.g., 7 (seven level boundaries for six bins) +float *d_levels; // e.g., [0.0, 2.0, 4.0, 6.0, 8.0, 12.0, 16.0] +... + +// Determine temporary device storage requirements +void* d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceHistogram::HistogramRange( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, d_levels, + num_row_samples, num_rows, row_stride_bytes); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Compute histograms +cub::DeviceHistogram::HistogramRange( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, d_levels, + num_row_samples, num_rows, row_stride_bytes); + +// d_histogram <-- [1, 5, 0, 3, 0, 0]; +``` + + + + +### MultiHistogramRange inline static + + + + +Computes per-channel intensity histograms from a sequence of multi-channel "pixel" data samples using the specified bin boundary levels. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::MultiHistogramRange( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_histogram, + ::cuda::std::array num_levels, + ::cuda::std::array d_levels, + OffsetT num_pixels, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The input is a sequence of *pixel* structures, where each pixel comprises a record of `NUM_CHANNELS` consecutive data samples (e.g., an *RGBA* pixel). +`NUM_CHANNELS` can be up to 4. +Of the `NUM_CHANNELS` specified, the function will only compute histograms for the first `NUM_ACTIVE_CHANNELS` (e.g., *RGB* histograms from *RGBA* pixel samples). +The number of histogram bins for channeli is `num_levels[i] - 1`. +For channeli, the range of values for all histogram bins have the same width: `(upper_level[i] - lower_level[i]) / (num_levels[i] - 1)` +For given channels `c1` and `c2` in `[0, NUM_ACTIVE_CHANNELS)`, the range `[d_histogram[c1], d_histogram[c1] + num_levels[c1] - 1)` shall not overlap `[d_samples, d_samples + NUM_CHANNELS * num_pixels)` nor `[d_levels[c2], d_levels[c2] + num_levels[c2])` in any way. The ranges `[d_levels[c2], d_levels[c2] + num_levels[c2])` and `[d_samples, d_samples + NUM_CHANNELS * num_pixels)` may overlap. +@devicestorage + + +**Template parameters** + + +Number of channels interleaved in the input data (may be greater than the number of channels being actively histogrammed) + + + +**[inferred]** Number of channels actively being histogrammed + + + +**[inferred]** Random-access input iterator type for reading input samples. (may be a simple pointer type) + + + +**[inferred]** Integer type for histogram bin counters + + + +**[inferred]** Type for specifying boundaries (levels) + + + +**[inferred]** Signed integer type for sequence offsets, list lengths, pointer differences, etc. (Consider using 32-bit values as offsets/lengths/etc. For example, `int` will typically yield better performance than `size_t` in 64-bit memory mode.) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four *RGBA* 8-bit samples). + + + +Embed:rst:leading-asterisk +//! The pointers to the histogram counter output arrays, one for each active +//! channel. For channel\ :sub:`i`, the allocation length of +//! ``d_histogram[i]`` should be ``num_levels[i] - 1``. +//! + + + +Embed:rst:leading-asterisk +//! The number of boundaries (levels) for delineating histogram samples in +//! each active channel. Implies that the number of bins for +//! channel\ :sub:`i` is ``num_levels[i] - 1``. +//! + + + +The pointers to the arrays of boundaries (levels), one for each active channel. Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + + + +The number of multi-channel pixels (i.e., the length of `d_samples / NUM_CHANNELS`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the computation of three 4-bin *RGB* histograms from a quad-channel sequence of *RGBA* pixels (8 bits per channel per pixel) + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input samples and output histograms +int num_pixels; // e.g., 5 +unsigned char *d_samples; // e.g., [(2, 6, 7, 5),(3, 0, 2, 1),(7, 0, 6, 2), + // (0, 6, 7, 5),(3, 0, 2, 6)] +unsigned int *d_histogram[3]; // e.g., [[ -, -, -, -],[ -, -, -, -],[ -, -, -, -]]; +int num_levels[3]; // e.g., {5, 5, 5}; +unsigned int *d_levels[3]; // e.g., [ [0, 2, 4, 6, 8], + // [0, 2, 4, 6, 8], + // [0, 2, 4, 6, 8] ]; +... + +// Determine temporary device storage requirements +void* d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceHistogram::MultiHistogramRange<4, 3>( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, d_levels, num_pixels); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Compute histograms +cub::DeviceHistogram::MultiHistogramRange<4, 3>( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, d_levels, num_pixels); + +// d_histogram <-- [ [1, 3, 0, 1], +// [3, 0, 0, 2], +// [0, 2, 0, 3] ] +``` + + + + +Deprecate [Since 3.0]. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::MultiHistogramRange( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + CounterT *d_histogram[NUM_ACTIVE_CHANNELS], + const int num_levels[NUM_ACTIVE_CHANNELS], + const LevelT *const d_levels[NUM_ACTIVE_CHANNELS], + OffsetT num_pixels, + cudaStream_t stream = 0 +) +``` + + + + + +Computes per-channel intensity histograms from a sequence of multi-channel "pixel" data samples using the specified bin boundary levels. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::MultiHistogramRange( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_histogram, + ::cuda::std::array num_levels, + ::cuda::std::array d_levels, + OffsetT num_row_pixels, + OffsetT num_rows, + size_t row_stride_bytes, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The input is a sequence of *pixel* structures, where each pixel comprises a record of `NUM_CHANNELS` consecutive data samples (e.g., an *RGBA* pixel). +`NUM_CHANNELS` can be up to 4. +Of the `NUM_CHANNELS` specified, the function will only compute histograms for the first `NUM_ACTIVE_CHANNELS` (e.g., *RGB* histograms from *RGBA* pixel samples). +A two-dimensional *region of interest* within `d_samples` can be specified using the `num_row_samples`, `num_rows`, and `row_stride_bytes` parameters. +The row stride must be a whole multiple of the sample data type size, i.e., `(row_stride_bytes % sizeof(SampleT)) == 0`. +The number of histogram bins for channeli is `num_levels[i] - 1`. +For channeli, the range of values for all histogram bins have the same width: `(upper_level[i] - lower_level[i]) / (num_levels[i] - 1)` +For a given row `r` in `[0, num_rows)`, and sample `s` in `[0, num_row_pixels)`, let `row_begin = d_samples + r * row_stride_bytes / sizeof(SampleT)`, `sample_begin = row_begin + s * NUM_CHANNELS`, and `sample_end = sample_begin + NUM_ACTIVE_CHANNELS`. For given channels `c1` and `c2` in `[0, NUM_ACTIVE_CHANNELS)`, the range `[d_histogram[c1], d_histogram[c1] + num_levels[c1] - 1)` shall not overlap `[sample_begin, sample_end)` nor `[d_levels[c2], d_levels[c2] + num_levels[c2])` in any way. The ranges `[d_levels[c2], d_levels[c2] + num_levels[c2])` and `[sample_begin, sample_end)` may overlap. +@devicestorage + + +**Template parameters** + + +Number of channels interleaved in the input data (may be greater than the number of channels being actively histogrammed) + + + +**[inferred]** Number of channels actively being histogrammed + + + +**[inferred]** Random-access input iterator type for reading input samples. (may be a simple pointer type) + + + +**[inferred]** Integer type for histogram bin counters + + + +**[inferred]** Type for specifying boundaries (levels) + + + +**[inferred]** Signed integer type for sequence offsets, list lengths, pointer differences, etc. (Consider using 32-bit values as offsets/lengths/etc. For example, `int` will typically yield better performance than `size_t` in 64-bit memory mode.) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four *RGBA* 8-bit samples). + + + +Embed:rst:leading-asterisk +//! The pointers to the histogram counter output arrays, one for each active +//! channel. For channel\ :sub:`i`, the allocation length of +//! ``d_histogram[i]`` should be ``num_levels[i] - 1``. +//! + + + +Embed:rst:leading-asterisk +//! The number of boundaries (levels) for delineating histogram samples in +//! each active channel. Implies that the number of bins for +//! channel\ :sub:`i` is ``num_levels[i] - 1``. +//! + + + +The pointers to the arrays of boundaries (levels), one for each active channel. Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + + + +The number of multi-channel pixels per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of bytes between starts of consecutive rows in the region of interest + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the computation of three 4-bin *RGB* histograms from a 2x3 region of interest of within a flattened 2x4 array of quad-channel *RGBA* pixels (8 bits per channel per pixel). + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for input +// samples and output histograms +int num_row_pixels; // e.g., 3 +int num_rows; // e.g., 2 +size_t row_stride_bytes; // e.g., 4 * sizeof(unsigned char) * NUM_CHANNELS +unsigned char* d_samples; // e.g., [(2, 6, 7, 5),(3, 0, 2, 1),(1, 1, 1, 1),(-, -, -, -), + // (7, 0, 6, 2),(0, 6, 7, 5),(3, 0, 2, 6),(-, -, -, -)] +int* d_histogram[3]; // e.g., [[ -, -, -, -],[ -, -, -, -],[ -, -, -, -]]; +int num_levels[3]; // e.g., {5, 5, 5}; +unsigned int* d_levels[3]; // e.g., [ [0, 2, 4, 6, 8], + // [0, 2, 4, 6, 8], + // [0, 2, 4, 6, 8] ]; +... + +// Determine temporary device storage requirements +void* d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceHistogram::MultiHistogramRange<4, 3>( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, d_levels, + num_row_pixels, num_rows, row_stride_bytes); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Compute histograms +cub::DeviceHistogram::MultiHistogramRange<4, 3>( + d_temp_storage, temp_storage_bytes, + d_samples, d_histogram, num_levels, + d_levels, num_row_pixels, num_rows, row_stride_bytes); + +// d_histogram <-- [ [2, 3, 0, 1], +// [3, 0, 0, 2], +// [1, 2, 0, 3] ] +``` + + + + +Deprecate [Since 3.0]. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceHistogram::MultiHistogramRange( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + CounterT *d_histogram[NUM_ACTIVE_CHANNELS], + const int num_levels[NUM_ACTIVE_CHANNELS], + const LevelT *const d_levels[NUM_ACTIVE_CHANNELS], + OffsetT num_row_pixels, + OffsetT num_rows, + size_t row_stride_bytes, + cudaStream_t stream = 0 +) +``` + + + + diff --git a/fern/cudapages/cub/cub/cub/DeviceMemcpy.mdx b/fern/cudapages/cub/cub/cub/DeviceMemcpy.mdx new file mode 100644 index 0000000..58db367 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceMemcpy.mdx @@ -0,0 +1,146 @@ +--- +title: cub::DeviceMemcpy +description: "[cub::DeviceMemcpy](/library/api/cub::_device_memcpy) provides device-wide, parallel operations for copying data." +--- + +`cub::DeviceMemcpy` provides device-wide, parallel operations for copying data. + +--- + +## Static methods + +### Batched inline static + +Copies data from a batch of given source buffers to their corresponding destination buffer. + +.. note:: + +If any input buffer aliases memory from any output buffer the behavior is undefined. If any output buffer aliases memory of another output buffer the behavior is undefined. Input buffers can alias one another. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMemcpy::Batched( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputBufferIt input_buffer_it, + OutputBufferIt output_buffer_it, + BufferSizeIteratorT buffer_sizes, + ::cuda::std::int64_t num_buffers, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** Device-accessible random-access input iterator type providing the pointers to the source memory buffers + + + +**[inferred]** Device-accessible random-access input iterator type providing the pointers to the destination memory buffers + + + +**[inferred]** Device-accessible random-access input iterator type providing the number of bytes to be copied for each pair of buffers + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible iterator providing the pointers to the source memory buffers + + + +Device-accessible iterator providing the pointers to the destination memory buffers + + + +Device-accessible iterator providing the number of bytes to be copied for each pair of buffers + + + +The total number of buffer pairs + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates usage of DeviceMemcpy::Batched for mutating strings withing a single string buffer. + +```cpp showLineNumbers={false} +struct GetPtrToStringItem +{ + __host__ __device__ __forceinline__ void *operator()(uint32_t index) + { + return &d_string_data_in[d_string_offsets[index]]; + } + char *d_string_data_in; + uint32_t *d_string_offsets; +}; + +struct GetStringItemSize +{ + __host__ __device__ __forceinline__ uint32_t operator()(uint32_t index) + { + return d_string_offsets[index + 1] - d_string_offsets[index]; + } + uint32_t *d_string_offsets; +}; + +uint32_t num_strings = 5; +char *d_string_data_in; // e.g., "TomatoesBananasApplesOrangesGrapes" +char *d_string_data_out; // e.g., " ... " +uint32_t *d_string_offsets_old; // e.g., [0, 8, 15, 21, 28, 34] +uint32_t *d_string_offsets_new; // e.g., [0, 6, 13, 19, 26, 34] +uint32_t *d_gather_index; // e.g., [2, 1, 4, 3, 0] + +// Initialize an iterator that returns d_gather_index[i] when the i-th item is dereferenced +auto gather_iterator = thrust::make_permutation_iterator(thrust::make_counting_iterator(0), +d_gather_index); + +// Returns pointers to the input buffer for each string +auto str_ptrs_in = thrust::make_transform_iterator(gather_iterator, + GetPtrToStringItem{d_string_data_in, +d_string_offsets_old}); + +// Returns the string size of the i-th string +auto str_sizes = thrust::make_transform_iterator(gather_iterator, +GetStringItemSize{d_string_offsets_old}); + +// Returns pointers to the output buffer for each string +auto str_ptrs_out = thrust::make_transform_iterator(thrust::make_counting_iterator(0), + GetPtrToStringItem{d_string_data_out, +d_string_offsets_new}); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceMemcpy::Batched(d_temp_storage, temp_storage_bytes, str_ptrs_in, str_ptrs_out, +str_sizes, num_strings); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run batched copy algorithm (used to permute strings) +cub::DeviceMemcpy::Batched(d_temp_storage, temp_storage_bytes, str_ptrs_in, str_ptrs_out, +str_sizes, num_strings); + +// d_string_data_out <-- "ApplesBananasGrapesOrangesTomatoe" +``` diff --git a/fern/cudapages/cub/cub/cub/DeviceMerge.mdx b/fern/cudapages/cub/cub/cub/DeviceMerge.mdx new file mode 100644 index 0000000..496401d --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceMerge.mdx @@ -0,0 +1,202 @@ +--- +title: cub::DeviceMerge +description: "[DeviceMerge](/library/api/cub::_device_merge) provides device-wide, parallel operations for merging two sorted sequences of values (called keys) or key-value pairs in device-accessible memory." +--- + +`DeviceMerge` provides device-wide, parallel operations for merging two sorted sequences of values (called keys) or key-value pairs in device-accessible memory. + +The sorting order is determined by a comparison functor (default: less-than), which has to establish a [strict weak ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + +--- + +## Static methods + +### MergeKeys inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMerge::MergeKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyIteratorIn1 keys_in1, + ::cuda::std::int64_t num_keys1, + KeyIteratorIn2 keys_in2, + ::cuda::std::int64_t num_keys2, + KeyIteratorOut keys_out, + CompareOp compare_op = {}, + cudaStream_t stream = nullptr +) +``` + + +*Added in v2.7.0. First appears in CUDA Toolkit 12.8.* + +**Template parameters** + + +**[deduced]** Random access iterator to the first sorted input sequence. Must have the same value type as KeyIteratorIn2. + + + +**[deduced]** Random access iterator to the second sorted input sequence. Must have the same value type as KeyIteratorIn1. + + + +**[deduced]** Random access iterator to the output sequence. + + + +**[deduced]** Binary predicate to compare the input iterator's value types. Must have a signature equivalent to `bool operator()(Key lhs, Key rhs)` and establish a [strict weak ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation. + + + +Iterator to the beginning of the first sorted input sequence. + + + +Number of keys in the first input sequence. + + + +Iterator to the beginning of the second sorted input sequence. + + + +Number of keys in the second input sequence. + + + +Iterator to the beginning of the output sequence. + + + +Comparison function object, returning true if the first argument is ordered before the second. Must establish a [strict weak ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + + +**[optional]** CUDA stream to launch kernels into. Default is stream0. + + +### MergePairs inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMerge::MergePairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyIteratorIn1 keys_in1, + ValueIteratorIn1 values_in1, + ::cuda::std::int64_t num_pairs1, + KeyIteratorIn2 keys_in2, + ValueIteratorIn2 values_in2, + ::cuda::std::int64_t num_pairs2, + KeyIteratorOut keys_out, + ValueIteratorOut values_out, + CompareOp compare_op = {}, + cudaStream_t stream = nullptr +) +``` + + +*Added in v2.7.0. First appears in CUDA Toolkit 12.8.* + +**Template parameters** + + +**[deduced]** Random access iterator to the keys of the first sorted input sequence. Must have the same value type as KeyIteratorIn2. + + + +**[deduced]** Random access iterator to the values of the first sorted input sequence. Must have the same value type as ValueIteratorIn2. + + + +**[deduced]** Random access iterator to the second sorted input sequence. Must have the same value type as KeyIteratorIn1. + + + +**[deduced]** Random access iterator to the values of the second sorted input sequence. Must have the same value type as ValueIteratorIn1. + + + +**[deduced]** Random access iterator to the keys of the output sequence. + + + +**[deduced]** Random access iterator to the values of the output sequence. + + + +**[deduced]** Binary predicate to compare the key input iterator's value types. Must have a signature equivalent to `bool operator()(Key lhs, Key rhs)` and establish a [strict weak ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation. + + + +Iterator to the beginning of the keys of the first sorted input sequence. + + + +Iterator to the beginning of the values of the first sorted input sequence. + + + +Number of key-value pairs in the first input sequence. + + + +Iterator to the beginning of the keys of the second sorted input sequence. + + + +Iterator to the beginning of the values of the second sorted input sequence. + + + +Number of key-value pairs in the second input sequence. + + + +Iterator to the beginning of the keys of the output sequence. + + + +Iterator to the beginning of the values of the output sequence. + + + +Comparison function object, returning true if the first argument is ordered before the second. Must establish a [strict weak ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order). + + + +**[optional]** CUDA stream to launch kernels into. Default is stream0. + diff --git a/fern/cudapages/cub/cub/cub/DeviceMergeSort.mdx b/fern/cudapages/cub/cub/cub/DeviceMergeSort.mdx new file mode 100644 index 0000000..af35af2 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceMergeSort.mdx @@ -0,0 +1,633 @@ +--- +title: cub::DeviceMergeSort +description: "[DeviceMergeSort](/library/api/cub::_device_merge_sort) provides device-wide, parallel operations for computing a merge sort across a sequence of data items residing within device-accessible memory." +--- + +`DeviceMergeSort` provides device-wide, parallel operations for computing a merge sort across a sequence of data items residing within device-accessible memory. + +**Overview** + +- `DeviceMergeSort` arranges items into ascending order using a comparison functor with less-than semantics. Merge sort can handle arbitrary types (as long as a value of these types is a model of [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable)) and comparison functors, but is slower than [DeviceRadixSort](/library/api/cub::_device_radix_sort) when sorting arithmetic types into ascending/descending order. +- Another difference from RadixSort is the fact that `DeviceMergeSort` can handle arbitrary random-access iterators, as shown below. + +**A Simple Example** + + +The code snippet below illustrates a thrust reverse iterator usage. + +```cpp showLineNumbers={false} +#include // or equivalently + +struct CustomLess +{ + template + __device__ bool operator()(const DataType &lhs, const DataType &rhs) + { + return lhs < rhs; + } +}; + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +thrust::device_vector d_keys(num_items); +thrust::device_vector d_values(num_items); +// ... + +// Initialize iterator +using KeyIterator = typename thrust::device_vector::iterator; +cuda::std::reverse_iterator reverse_iter(d_keys.end()); + +// Determine temporary device storage requirements +size_t temp_storage_bytes = 0; +cub::DeviceMergeSort::SortPairs( + nullptr, + temp_storage_bytes, + reverse_iter, + thrust::raw_pointer_cast(d_values.data()), + num_items, + CustomLess()); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceMergeSort::SortPairs( + d_temp_storage, + temp_storage_bytes, + reverse_iter, + thrust::raw_pointer_cast(d_values.data()), + num_items, + CustomLess()); +``` + +--- + +## Methods + +### GetName inline static constexpr + + +```cpp showLineNumbers={false} +static constexpr const char * cub::DeviceMergeSort::GetName() +``` + + +### SortPairsNoNVTX inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMergeSort::SortPairsNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyIteratorT d_keys, + ValueIteratorT d_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0 +) +``` + + +### SortKeysNoNVTX inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMergeSort::SortKeysNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyIteratorT d_keys, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0 +) +``` + + +### SortKeysCopyNoNVTX inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMergeSort::SortKeysCopyNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_input_keys, + KeyIteratorT d_output_keys, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0 +) +``` + + +--- + +## Static methods + +### SortPairs inline static + +Sorts items using a merge sorting method. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMergeSort::SortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyIteratorT d_keys, + ValueIteratorT d_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator). `KeyIteratorT` is mutable, and its `value_type` is a model of [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable). This `value_type`'s ordering relation is a *strict weak ordering* as defined in the [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable) requirements. + + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator), and `ValueIteratorT` is mutable. + + + +Is an integer type for global offsets. + + + +Is a type of callable object with the signature `bool operator()(KeyT lhs, KeyT rhs)` that models the [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order) concept. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of unsorted input keys + + + +Pointer to the input sequence of unsorted input values + + + +Number of items to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +### SortPairsCopy inline static + +Sorts items using a merge sorting method. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMergeSort::SortPairsCopy( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_input_keys, + ValueInputIteratorT d_input_items, + KeyIteratorT d_output_keys, + ValueIteratorT d_output_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator). Its `value_type` is a model of [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable). This `value_type`'s ordering relation is a *strict weak ordering* as defined in the [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable) requirements. + + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator). + + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator). `KeyIteratorT` is mutable, and its `value_type` is a model of [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable). This `value_type`'s ordering relation is a *strict weak ordering* as defined in the [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable) requirements. + + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator), and `ValueIteratorT` is mutable. + + + +Is an integer type for global offsets. + + + +Is a type of callable object with the signature `bool operator()(KeyT lhs, KeyT rhs)` that models the [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order) concept. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of unsorted input keys + + + +Pointer to the input sequence of unsorted input values + + + +Pointer to the output sequence of sorted input keys + + + +Pointer to the output sequence of sorted input values + + + +Number of items to sort + + + +Comparison function object which returns `true` if the first argument is ordered before the second + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +### SortKeys inline static + +Sorts items using a merge sorting method. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMergeSort::SortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyIteratorT d_keys, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator). `KeyIteratorT` is mutable, and its `value_type` is a model of [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable). This `value_type`'s ordering relation is a *strict weak ordering* as defined in the [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable) requirements. + + + +Is an integer type for global offsets. + + + +Is a type of callable object with the signature `bool operator()(KeyT lhs, KeyT rhs)` that models the [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order) concept. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of unsorted input keys + + + +Number of items to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +### SortKeysCopy inline static + +Sorts items using a merge sorting method. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMergeSort::SortKeysCopy( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_input_keys, + KeyIteratorT d_output_keys, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator). Its `value_type` is a model of [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable). This `value_type`'s ordering relation is a *strict weak ordering* as defined in the [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable) requirements. + + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator). `KeyIteratorT` is mutable, and its `value_type` is a model of [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable). This `value_type`'s ordering relation is a *strict weak ordering* as defined in the [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable) requirements. + + + +Is an integer type for global offsets. + + + +Is a type of callable object with the signature `bool operator()(KeyT lhs, KeyT rhs)` that models the [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order) concept. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of unsorted input keys + + + +Pointer to the output sequence of sorted input keys + + + +Number of items to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +### StableSortPairs inline static + +Sorts items using a merge sorting method. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMergeSort::StableSortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyIteratorT d_keys, + ValueIteratorT d_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator). `KeyIteratorT` is mutable, and its `value_type` is a model of [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable). This `value_type`'s ordering relation is a *strict weak ordering* as defined in the [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable) requirements. + + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator), and `ValueIteratorT` is mutable. + + + +Is an integer type for global offsets. + + + +Is a type of callable object with the signature `bool operator()(KeyT lhs, KeyT rhs)` that models the [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order) concept. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of unsorted input keys + + + +Pointer to the input sequence of unsorted input values + + + +Number of items to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +### StableSortKeys inline static + +Sorts items using a merge sorting method. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMergeSort::StableSortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyIteratorT d_keys, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator). `KeyIteratorT` is mutable, and its `value_type` is a model of [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable). This `value_type`'s ordering relation is a *strict weak ordering* as defined in the [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable) requirements. + + + +Is an integer type for global offsets. + + + +Is a type of callable object with the signature `bool operator()(KeyT lhs, KeyT rhs)` that models the [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order) concept. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of unsorted input keys + + + +Number of items to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +### StableSortKeysCopy inline static + +Sorts items using a merge sorting method. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceMergeSort::StableSortKeysCopy( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_input_keys, + KeyIteratorT d_output_keys, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator). Its `value_type` is a model of [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable). This `value_type`'s ordering relation is a *strict weak ordering* as defined in the [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable) requirements. + + + +Is a model of [Random Access Iterator](https://en.cppreference.com/w/cpp/iterator/random_access_iterator). `KeyIteratorT` is mutable, and its `value_type` is a model of [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable). This `value_type`'s ordering relation is a *strict weak ordering* as defined in the [LessThan Comparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable) requirements. + + + +Is an integer type for global offsets. + + + +Is a type of callable object with the signature `bool operator()(KeyT lhs, KeyT rhs)` that models the [Strict Weak Ordering](https://en.cppreference.com/w/cpp/concepts/strict_weak_order) concept. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of unsorted input keys + + + +Pointer to the output sequence of sorted input keys + + + +Number of elements in d_input_keys to sort + + + +Comparison function object which returns true if the first argument is ordered before the second + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + diff --git a/fern/cudapages/cub/cub/cub/DevicePartition.mdx b/fern/cudapages/cub/cub/cub/DevicePartition.mdx new file mode 100644 index 0000000..5f4c5f3 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DevicePartition.mdx @@ -0,0 +1,730 @@ +--- +title: cub::DevicePartition +description: "" +--- + +DevicePartition provides device-wide, parallel operations for partitioning sequences of data items residing within device-accessible memory. + +## Performance considerations + +@linear_performance{partition} + +--- + +## Methods + +### partition_impl inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DevicePartition::partition_impl( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FlagIteratorT d_flags, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + OffsetT num_items, + SelectOpT select_op, + cudaStream_t stream +) +``` + + +### IfNoNVTX inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DevicePartition::IfNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FirstOutputIteratorT d_first_part_out, + SecondOutputIteratorT d_second_part_out, + UnselectedOutputIteratorT d_unselected_out, + NumSelectedIteratorT d_num_selected_out, + NumItemsT num_items, + SelectFirstPartOp select_first_part_op, + SelectSecondPartOp select_second_part_op, + cudaStream_t stream = 0 +) +``` + + +--- + +## Static methods + +### Flagged inline static + + + + +Uses the `d_flags` sequence to split the corresponding items from `d_in` into a partitioned sequence `d_out`. The total number of items copied into the first partition is written to `d_num_selected_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DevicePartition::Flagged( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FlagIterator d_flags, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The value type of `d_flags` must be castable to `bool` (e.g., `bool`, `char`, `int`, etc.). +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering, however copies of the unselected items are compacted into the rear of `d_out` in reverse order. +The range `[d_out, d_out + num_items)` shall not overlap `[d_in, d_in + num_items)` nor `[d_flags, d_flags + num_items)` in any way. The range `[d_in, d_in + num_items)` may overlap `[d_flags, d_flags + num_items)`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading selection flags (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing output items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the input sequence of selection flags + + + +Pointer to the output sequence of partitioned data items + + + +Pointer to the output total number of items selected (i.e., the offset of the unselected partition) + + + +Total number of items to select from + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the compaction of items selected from an `int` device vector. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input, flags, and output +int num_items; // e.g., 8 +int *d_in; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +char *d_flags; // e.g., [1, 0, 0, 1, 0, 1, 1, 0] +int *d_out; // e.g., [ , , , , , , , ] +int *d_num_selected_out; // e.g., [ ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DevicePartition::Flagged( + d_temp_storage, temp_storage_bytes, + d_in, d_flags, d_out, d_num_selected_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run selection +cub::DevicePartition::Flagged( + d_temp_storage, temp_storage_bytes, + d_in, d_flags, d_out, d_num_selected_out, num_items); + +// d_out <-- [1, 4, 6, 7, 8, 5, 3, 2] +// d_num_selected_out <-- [4] +``` + + + + +nodiscard + +Uses the `d_flags` sequence to split the corresponding items from `d_in` into a partitioned sequence `d_out`. The total number of items copied into the first partition is written to `d_num_selected_out`. + +This is an environment-based API that allows customization of: + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DevicePartition::Flagged( + InputIteratorT d_in, + FlagIterator d_flags, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + NumItemsT num_items, + EnvT env = {} +) +``` + + + +Stream: Query via `cuda::get_stream` +Memory resource: Query via `cuda::mr::get_memory_resource` +The value type of `d_flags` must be castable to `bool` (e.g., `bool`, `char`, `int`, etc.). +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering, however copies of the unselected items are compacted into the rear of `d_out` in reverse order. +The range `[d_out, d_out + num_items)` shall not overlap `[d_in, d_in + num_items)` nor `[d_flags, d_flags + num_items)` in any way. The range `[d_in, d_in + num_items)` may overlap `[d_flags, d_flags + num_items)`. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading selection flags (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing output items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Type of num_items + + + +**[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`) + + +**Parameters** + + +Pointer to the input sequence of data items + + + +Pointer to the input sequence of selection flags + + + +Pointer to the output sequence of partitioned data items + + + +Pointer to the output total number of items selected (i.e., the offset of the unselected partition) + + + +Total number of items to select from + + + +**[optional]** Execution environment. Default is `cuda::std::execution::env{}`. + + + + + +### If inline static + + + + +Uses the `select_op` functor to split the corresponding items from `d_in` into a partitioned sequence `d_out`. The total number of items copied into the first partition is written to `d_num_selected_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DevicePartition::If( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + NumItemsT num_items, + SelectOp select_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering, however copies of the unselected items are compacted into the rear of `d_out` in reverse order. +The range `[d_out, d_out + num_items)` shall not overlap `[d_in, d_in + num_items)` in any way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing output items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Selection functor type having member `bool operator()(const T &a)` + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output sequence of partitioned data items + + + +Pointer to the output total number of items selected (i.e., the offset of the unselected partition) + + + +Total number of items to select from + + + +Unary selection operator + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the compaction of items selected from an `int` device vector. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Functor type for selecting values less than some criteria +struct LessThan +{ + int compare; + + CUB_RUNTIME_FUNCTION __forceinline__ + explicit LessThan(int compare) : compare(compare) {} + + CUB_RUNTIME_FUNCTION __forceinline__ + bool operator()(const int &a) const + { + return (a < compare); + } +}; + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 8 +int *d_in; // e.g., [0, 2, 3, 9, 5, 2, 81, 8] +int *d_out; // e.g., [ , , , , , , , ] +int *d_num_selected_out; // e.g., [ ] +LessThan select_op(7); +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DevicePartition::If( +d_temp_storage, temp_storage_bytes, +d_in, d_out, d_num_selected_out, num_items, select_op); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run selection +cub::DevicePartition::If( + d_temp_storage, temp_storage_bytes, + d_in, d_out, d_num_selected_out, num_items, select_op); + +// d_out <-- [0, 2, 3, 5, 2, 8, 81, 9] +// d_num_selected_out <-- [5] +``` + + + + +nodiscard + +Uses the `select_op` functor to split the corresponding items from `d_in` into a partitioned sequence `d_out`. The total number of items copied into the first partition is written to `d_num_selected_out`. + +This is an environment-based API that allows customization of: + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DevicePartition::If( + InputIteratorT d_in, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + NumItemsT num_items, + SelectOp select_op, + EnvT env = {} +) +``` + + + +Stream: Query via `cuda::get_stream` +Memory resource: Query via `cuda::mr::get_memory_resource` +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering, however copies of the unselected items are compacted into the rear of `d_out` in reverse order. +The range `[d_out, d_out + num_items)` shall not overlap `[d_in, d_in + num_items)` in any way. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing output items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Selection functor type having member `bool operator()(const T &a)` + + + +**[inferred]** Type of num_items + + + +**[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`) + + +**Parameters** + + +Pointer to the input sequence of data items + + + +Pointer to the output sequence of partitioned data items + + + +Pointer to the output total number of items selected (i.e., the offset of the unselected partition) + + + +Total number of items to select from + + + +Unary selection operator + + + +**[optional]** Execution environment. Default is `cuda::std::execution::env{}`. + + + + + +Uses two functors to split the corresponding items from `d_in` into a three partitioned sequences `d_first_part_out`, `d_second_part_out`, and `d_unselected_out`. The total number of items copied into the first partition is written to `d_num_selected_out[0]`, while the total number of items copied into the second partition is written to `d_num_selected_out[1]`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DevicePartition::If( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FirstOutputIteratorT d_first_part_out, + SecondOutputIteratorT d_second_part_out, + UnselectedOutputIteratorT d_unselected_out, + NumSelectedIteratorT d_num_selected_out, + NumItemsT num_items, + SelectFirstPartOp select_first_part_op, + SelectSecondPartOp select_second_part_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Copies of the items selected by `select_first_part_op` are compacted into `d_first_part_out` and maintain their original relative ordering. +Copies of the items selected by `select_second_part_op` are compacted into `d_second_part_out` and maintain their original relative ordering. +Copies of the unselected items are compacted into the `d_unselected_out` in reverse order. +The ranges `[d_out, d_out + num_items)`, `[d_first_part_out, d_first_part_out + d_num_selected_out[0])`, `[d_second_part_out, d_second_part_out + d_num_selected_out[1])`, `[d_unselected_out, d_unselected_out + num_items - d_num_selected_out[0] - d_num_selected_out[1])`, shall not overlap in any way. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing output items selected by first operator (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing output items selected by second operator (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing unselected items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Selection functor type having member `bool operator()(const T &a)` + + + +**[inferred]** Selection functor type having member `bool operator()(const T &a)` + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output sequence of data items selected by `select_first_part_op` + + + +Pointer to the output sequence of data items selected by `select_second_part_op` + + + +Pointer to the output sequence of unselected data items + + + +Pointer to the output array with two elements, where total number of items selected by `select_first_part_op` is stored as `d_num_selected_out[0]` and total number of items selected by `select_second_part_op` is stored as `d_num_selected_out[1]`, respectively + + + +Total number of items to select from + + + +Unary selection operator to select `d_first_part_out` + + + +Unary selection operator to select `d_second_part_out` + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates how this algorithm can partition an input vector into small, medium, and large items so that the relative order of items remain deterministic. + +Let's consider any value that doesn't exceed six a small one. On the other hand, any value that exceeds 50 will be considered a large one. Since the value used to define a small part doesn't match one that defines the large part, the intermediate segment is implied. + +These definitions partition a value space into three categories. We want to preserve the order of items in which they appear in the input vector. Since the algorithm provides stable partitioning, this is possible. + +Since the number of items in each category is unknown beforehand, we need three output arrays of num_items elements each. To reduce the memory requirements, we can combine the output storage for two categories. + +Since each value falls precisely in one category, it's safe to add "large" values into the head of the shared output vector and the "middle" values into its tail. To add items into the tail of the output array, we can use `cuda::std::reverse_iterator`. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Functor type for selecting values less than some criteria +struct LessThan +{ + int compare; + + __host__ __device__ __forceinline__ + explicit LessThan(int compare) : compare(compare) {} + + __host__ __device__ __forceinline__ + bool operator()(const int &a) const + { + return a < compare; + } +}; + +// Functor type for selecting values greater than some criteria +struct GreaterThan +{ + int compare; + + __host__ __device__ __forceinline__ + explicit GreaterThan(int compare) : compare(compare) {} + + __host__ __device__ __forceinline__ + bool operator()(const int &a) const + { + return a > compare; + } +}; + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 8 +int *d_in; // e.g., [0, 2, 3, 9, 5, 2, 81, 8] +int *d_large_and_unselected_out; // e.g., [ , , , , , , , ] +int *d_small_out; // e.g., [ , , , , , , , ] +int *d_num_selected_out; // e.g., [ , ] +cud::std::reverse_iterator unselected_out(d_large_and_unselected_out + num_items); +LessThan small_items_selector(7); +GreaterThan large_items_selector(50); +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DevicePartition::If( + d_temp_storage, temp_storage_bytes, + d_in, d_large_and_medium_out, d_small_out, unselected_out, + d_num_selected_out, num_items, + large_items_selector, small_items_selector); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run selection +cub::DevicePartition::If( + d_temp_storage, temp_storage_bytes, + d_in, d_large_and_medium_out, d_small_out, unselected_out, + d_num_selected_out, num_items, + large_items_selector, small_items_selector); + +// d_large_and_unselected_out <-- [ 81, , , , , , 8, 9 ] +// d_small_out <-- [ 0, 2, 3, 5, 2, , , ] +// d_num_selected_out <-- [ 1, 5 ] +``` + + + diff --git a/fern/cudapages/cub/cub/cub/DeviceRadixSort.mdx b/fern/cudapages/cub/cub/cub/DeviceRadixSort.mdx new file mode 100644 index 0000000..8ae872a --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceRadixSort.mdx @@ -0,0 +1,2874 @@ +--- +title: cub::DeviceRadixSort +description: "" +--- + +DeviceRadixSort provides device-wide, parallel operations for computing a radix sort across a sequence of data items residing within device-accessible memory. + +![](../../img/sorting_logo.png) + +The [radix sorting method](http://en.wikipedia.org/wiki/Radix_sort) arranges items into ascending (or descending) order. The algorithm relies upon a positional representation for keys, i.e., each key is comprised of an ordered sequence of symbols (e.g., digits, characters, etc.) specified from least-significant to most-significant. For a given input sequence of keys and a set of rules specifying a total ordering of the symbolic alphabet, the radix sorting method produces a lexicographic ordering of those keys. + +Assumes threads are in row-major order. + + +DeviceRadixSort can sort all of the built-in C++ numeric primitive types (`unsigned char`, `int`, `double`, etc.) as well as CUDA's `__half` and `__nv_bfloat16` 16-bit floating-point types. User-defined types are supported as long as a decomposer object is provided. + + +- Positive and negative zeros are considered equivalent, and will be treated + as such in the output. +- No special handling is implemented for NaN values; these are sorted + according to their bit representations after any transformations. + + +Although the direct radix sorting method can only be applied to unsigned integral types, DeviceRadixSort is able to sort signed and floating-point types via simple bit-wise transformations that ensure lexicographic key ordering. Additional transformations occur for descending sorts. These transformations must be considered when restricting the `[begin_bit, end_bit)` range, as the bitwise transformations will occur before the bit-range truncation. + +Any transformations applied to the keys prior to sorting are reversed while writing to the final output buffer. + + +To convert the input values into a radix-sortable bitwise representation, the following transformations take place prior to sorting: + +- For unsigned integral values, the keys are used directly. +- For signed integral values, the sign bit is inverted. +- For positive floating point values, the sign bit is inverted. +- For negative floating point values, the full key is inverted. + +For floating point types, positive and negative zero are a special case and will be considered equivalent during sorting. + + +If descending sort is used, the keys are inverted after performing any type-specific transformations, and the resulting keys are sorted in ascending order. + + +DeviceRadixSort is stable. For floating-point types, `-0.0` and `+0.0` are considered equal and appear in the result in the same order as they appear in the input. + + +@cdp_class{DeviceRadixSort} + + +@linear_performance{radix sort} + +--- + +## KeyT-value pairs + +### SortPairs inline static + + + + +Sorts key-value pairs into ascending order using :math:`\approx 2N` auxiliary storage. + +The code snippet below illustrates the sorting of a device vector of `int` keys with associated vector of `int` values. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +Pointers to contiguous memory must be used; iterators are not currently supported. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys_in, d_keys_in + num_items)` +`[d_keys_out, d_keys_out + num_items)` +`[d_values_in, d_values_in + num_items)` +`[d_values_out, d_values_out + num_items)` +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +@devicestorageNP For sorting using only `O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Pointer to the corresponding input sequence of associated value items + + + +Pointer to the correspondingly-reordered output sequence of associated value items + + + +Number of items to sort + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +**Example** + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [ ... ] +int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_values_out; // e.g., [ ... ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items); + +// d_keys_out <-- [0, 3, 5, 6, 7, 8, 9] +// d_values_out <-- [5, 4, 3, 1, 2, 0, 6] +``` + + + + +nodiscard + +Sorts key-value pairs into ascending order using :math:`\approx 2N` auxiliary storage. + +This is an environment-based API that allows customization of: + +The code snippet below illustrates the env-based sorting of key-value pairs: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_env_api.cu :language: c++ :dedent: :start-after: example-begin radix-sort-pairs-env :end-before: example-end radix-sort-pairs-env + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortPairs( + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + EnvT env = {} +) +``` + + +*Added in v3.4.0. First appears in CUDA Toolkit 13.4.* + + +Stream: Query via `cuda::get_stream` +Memory resource: Query via `cuda::mr::get_memory_resource` +The contents of the input data are not altered by the sorting operation. +Pointers to contiguous memory must be used; iterators are not currently supported. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys_in, d_keys_in + num_items)` +`[d_keys_out, d_keys_out + num_items)` +`[d_values_in, d_values_in + num_items)` +`[d_values_out, d_values_out + num_items)` +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`) + + +**Parameters** + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Pointer to the corresponding input sequence of associated value items + + + +Pointer to the correspondingly-reordered output sequence of associated value items + + + +Number of items to sort + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** Execution environment. Default is `cuda::std::execution::env{}`. + + + + + +Sorts key-value pairs into ascending order using :math:`\approx 2N` auxiliary storage. + +* The contents of the input data are not altered by the sorting operation. * Pointers to contiguous memory must be used; iterators are not currently supported. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys_in, d_keys_in + num_items)` * `[d_keys_out, d_keys_out + num_items)` * `[d_values_in, d_values_in + num_items)` * `[d_values_out, d_values_out + num_items)` + +* A bit subrange `[begin_bit, end_bit)` is provided to specify differentiating key bits. This can reduce overall sorting overhead and yield a corresponding performance improvement. * @devicestorageNP For sorting using only :math:`O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortPairs`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-bits :end-before: example-end pairs-bits + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + NumItemsT num_items, + DecomposerT decomposer, + int begin_bit, + int end_bit, + cudaStream_t stream = 0 +) +``` + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Pointer to the corresponding input sequence of associated value items + + + +Pointer to the correspondingly-reordered output sequence of associated value items + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts key-value pairs into ascending order using :math:`\approx 2N` auxiliary storage. + +* The contents of the input data are not altered by the sorting operation. * Pointers to contiguous memory must be used; iterators are not currently supported. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys_in, d_keys_in + num_items)` * `[d_keys_out, d_keys_out + num_items)` * `[d_values_in, d_values_in + num_items)` * `[d_values_out, d_values_out + num_items)` + +* @devicestorageNP For sorting using only :math:`O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortPairs`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs :end-before: example-end pairs + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + NumItemsT num_items, + DecomposerT decomposer, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Pointer to the corresponding input sequence of associated value items + + + +Pointer to the correspondingly-reordered output sequence of associated value items + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts key-value pairs into ascending order using :math:`\approx N` auxiliary storage. + +The code snippet below illustrates the sorting of a device vector of `int` keys with associated vector of `int` values. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers within each pair may be altered by the sorting operation. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys.Current(), d_keys.Current() + num_items)` +`[d_keys.Alternate(), d_keys.Alternate() + num_items)` +`[d_values.Current(), d_values.Current() + num_items)` +`[d_values.Alternate(), d_values.Alternate() + num_items)` +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +Number of items to sort + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +**Example** + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// sorting data +int num_items; // e.g., 7 +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [ ... ] +int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_value_alt_buf; // e.g., [ ... ] +... + +// Create a set of DoubleBuffers to wrap pairs of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); +cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceRadixSort::SortPairs( + d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceRadixSort::SortPairs( + d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items); + +// d_keys.Current() <-- [0, 3, 5, 6, 7, 8, 9] +// d_values.Current() <-- [5, 4, 3, 1, 2, 0, 6] +``` + + + + +nodiscard + +Sorts key-value pairs into ascending order using :math:`\approx N` auxiliary storage. + +This is an environment-based API that allows customization of: + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortPairs( + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + EnvT env = {} +) +``` + + +*Added in v3.4.0. First appears in CUDA Toolkit 13.4.* + + +Stream: Query via `cuda::get_stream` +Memory resource: Query via `cuda::mr::get_memory_resource` +The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers within each pair may be altered by the sorting operation. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys.Current(), d_keys.Current() + num_items)` +`[d_keys.Alternate(), d_keys.Alternate() + num_items)` +`[d_values.Current(), d_values.Current() + num_items)` +`[d_values.Alternate(), d_values.Alternate() + num_items)` +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`) + + +**Parameters** + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +Number of items to sort + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** Execution environment. Default is `cuda::std::execution::env{}`. + + + + + +Sorts key-value pairs into ascending order using :math:`\approx N` auxiliary storage. + +* The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). * The contents of both buffers within each pair may be altered by the sorting operation. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortPairs`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-db :end-before: example-end pairs-db + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + NumItemsT num_items, + DecomposerT decomposer, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +`[d_keys.Current(), d_keys.Current() + num_items)` +`[d_keys.Alternate(), d_keys.Alternate() + num_items)` +`[d_values.Current(), d_values.Current() + num_items)` +`[d_values.Alternate(), d_values.Alternate() + num_items)` +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts key-value pairs into ascending order using :math:`\approx N` auxiliary storage. + +* The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). * The contents of both buffers within each pair may be altered by the sorting operation. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortPairs`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-bits-db :end-before: example-end pairs-bits-db + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + NumItemsT num_items, + DecomposerT decomposer, + int begin_bit, + int end_bit, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +`[d_keys.Current(), d_keys.Current() + num_items)` +`[d_keys.Alternate(), d_keys.Alternate() + num_items)` +`[d_values.Current(), d_values.Current() + num_items)` +`[d_values.Alternate(), d_values.Alternate() + num_items)` +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +### SortPairsDescending inline static + + + + +Sorts key-value pairs into descending order using :math:`\approx 2N` auxiliary storage. + +The code snippet below illustrates the sorting of a device vector of `int` keys with associated vector of `int` values. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +Pointers to contiguous memory must be used; iterators are not currently supported. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys_in, d_keys_in + num_items)` +`[d_keys_out, d_keys_out + num_items)` +`[d_values_in, d_values_in + num_items)` +`[d_values_out, d_values_out + num_items)` +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +@devicestorageNP For sorting using only `O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Pointer to the corresponding input sequence of associated value items + + + +Pointer to the correspondingly-reordered output sequence of associated value items + + + +Number of items to sort + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +**Example** + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [ ... ] +int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_values_out; // e.g., [ ... ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items); + +// d_keys_out <-- [9, 8, 7, 6, 5, 3, 0] +// d_values_out <-- [6, 0, 2, 1, 3, 4, 5] +``` + + + + +nodiscard + +Sorts key-value pairs into descending order using :math:`\approx 2N` auxiliary storage. + +This is an environment-based API that allows customization of: + +The code snippet below illustrates the env-based descending sort of key-value pairs: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_env_api.cu :language: c++ :dedent: :start-after: example-begin radix-sort-pairs-descending-env :end-before: example-end radix-sort-pairs-descending-env + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortPairsDescending( + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + EnvT env = {} +) +``` + + +*Added in v3.4.0. First appears in CUDA Toolkit 13.4.* + + +Stream: Query via `cuda::get_stream` +Memory resource: Query via `cuda::mr::get_memory_resource` +The contents of the input data are not altered by the sorting operation. +Pointers to contiguous memory must be used; iterators are not currently supported. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys_in, d_keys_in + num_items)` +`[d_keys_out, d_keys_out + num_items)` +`[d_values_in, d_values_in + num_items)` +`[d_values_out, d_values_out + num_items)` +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`) + + +**Parameters** + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Pointer to the corresponding input sequence of associated value items + + + +Pointer to the correspondingly-reordered output sequence of associated value items + + + +Number of items to sort + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** Execution environment. Default is `cuda::std::execution::env{}`. + + + + + +Sorts key-value pairs into descending order using :math:`\approx 2N` auxiliary storage. + +* The contents of the input data are not altered by the sorting operation. * Pointers to contiguous memory must be used; iterators are not currently supported. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys_in, d_keys_in + num_items)` * `[d_keys_out, d_keys_out + num_items)` * `[d_values_in, d_values_in + num_items)` * `[d_values_out, d_values_out + num_items)` + +* A bit subrange `[begin_bit, end_bit)` is provided to specify differentiating key bits. This can reduce overall sorting overhead and yield a corresponding performance improvement. * @devicestorageNP For sorting using only :math:`O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortPairsDescending`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-descending-bits :end-before: example-end pairs-descending-bits + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + NumItemsT num_items, + DecomposerT decomposer, + int begin_bit, + int end_bit, + cudaStream_t stream = 0 +) +``` + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Pointer to the corresponding input sequence of associated value items + + + +Pointer to the correspondingly-reordered output sequence of associated value items + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts key-value pairs into descending order using :math:`\approx 2N` auxiliary storage. + +* The contents of the input data are not altered by the sorting operation. * Pointers to contiguous memory must be used; iterators are not currently supported. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys_in, d_keys_in + num_items)` * `[d_keys_out, d_keys_out + num_items)` * `[d_values_in, d_values_in + num_items)` * `[d_values_out, d_values_out + num_items)` + +* @devicestorageNP For sorting using only :math:`O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortPairsDescending`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-descending :end-before: example-end pairs-descending + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + NumItemsT num_items, + DecomposerT decomposer, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Pointer to the corresponding input sequence of associated value items + + + +Pointer to the correspondingly-reordered output sequence of associated value items + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts key-value pairs into descending order using :math:`\approx N` auxiliary storage. + +The code snippet below illustrates the sorting of a device vector of `int` keys with associated vector of `int` values. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers within each pair may be altered by the sorting operation. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys.Current(), d_keys.Current() + num_items)` +`[d_keys.Alternate(), d_keys.Alternate() + num_items)` +`[d_values.Current(), d_values.Current() + num_items)` +`[d_values.Alternate(), d_values.Alternate() + num_items)` +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +Number of items to sort + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +**Example** + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [ ... ] +int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_value_alt_buf; // e.g., [ ... ] +... + +// Create a set of DoubleBuffers to wrap pairs of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); +cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items); + +// d_keys.Current() <-- [9, 8, 7, 6, 5, 3, 0] +// d_values.Current() <-- [6, 0, 2, 1, 3, 4, 5] +``` + + + + +nodiscard + +Sorts key-value pairs into descending order using :math:`\approx N` auxiliary storage. + +This is an environment-based API that allows customization of: + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortPairsDescending( + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + EnvT env = {} +) +``` + + +*Added in v3.4.0. First appears in CUDA Toolkit 13.4.* + + +Stream: Query via `cuda::get_stream` +Memory resource: Query via `cuda::mr::get_memory_resource` +The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers within each pair may be altered by the sorting operation. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys.Current(), d_keys.Current() + num_items)` +`[d_keys.Alternate(), d_keys.Alternate() + num_items)` +`[d_values.Current(), d_values.Current() + num_items)` +`[d_values.Alternate(), d_values.Alternate() + num_items)` +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`) + + +**Parameters** + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +Number of items to sort + + + +The least-significant bit index (inclusive) needed for key comparison + + + +The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** Execution environment. Default is `cuda::std::execution::env{}`. + + + + + +Sorts key-value pairs into descending order using :math:`\approx N` auxiliary storage. + +* The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). * The contents of both buffers within each pair may be altered by the sorting operation. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortPairsDescending`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-descending-db :end-before: example-end pairs-descending-db + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + NumItemsT num_items, + DecomposerT decomposer, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +`[d_keys.Current(), d_keys.Current() + num_items)` +`[d_keys.Alternate(), d_keys.Alternate() + num_items)` +`[d_values.Current(), d_values.Current() + num_items)` +`[d_values.Alternate(), d_values.Alternate() + num_items)` +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts key-value pairs into descending order using :math:`\approx N` auxiliary storage. + +* The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). * The contents of both buffers within each pair may be altered by the sorting operation. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortPairsDescending`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin pairs-descending-bits-db :end-before: example-end pairs-descending-bits-db + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + NumItemsT num_items, + DecomposerT decomposer, + int begin_bit, + int end_bit, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +`[d_keys.Current(), d_keys.Current() + num_items)` +`[d_keys.Alternate(), d_keys.Alternate() + num_items)` +`[d_values.Current(), d_values.Current() + num_items)` +`[d_values.Alternate(), d_values.Alternate() + num_items)` +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** ValueT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +--- + +## Keys-only + +### SortKeys inline static + + + + +Sorts keys into ascending order using :math:`\approx 2N` auxiliary storage. + +The code snippet below illustrates the sorting of a device vector of `int` keys. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +Pointers to contiguous memory must be used; iterators are not currently supported. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys_in, d_keys_in + num_items)` +`[d_keys_out, d_keys_out + num_items)` +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +@devicestorageNP For sorting using only `O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Number of items to sort + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +**Example** + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [ ... ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceRadixSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceRadixSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items); + +// d_keys_out <-- [0, 3, 5, 6, 7, 8, 9] +``` + + + + +Sorts keys into ascending order using :math:`\approx 2N` auxiliary storage. + +* The contents of the input data are not altered by the sorting operation. * Pointers to contiguous memory must be used; iterators are not currently supported. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys_in, d_keys_in + num_items)` * `[d_keys_out, d_keys_out + num_items)` + +* A bit subrange `[begin_bit, end_bit)` is provided to specify differentiating key bits. This can reduce overall sorting overhead and yield a corresponding performance improvement. * @devicestorageNP For sorting using only :math:`O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortKeys`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-bits :end-before: example-end keys-bits + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + NumItemsT num_items, + DecomposerT decomposer, + int begin_bit, + int end_bit, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts keys into ascending order using :math:`\approx 2N` auxiliary storage. + +* The contents of the input data are not altered by the sorting operation. * Pointers to contiguous memory must be used; iterators are not currently supported. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys_in, d_keys_in + num_items)` * `[d_keys_out, d_keys_out + num_items)` + +* An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. * @devicestorageNP For sorting using only :math:`O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortKeys`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys :end-before: example-end keys + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + NumItemsT num_items, + DecomposerT decomposer, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts keys into ascending order using :math:`\approx N` auxiliary storage. + +The code snippet below illustrates the sorting of a device vector of `int` keys. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers may be altered by the sorting operation. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys.Current(), d_keys.Current() + num_items)` +`[d_keys.Alternate(), d_keys.Alternate() + num_items)` +Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Number of items to sort + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +**Example** + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [ ... ] +... + +// Create a DoubleBuffer to wrap the pair of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceRadixSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceRadixSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys, num_items); + +// d_keys.Current() <-- [0, 3, 5, 6, 7, 8, 9] +``` + + + + +Sorts keys into ascending order using :math:`\approx N` auxiliary storage. + +* The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). * The contents of both buffers may be altered by the sorting operation. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys.Current(), d_keys.Current() + num_items)` * `[d_keys.Alternate(), d_keys.Alternate() + num_items)` + +* Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). * @devicestorageP * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortKeys`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-db :end-before: example-end keys-db + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + NumItemsT num_items, + DecomposerT decomposer, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts keys into ascending order using :math:`\approx N` auxiliary storage. + +* The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). * The contents of both buffers may be altered by the sorting operation. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys.Current(), d_keys.Current() + num_items)` * `[d_keys.Alternate(), d_keys.Alternate() + num_items)` + +* A bit subrange `[begin_bit, end_bit)` is provided to specify differentiating key bits. This can reduce overall sorting overhead and yield a corresponding performance improvement. * Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). * @devicestorageP * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortKeys`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-bits-db :end-before: example-end keys-bits-db + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + NumItemsT num_items, + DecomposerT decomposer, + int begin_bit, + int end_bit, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +### SortKeysDescending inline static + + + + +Sorts keys into descending order using :math:`\approx 2N` auxiliary storage. + +The code snippet below illustrates the sorting of a device vector of `int` keys. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +Pointers to contiguous memory must be used; iterators are not currently supported. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys_in, d_keys_in + num_items)` +`[d_keys_out, d_keys_out + num_items)` +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +@devicestorageNP For sorting using only `O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Number of items to sort + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +**Example** + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [ ... ] +... + +// Create a DoubleBuffer to wrap the pair of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceRadixSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceRadixSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items); + +// d_keys_out <-- [9, 8, 7, 6, 5, 3, 0]s +``` + + + + +Sorts keys into descending order using :math:`\approx 2N` auxiliary storage. + +* The contents of the input data are not altered by the sorting operation. * Pointers to contiguous memory must be used; iterators are not currently supported. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys_in, d_keys_in + num_items)` * `[d_keys_out, d_keys_out + num_items)` + +* An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. * @devicestorageNP For sorting using only :math:`O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortKeysDescending`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-descending-bits :end-before: example-end keys-descending-bits + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + NumItemsT num_items, + DecomposerT decomposer, + int begin_bit, + int end_bit, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts keys into descending order using :math:`\approx 2N` auxiliary storage. + +* The contents of the input data are not altered by the sorting operation. * Pointers to contiguous memory must be used; iterators are not currently supported. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys_in, d_keys_in + num_items)` * `[d_keys_out, d_keys_out + num_items)` + +* @devicestorageNP For sorting using only :math:`O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortKeysDescending`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-descending :end-before: example-end keys-descending + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + NumItemsT num_items, + DecomposerT decomposer, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input data of key data to sort + + + +Pointer to the sorted output sequence of key data + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts keys into descending order using :math:`\approx N` auxiliary storage. + +The code snippet below illustrates the sorting of a device vector of `int` keys. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::SortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers may be altered by the sorting operation. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys.Current(), d_keys.Current() + num_items)` +`[d_keys.Alternate(), d_keys.Alternate() + num_items)` +Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Number of items to sort + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +**Example** + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [ ... ] +... + +// Create a DoubleBuffer to wrap the pair of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceRadixSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceRadixSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys, num_items); + +// d_keys.Current() <-- [9, 8, 7, 6, 5, 3, 0] +``` + + + + +Sorts keys into descending order using :math:`\approx N` auxiliary storage. + +* The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). * The contents of both buffers may be altered by the sorting operation. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys.Current(), d_keys.Current() + num_items)` * `[d_keys.Alternate(), d_keys.Alternate() + num_items)` + +* Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). * @devicestorageP * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortKeysDescending`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-descending-db :end-before: example-end keys-descending-db + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + NumItemsT num_items, + DecomposerT decomposer, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +Sorts keys into descending order using :math:`\approx N` auxiliary storage. + +* The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). * The contents of both buffers may be altered by the sorting operation. * In-place operations are not supported. There must be no overlap between any of the provided ranges: + +* `[d_keys.Current(), d_keys.Current() + num_items)` * `[d_keys.Alternate(), d_keys.Alternate() + num_items)` + +* A bit subrange `[begin_bit, end_bit)` is provided to specify differentiating key bits. This can reduce overall sorting overhead and yield a corresponding performance improvement. * Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). * @devicestorageP * @devicestorage + +Let's consider a user-defined `custom_t` type below. To sort an array of `custom_t` objects, we have to tell CUB about relevant members of the `custom_t` type. We do this by providing a decomposer that returns a tuple of references to relevant members of the key. + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin custom-type :end-before: example-end custom-type + +The following snippet shows how to sort an array of `custom_t` objects using `cub::DeviceRadixSort::SortKeysDescending`: + +.. literalinclude:: ../../../cub/test/catch2_test_device_radix_sort_custom.cu :language: c++ :dedent: :start-after: example-begin keys-descending-bits-db :end-before: example-end keys-descending-bits-db + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceRadixSort::SortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + NumItemsT num_items, + DecomposerT decomposer, + int begin_bit, + int end_bit, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** KeyT type + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of a callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types: `::cuda::std::tuple operator()(KeyT &key)`. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Number of items to sort + + + +Callable object responsible for decomposing a `KeyT` into a tuple of references to its constituent arithmetic types. The leftmost element of the tuple is considered the most significant. The call operator must not modify members of the key. + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `(sizeof(float) + sizeof(long long int)) * 8`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + + + +--- + +## Utility methods + +### custom_radix_sort inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::custom_radix_sort( + void *d_temp_storage, + size_t &temp_storage_bytes, + bool is_overwrite_okay, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + OffsetT num_items, + DecomposerT decomposer, + int begin_bit, + int end_bit, + cudaStream_t stream +) +``` + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRadixSort::custom_radix_sort( + void *d_temp_storage, + size_t &temp_storage_bytes, + bool is_overwrite_okay, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + OffsetT num_items, + DecomposerT decomposer, + cudaStream_t stream +) +``` + + + + + +### GetName inline static constexpr + + +```cpp showLineNumbers={false} +static constexpr const char * cub::DeviceRadixSort::GetName() +``` + diff --git a/fern/cudapages/cub/cub/cub/DeviceReduce.mdx b/fern/cudapages/cub/cub/cub/DeviceReduce.mdx new file mode 100644 index 0000000..edc92fc --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceReduce.mdx @@ -0,0 +1,1750 @@ +--- +title: cub::DeviceReduce +description: "" +--- + +DeviceReduce provides device-wide, parallel operations for computing a reduction across a sequence of data items residing within device-accessible memory. + +![](../../img/reduce_logo.png) + +A [reduction](http://en.wikipedia.org/wiki/Reduce_(higher-order_function)) (or *fold*) uses a binary combining operator to compute a single aggregate from a sequence of input elements. + + +@cdp_class{DeviceReduce} + + +@linear_performance{reduction, reduce-by-key, and run-length encode} + +--- + +## Methods + +### reduce_impl inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::reduce_impl( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + ReductionOpT reduction_op, + TransformOpT transform_op, + T init, + ::cuda::execution::determinism::__determinism_holder_t, + cudaStream_t stream +) +``` + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::reduce_impl( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + ReductionOpT, + TransformOpT transform_op, + T init, + ::cuda::execution::determinism::gpu_to_gpu_t, + cudaStream_t stream +) +``` + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::reduce_impl( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + ReductionOpT reduction_op, + TransformOpT transform_op, + T init, + ::cuda::execution::determinism::not_guaranteed_t, + cudaStream_t stream +) +``` + + + + + +--- + +## Static methods + +### Reduce inline static + + + + +Computes a device-wide reduction using the specified binary `reduction_op` functor and initial value `init`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::Reduce( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + ReductionOpT reduction_op, + T init, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Does not support binary reduction operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Data element type that is convertible to the `value` type of `InputIteratorT` + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of `d_in`) + + + +Binary reduction functor + + + +Initial value of the reduction + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates a user-defined min-reduction of a device vector of `int` data elements. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// CustomMin functor +struct CustomMin +{ + template + __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return (b < a) ? b : a; + } +}; + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_out; // e.g., [-] +CustomMin min_op; +int init; // e.g., INT_MAX +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceReduce::Reduce( + d_temp_storage, temp_storage_bytes, + d_in, d_out, num_items, min_op, init); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run reduction +cub::DeviceReduce::Reduce( + d_temp_storage, temp_storage_bytes, + d_in, d_out, num_items, min_op, init); + +// d_out <-- [0] +``` + + + + +nodiscard + +Computes a device-wide reduction using the specified binary `reduction_op` functor and initial value `init`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::Reduce( + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + ReductionOpT reduction_op, + T init, + EnvT env = {} +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Does not support binary reduction operators that are non-commutative. +By default, provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. To request "gpu-to-gpu" determinism, pass `cuda::execution::require(cuda::execution::determinism::gpu_to_gpu)` as the `env` parameter. To request "not-guaranteed" determinism, pass `cuda::execution::require(cuda::execution::determinism::not_guaranteed)` as the `env` parameter. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Data element type that is convertible to the `value` type of `InputIteratorT` + + + +**[inferred]** Type of num_items + + + +**[inferred]** Execution environment type. Default is `cuda::std::execution::env<>`. + + +**Parameters** + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of `d_in`) + + + +Binary reduction functor + + + +Initial value of the reduction + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``. +//! + + + + + +### Sum inline static + + + + +nodiscard + +Computes a device-wide sum using the addition (`+`) operator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::Sum( + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + EnvT env = {} +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses `0` as the initial value of the reduction. +Does not support `+` operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. To request "gpu-to-gpu" determinism, pass `cuda::execution::require(cuda::execution::determinism::gpu_to_gpu)` as the `env` parameter. To request "not-guaranteed" determinism, pass `cuda::execution::require(cuda::execution::determinism::not_guaranteed)` as the `env` parameter. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Type of num_items + + + +**[inferred]** Execution environment type. Default is `cuda::std::execution::env<>`. + + +**Parameters** + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is `cuda::std::execution::env{}`. +//! + + + + + +Computes a device-wide sum using the addition (`+`) operator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::Sum( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses `0` as the initial value of the reduction. +Does not support `+` operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the sum-reduction of a device vector of `int` data elements. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_out; // e.g., [-] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceReduce::Sum( + d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sum-reduction +cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + +// d_out <-- [38] +``` + + + + +### Min inline static + + + + +Computes a device-wide minimum using the less-than (`<`) operator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::Min( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses `cuda::std::numeric_limits::max()` as the initial value of the reduction. +Does not support `<` operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the min-reduction of a device vector of `int` data elements. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_out; // e.g., [-] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceReduce::Min( + d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run min-reduction +cub::DeviceReduce::Min( + d_temp_storage, temp_storage_bytes, d_in, d_out, num_items); + +// d_out <-- [0] +``` + + + + +nodiscard + +Computes a device-wide minimum using the less-than (`<`) operator. The result is written to the output iterator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::Min( + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + EnvT env = {} +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses `cuda::std::numeric_limits::max()` as the initial value of the reduction. +Provides determinism based on the environment's determinism requirements. To request "run-to-run" determinism, pass `cuda::execution::require(cuda::execution::determinism::run_to_run)` as the `env` parameter. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Type of num_items + + + +**[inferred]** Execution environment type. Default is `cuda::std::execution::env<>`. + + +**Parameters** + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is `cuda::std::execution::env{}`. +//! + + + + + +### ArgMin inline static + + + + +Finds the first device-wide minimum using the less-than (`<`) operator and also returns the index of that item. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::ArgMin( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + ExtremumOutIteratorT d_min_out, + IndexOutIteratorT d_index_out, + ::cuda::std::int64_t num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The minimum is written to `d_min_out` +The offset of the returned item is written to `d_index_out`, the offset type being written is of type `cuda::std::int64_t`. +For zero-length inputs, `cuda::std::numeric_limits::max()}` is written to `d_min_out` and the index `1` is written to `d_index_out`. +Does not support `<` operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +The range `[d_in, d_in + num_items)` shall not overlap `d_min_out` nor `d_index_out`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (of some type `T`) (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording minimum value + + + +**[inferred]** Output iterator type for recording index of the returned value + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Iterator to the input sequence of data items + + + +Iterator to which the minimum value is written + + + +Iterator to which the index of the returned value is written + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the argmin-reduction of a device vector of `int` data elements. + +```cpp showLineNumbers={false} +#include // or equivalently +#include + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_min_out; // memory for the minimum value +cuda::std::int64_t *d_index_out; // memory for the index of the returned value +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_min_out, d_index_out, +num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run argmin-reduction +cub::DeviceReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_min_out, d_index_out, +num_items); + +// d_min_out <-- 0 +// d_index_out <-- 5 +``` + + + + +nodiscard + +Finds the first device-wide minimum using the less-than (`<`) operator and also returns the index of that item. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::ArgMin( + InputIteratorT d_in, + ExtremumOutIteratorT d_min_out, + IndexOutIteratorT d_index_out, + ::cuda::std::int64_t num_items, + EnvT env = {} +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The minimum is written to `d_min_out` +The offset of the returned item is written to `d_index_out`, the offset type being written is of type `cuda::std::int64_t`. +For zero-length inputs, `cuda::std::numeric_limits::max()}` is written to `d_min_out` and the index `1` is written to `d_index_out`. +Does not support `<` operators that are non-commutative. +Provides determinism based on the environment's determinism requirements. To request "run-to-run" determinism, pass `cuda::execution::require(cuda::execution::determinism::run_to_run)` as the `env` parameter. +The range `[d_in, d_in + num_items)` shall not overlap `d_min_out` nor `d_index_out`. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (of some type `T`) (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording minimum value + + + +**[inferred]** Output iterator type for recording index of the returned value + + + +**[inferred]** Execution environment type. Default is `cuda::std::execution::env<>`. + + +**Parameters** + + +Iterator to the input sequence of data items + + + +Iterator to which the minimum value is written + + + +Iterator to which the index of the returned value is written + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``. +//! + + + + + +Finds the first device-wide minimum using the less-than (`<`) operator, also returning the index of that item. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::ArgMin( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The output value type of `d_out` is `cub::KeyValuePair` (assuming the value type of `d_in` is `T`) +The minimum is written to `d_out.value` and its offset in the input array is written to `d_out.key`. +The `{1, cuda::std::numeric_limits::max()}` tuple is produced for zero-length inputs +Does not support `<` operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (of some type `T`) (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (having value type `cub::KeyValuePair`) (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the argmin-reduction of a device vector of `int` data elements. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +KeyValuePair *d_argmin; // e.g., [{-,-}] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_argmin, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run argmin-reduction +cub::DeviceReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_argmin, num_items); + +// d_argmin <-- [{5, 0}] +``` + + + + +### Max inline static + + + + +Computes a device-wide maximum using the greater-than (`>`) operator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::Max( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses `cuda::std::numeric_limits::lowest()` as the initial value of the reduction. +Does not support `>` operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the max-reduction of a device vector of `int` data elements. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_max; // e.g., [-] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_max, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run max-reduction +cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_max, num_items); + +// d_max <-- [9] +``` + + + + +nodiscard + +Computes a device-wide maximum using the greater-than (`>`) operator. The result is written to the output iterator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::Max( + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + EnvT env = {} +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses `cuda::std::numeric_limits::lowest()` as the initial value of the reduction. +Provides determinism based on the environment's determinism requirements. To request "run-to-run" determinism, pass `cuda::execution::require(cuda::execution::determinism::run_to_run)` as the `env` parameter. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Type of num_items + + + +**[inferred]** Execution environment type. Default is `cuda::std::execution::env<>`. + + +**Parameters** + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``. +//! + + + + + +### ArgMax inline static + + + + +Finds the first device-wide maximum using the greater-than (`>`) operator and also returns the index of that item. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::ArgMax( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + ExtremumOutIteratorT d_max_out, + IndexOutIteratorT d_index_out, + ::cuda::std::int64_t num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The maximum is written to `d_max_out` +The offset of the returned item is written to `d_index_out`, the offset type being written is of type `cuda::std::int64_t`. +For zero-length inputs, `cuda::std::numeric_limits::max()}` is written to `d_max_out` and the index `1` is written to `d_index_out`. +Does not support `>` operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (of some type `T`) (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording maximum value + + + +**[inferred]** Output iterator type for recording index of the returned value + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Iterator to which the maximum value is written + + + +Iterator to which the index of the returned value is written + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the argmax-reduction of a device vector of `int` data elements. + +```cpp showLineNumbers={false} +#include // or equivalently +#include + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_max_out; // memory for the maximum value +cuda::std::int64_t *d_index_out; // memory for the index of the returned value +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceReduce::ArgMax( + d_temp_storage, temp_storage_bytes, d_in, d_max_out, d_index_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run argmax-reduction +cub::DeviceReduce::ArgMax( + d_temp_storage, temp_storage_bytes, d_in, d_max_out, d_index_out, num_items); + +// d_max_out <-- 9 +// d_index_out <-- 6 +``` + + + + +Finds the first device-wide maximum using the greater-than (`>`) operator, also returning the index of that item + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::ArgMax( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The output value type of `d_out` is `cub::KeyValuePair` (assuming the value type of `d_in` is `T`) +The maximum is written to `d_out.value` and its offset in the input array is written to `d_out.key`. +The `{1, cuda::std::numeric_limits::lowest()}` tuple is produced for zero-length inputs +Does not support `>` operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (of some type `T`) (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (having value type `cub::KeyValuePair`) (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the argmax-reduction of a device vector of `int` data elements. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +KeyValuePair *d_argmax; // e.g., [{-,-}] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceReduce::ArgMax( + d_temp_storage, temp_storage_bytes, d_in, d_argmax, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run argmax-reduction +cub::DeviceReduce::ArgMax( + d_temp_storage, temp_storage_bytes, d_in, d_argmax, num_items); + +// d_argmax <-- [{6, 9}] +``` + + + + +nodiscard + +Finds the first device-wide maximum using the greater-than (`>`) operator and also returns the index of that item. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::ArgMax( + InputIteratorT d_in, + ExtremumOutIteratorT d_max_out, + IndexOutIteratorT d_index_out, + ::cuda::std::int64_t num_items, + EnvT env = {} +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The maximum is written to `d_max_out` +The offset of the returned item is written to `d_index_out`, the offset type being written is of type `cuda::std::int64_t`. +For zero-length inputs, `cuda::std::numeric_limits::lowest()}` is written to `d_max_out` and the index `1` is written to `d_index_out`. +Does not support `>` operators that are non-commutative. +Provides determinism based on the environment's determinism requirements. To request "run-to-run" determinism, pass `cuda::execution::require(cuda::execution::determinism::run_to_run)` as the `env` parameter. +The range `[d_in, d_in + num_items)` shall not overlap `d_max_out` nor `d_index_out`. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (of some type `T`) (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording maximum value + + + +**[inferred]** Output iterator type for recording index of the returned value + + + +**[inferred]** Execution environment type. Default is `cuda::std::execution::env<>`. + + +**Parameters** + + +Iterator to the input sequence of data items + + + +Iterator to which the maximum value is written + + + +Iterator to which the index of the returned value is written + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``. +//! + + + + + +### TransformReduce inline static + +Fuses transform and reduce operations + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::TransformReduce( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + ReductionOpT reduction_op, + TransformOpT transform_op, + T init, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Does not support binary reduction operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +The range `[d_in, d_in + num_items)` shall not overlap `d_out`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Unary reduction functor type having member `auto operator()(const T &a)` + + + +**[inferred]** Data element type that is convertible to the `value` type of `InputIteratorT` + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of `d_in`) + + + +Binary reduction functor + + + +Unary transform functor + + + +Initial value of the reduction + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates a user-defined min-reduction of a device vector of `int` data elements. + +```cpp showLineNumbers={false} +#include +// or equivalently + +thrust::device_vector in = { 1, 2, 3, 4 }; +thrust::device_vector out(1); + +size_t temp_storage_bytes = 0; +uint8_t *d_temp_storage = nullptr; + +const int init = 42; + +cub::DeviceReduce::TransformReduce( + d_temp_storage, + temp_storage_bytes, + in.begin(), + out.begin(), + in.size(), + cuda::std::plus<>{}, + square_t{}, + init); + +thrust::device_vector temp_storage(temp_storage_bytes); +d_temp_storage = temp_storage.data().get(); + +cub::DeviceReduce::TransformReduce( + d_temp_storage, + temp_storage_bytes, + in.begin(), + out.begin(), + in.size(), + cuda::std::plus<>{}, + square_t{}, + init); + +// out[0] <-- 72 +``` + +### ReduceByKey inline static + +Reduces segments of values, where segments are demarcated by corresponding runs of identical keys. + +This operation computes segmented reductions within `d_values_in` using the specified binary `reduction_op` functor. The segments are identified by "runs" of corresponding keys in `d_keys_in`, where runs are maximal ranges of consecutive, identical keys. For the *i*th run encountered, the last key of the run and the corresponding value aggregate of that run are written to `d_unique_out[i]` and `d_aggregates_out[i]`, respectively. The total number of runs encountered is written to `d_num_runs_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceReduce::ReduceByKey( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, + ReductionOpT reduction_op, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The `==` equality operator is used to determine whether keys are equivalent +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +Let `out` be any of `[d_unique_out, d_unique_out + *d_num_runs_out)` `[d_aggregates_out, d_aggregates_out + *d_num_runs_out)` `d_num_runs_out`. The ranges represented by `out` shall not overlap `[d_keys_in, d_keys_in + num_items)`, `[d_values_in, d_values_in + num_items)` nor `out` in any way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input keys (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing unique output keys (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading input values (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of runs encountered (may be a simple pointer type) + + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of keys + + + +Pointer to the output sequence of unique keys (one key per run) + + + +Pointer to the input sequence of corresponding values + + + +Pointer to the output sequence of value aggregates (one aggregate per run) + + + +Pointer to total number of runs encountered (i.e., the length of `d_unique_out`) + + + +Binary reduction functor + + + +Total number of associated key+value pairs (i.e., the length of `d_in_keys` and `d_in_values`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the segmented reduction of `int` values grouped by runs of associated `int` keys. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// CustomMin functor +struct CustomMin +{ + template + __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return (b < a) ? b : a; + } +}; + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 8 +int *d_keys_in; // e.g., [0, 2, 2, 9, 5, 5, 5, 8] +int *d_values_in; // e.g., [0, 7, 1, 6, 2, 5, 3, 4] +int *d_unique_out; // e.g., [-, -, -, -, -, -, -, -] +int *d_aggregates_out; // e.g., [-, -, -, -, -, -, -, -] +int *d_num_runs_out; // e.g., [-] +CustomMin reduction_op; +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceReduce::ReduceByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_unique_out, d_values_in, + d_aggregates_out, d_num_runs_out, reduction_op, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run reduce-by-key +cub::DeviceReduce::ReduceByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_unique_out, d_values_in, + d_aggregates_out, d_num_runs_out, reduction_op, num_items); + +// d_unique_out <-- [0, 2, 9, 5, 8] +// d_aggregates_out <-- [0, 1, 6, 2, 4] +// d_num_runs_out <-- [5] +``` diff --git a/fern/cudapages/cub/cub/cub/DeviceRleDispatch.mdx b/fern/cudapages/cub/cub/cub/DeviceRleDispatch.mdx new file mode 100644 index 0000000..3efd30e --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceRleDispatch.mdx @@ -0,0 +1,212 @@ +--- +title: cub::DeviceRleDispatch +description: "Utility class for dispatching the appropriately-tuned kernels for DeviceRle." +--- + +Utility class for dispatching the appropriately-tuned kernels for DeviceRle. + + + + + +Random-access input iterator type for reading input items (may be a simple pointer type) + + + +Random-access output iterator type for writing run-offset values (may be a simple pointer type) + + + +Random-access output iterator type for writing run-length values (may be a simple pointer type) + + + +Output iterator type for recording the number of runs encountered (may be a simple pointer type) + + + +T equality operator type + + + +Signed integer type for global offsets + + + +Implementation detail, do not specify directly, requirements on the content of this type are subject to breaking change. + + + + + +--- + +## Constructors + +### DeviceRleDispatch inline + + +```cpp showLineNumbers={false} +cub::DeviceRleDispatch::DeviceRleDispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OffsetsOutputIteratorT d_offsets_out, + LengthsOutputIteratorT d_lengths_out, + NumRunsOutputIteratorT d_num_runs_out, + EqualityOpT equality_op, + global_offset_t num_items, + cudaStream_t stream +) +``` + + +--- + +## Methods + +### Invoke inline + + + + +Internal dispatch routine for computing a device-wide run-length-encode using the specified kernel functions. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DeviceRleDispatch::Invoke( + DeviceScanInitKernelPtr device_scan_init_kernel, + DeviceRleSweepKernelPtr device_rle_sweep_kernel +) +``` + + +**Template parameters** + + +Function type of cub::DeviceScanInitKernel + + + +Function type of cub::DeviceRleSweepKernelPtr + + +**Parameters** + + +Kernel function pointer to parameterization of cub::DeviceScanInitKernel + + + +Kernel function pointer to parameterization of cub::DeviceRleSweepKernel + + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DeviceRleDispatch::Invoke() +``` + + + + + +--- + +## Static methods + +### Dispatch inline static + +Internal dispatch routine. + + +```cpp showLineNumbers={false} +static cudaError_t cub::DeviceRleDispatch::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OffsetsOutputIteratorT d_offsets_out, + LengthsOutputIteratorT d_lengths_out, + NumRunsOutputIteratorT d_num_runs_out, + EqualityOpT equality_op, + OffsetT num_items, + cudaStream_t stream +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_device_rle_dispatch::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_device_rle_dispatch::d_temp_storage) allocation + + + +Pointer to input sequence of data items + + + +Pointer to output sequence of run-offsets + + + +Pointer to output sequence of run-lengths + + + +Pointer to total number of runs (i.e., length of [`d_offsets_out`](/library/api/cub::_device_rle_dispatch::d_offsets_out)) + + + +Equality operator for input items + + + +Total number of input items (i.e., length of [`d_in`](/library/api/cub::_device_rle_dispatch::d_in)) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `local_offset_t` | `::cuda::std::int32_t` | +| `global_offset_t` | `OffsetT` | +| `length_t` | `cub::detail::non_void_value_t< LengthsOutputIteratorT, global_offset_t >` | +| `streaming_context_t` | `::cuda::std::conditional_t< use_streaming_invocation, detail::rle::streaming_context< InputIteratorT, length_t, global_offset_t >, NullType >` | +| `ScanTileStateT` | `ReduceByKeyScanTileState< length_t, local_offset_t >` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `use_streaming_invocation` static constexpr | `bool` | | +| `init_kernel_threads` static constexpr | `int` | | +| `d_temp_storage` | `void *` | | +| `temp_storage_bytes` | `size_t &` | | +| `d_in` | `InputIteratorT` | | +| `d_offsets_out` | `OffsetsOutputIteratorT` | | +| `d_lengths_out` | `LengthsOutputIteratorT` | | +| `d_num_runs_out` | `NumRunsOutputIteratorT` | | +| `equality_op` | `EqualityOpT` | | +| `num_items` | `global_offset_t` | | +| `stream` | `cudaStream_t` | | diff --git a/fern/cudapages/cub/cub/cub/DeviceRunLengthEncode.mdx b/fern/cudapages/cub/cub/cub/DeviceRunLengthEncode.mdx new file mode 100644 index 0000000..cf09535 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceRunLengthEncode.mdx @@ -0,0 +1,270 @@ +--- +title: cub::DeviceRunLengthEncode +description: "" +--- + +DeviceRunLengthEncode provides device-wide, parallel operations for demarcating "runs" of same-valued items within a sequence residing within device-accessible memory. + +## Performance considerations + +@linear_performance{run-length encode} + +--- + +## Static methods + +### Encode inline static + +Computes a run-length encoding of the sequence `d_in`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRunLengthEncode::Encode( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + UniqueOutputIteratorT d_unique_out, + LengthsOutputIteratorT d_counts_out, + NumRunsOutputIteratorT d_num_runs_out, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +For the *i*th run encountered, the first key of the run and its length are written to `d_unique_out[i]` and `d_counts_out[i]`, respectively. +The total number of runs encountered is written to `d_num_runs_out`. +The `==` equality operator is used to determine whether values are equivalent +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_unique_out, d_unique_out + *d_num_runs_out)` +`[d_counts_out, d_counts_out + *d_num_runs_out)` +`[d_num_runs_out, d_num_runs_out + 1)` +`[d_in, d_in + num_items)` +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing unique output items (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing output counts (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of runs encountered (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of keys + + + +Pointer to the output sequence of unique keys (one key per run) + + + +Pointer to the output sequence of run-lengths (one count per run) + + + +Pointer to total number of runs + + + +Total number of associated key+value pairs (i.e., the length of `d_in_keys` and `d_in_values`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the run-length encoding of a sequence of `int` values. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 8 +int *d_in; // e.g., [0, 2, 2, 9, 5, 5, 5, 8] +int *d_unique_out; // e.g., [ , , , , , , , ] +int *d_counts_out; // e.g., [ , , , , , , , ] +int *d_num_runs_out; // e.g., [ ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceRunLengthEncode::Encode( + d_temp_storage, temp_storage_bytes, + d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run encoding +cub::DeviceRunLengthEncode::Encode( + d_temp_storage, temp_storage_bytes, + d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items); + +// d_unique_out <-- [0, 2, 9, 5, 8] +// d_counts_out <-- [1, 2, 1, 3, 1] +// d_num_runs_out <-- [5] +``` + +### NonTrivialRuns inline static + +Enumerates the starting offsets and lengths of all non-trivial runs (of `length > 1`) of same-valued keys in the sequence `d_in`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceRunLengthEncode::NonTrivialRuns( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OffsetsOutputIteratorT d_offsets_out, + LengthsOutputIteratorT d_lengths_out, + NumRunsOutputIteratorT d_num_runs_out, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +For the *i*th non-trivial run, the run's starting offset and its length are written to `d_offsets_out[i]` and `d_lengths_out[i]`, respectively. +The total number of runs encountered is written to `d_num_runs_out`. +The `==` equality operator is used to determine whether values are equivalent +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_offsets_out, d_offsets_out + *d_num_runs_out)` +`[d_lengths_out, d_lengths_out + *d_num_runs_out)` +`[d_num_runs_out, d_num_runs_out + 1)` +`[d_in, d_in + num_items)` +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing run-offset values (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing run-length values (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of runs encountered (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to input sequence of data items + + + +Pointer to output sequence of run-offsets (one offset per non-trivial run) + + + +Pointer to output sequence of run-lengths (one count per non-trivial run) + + + +Pointer to total number of runs (i.e., length of `d_offsets_out`) + + + +Total number of associated key+value pairs (i.e., the length of `d_in_keys` and `d_in_values`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the identification of non-trivial runs within a sequence of `int` values. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 8 +int *d_in; // e.g., [0, 2, 2, 9, 5, 5, 5, 8] +int *d_offsets_out; // e.g., [ , , , , , , , ] +int *d_lengths_out; // e.g., [ , , , , , , , ] +int *d_num_runs_out; // e.g., [ ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceRunLengthEncode::NonTrivialRuns( + d_temp_storage, temp_storage_bytes, + d_in, d_offsets_out, d_lengths_out, d_num_runs_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run encoding +cub::DeviceRunLengthEncode::NonTrivialRuns( + d_temp_storage, temp_storage_bytes, + d_in, d_offsets_out, d_lengths_out, d_num_runs_out, num_items); + +// d_offsets_out <-- [1, 4] +// d_lengths_out <-- [2, 3] +// d_num_runs_out <-- [2] +``` diff --git a/fern/cudapages/cub/cub/cub/DeviceScan.mdx b/fern/cudapages/cub/cub/cub/DeviceScan.mdx new file mode 100644 index 0000000..e0c1815 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceScan.mdx @@ -0,0 +1,2197 @@ +--- +title: cub::DeviceScan +description: "" +--- + +DeviceScan provides device-wide, parallel operations for computing a prefix scan across a sequence of data items residing within device-accessible memory. + +## Performance considerations + +@linear_performance{prefix scan} + +--- + +## Exclusive scans + +### ExclusiveSum inline static + + + + +Computes a device-wide exclusive prefix sum. The value of `0` is applied as the initial value, and is assigned to `*d_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::ExclusiveSum( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative sum operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +When `d_in` and `d_out` are equal, the scan is performed in-place. The range `[d_in, d_in + num_items)` and `[d_out, d_out + num_items)` shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan outputs (may be a simple pointer type) + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the exclusive prefix sum of an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_out; // e.g., [ , , , , , , ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::ExclusiveSum( + d_temp_storage, temp_storage_bytes, + d_in, d_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run exclusive prefix sum +cub::DeviceScan::ExclusiveSum( + d_temp_storage, temp_storage_bytes, + d_in, d_out, num_items); + +// d_out <-- [0, 8, 14, 21, 26, 29, 29] +``` + + + + +nodiscard + +Computes a device-wide exclusive prefix sum. The value of `0` is applied as the initial value, and is assigned to `*d_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::ExclusiveSum( + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + EnvT env = {} +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative sum operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +When `d_in` and `d_out` are equal, the scan is performed in-place. The range `[d_in, d_in + num_items)` and `[d_out, d_out + num_items)` shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan outputs (may be a simple pointer type) + + + +**[inferred]** An integral type representing the number of input elements + + + +**[inferred]** Execution environment type. Default is `::cuda::std::execution::env<>`. + + +**Parameters** + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is `::cuda::std::execution::env{}`. +//! + + + + + +Computes a device-wide exclusive prefix sum in-place. The value of `0` is applied as the initial value, and is assigned to `*d_data`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::ExclusiveSum( + void *d_temp_storage, + size_t &temp_storage_bytes, + IteratorT d_data, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative sum operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access iterator type for reading scan inputs and wrigin scan outputs + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the sequence of data items + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the exclusive prefix sum of an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_data; // e.g., [8, 6, 7, 5, 3, 0, 9] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::ExclusiveSum( + d_temp_storage, temp_storage_bytes, + d_data, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run exclusive prefix sum +cub::DeviceScan::ExclusiveSum( + d_temp_storage, temp_storage_bytes, + d_data, num_items); + +// d_data <-- [0, 8, 14, 21, 26, 29, 29] +``` + + + + +### ExclusiveScan inline static + + + + +Computes a device-wide exclusive prefix scan using the specified binary associative `scan_op` functor. The `init_value` value is applied as the initial value, and is assigned to `*d_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::ExclusiveScan( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + InitValueT init_value, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +When `d_in` and `d_out` are equal, the scan is performed in-place. The range `[d_in, d_in + num_items)` and `[d_out, d_out + num_items)` shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan outputs (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Binary associative scan functor + + + +Initial value to seed the exclusive scan (and is assigned to `*d_out`) + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the exclusive prefix min-scan of an `int` device vector + +```cpp showLineNumbers={false} +#include // or equivalently +#include // for INT_MAX + +// CustomMin functor +struct CustomMin +{ + template + __host__ __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return (b < a) ? b : a; + } +}; + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_out; // e.g., [ , , , , , , ] +CustomMin min_op; +... + +// Determine temporary device storage requirements for exclusive +// prefix scan +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::ExclusiveScan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, min_op, (int) INT_MAX, num_items); + +// Allocate temporary storage for exclusive prefix scan +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run exclusive prefix min-scan +cub::DeviceScan::ExclusiveScan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, min_op, (int) INT_MAX, num_items); + +// d_out <-- [2147483647, 8, 6, 6, 5, 3, 0] +``` + + + + +nodiscard + +Computes a device-wide exclusive prefix scan using the specified binary associative `scan_op` functor. The `init_value` value is applied as the initial value, and is assigned to `*d_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::ExclusiveScan( + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + InitValueT init_value, + NumItemsT num_items, + EnvT env = {} +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +When `d_in` and `d_out` are equal, the scan is performed in-place. The range `[d_in, d_in + num_items)` and `[d_out, d_out + num_items)` shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan outputs (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + + +**[inferred]** An integral type representing the number of input elements + + + +**[inferred]** Execution environment type. Default is `::cuda::std::execution::env<>`. + + +**Parameters** + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Binary associative scan functor + + + +Initial value to seed the exclusive scan (and is assigned to `*d_out`) + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is `::cuda::std::execution::env{}`. +//! + + + + + +Computes a device-wide exclusive prefix scan using the specified binary associative `scan_op` functor. The `init_value` value is applied as the initial value, and is assigned to `*d_data`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::ExclusiveScan( + void *d_temp_storage, + size_t &temp_storage_bytes, + IteratorT d_data, + ScanOpT scan_op, + InitValueT init_value, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs and writing scan outputs + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the sequence of data items + + + +Binary associative scan functor + + + +Initial value to seed the exclusive scan (and is assigned to `*d_out`) + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the exclusive prefix min-scan of an `int` device vector: + +```cpp showLineNumbers={false} +#include // or equivalently +#include // for INT_MAX + +// CustomMin functor +struct CustomMin +{ + template + __host__ __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return (b < a) ? b : a; + } +}; + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_data; // e.g., [8, 6, 7, 5, 3, 0, 9] +CustomMin min_op; +... + +// Determine temporary device storage requirements for exclusive +// prefix scan +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::ExclusiveScan( + d_temp_storage, temp_storage_bytes, + d_data, min_op, (int) INT_MAX, num_items); + +// Allocate temporary storage for exclusive prefix scan +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run exclusive prefix min-scan +cub::DeviceScan::ExclusiveScan( + d_temp_storage, temp_storage_bytes, + d_data, min_op, (int) INT_MAX, num_items); + +// d_data <-- [2147483647, 8, 6, 6, 5, 3, 0] +``` + + + + +Computes a device-wide exclusive prefix scan using the specified binary associative `scan_op` functor. The `init_value` value is provided as a future value. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::ExclusiveScan( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + FutureValue init_value, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +When `d_in` and `d_out` are equal, the scan is performed in-place. The range `[d_in, d_in + num_items)` and `[d_out, d_out + num_items)` shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan outputs (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output sequence of data items + + + +Binary associative scan functor + + + +Initial value to seed the exclusive scan (and is assigned to `*d_out`) + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the exclusive prefix min-scan of an `int` device vector + +```cpp showLineNumbers={false} +#include // or equivalently +#include // for INT_MAX + +// CustomMin functor +struct CustomMin +{ + template + __host__ __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return (b < a) ? b : a; + } +}; + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_out; // e.g., [ , , , , , , ] +int *d_init_iter; // e.g., INT_MAX +CustomMin min_op; + +auto future_init_value = + cub::FutureValue(d_init_iter); + +... + +// Determine temporary device storage requirements for exclusive +// prefix scan +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::ExclusiveScan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, min_op, future_init_value, num_items); + +// Allocate temporary storage for exclusive prefix scan +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run exclusive prefix min-scan +cub::DeviceScan::ExclusiveScan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, min_op, future_init_value, num_items); + +// d_out <-- [2147483647, 8, 6, 6, 5, 3, 0] +``` + + + + +Computes a device-wide exclusive prefix scan using the specified binary associative `scan_op` functor. The `init_value` value is provided as a future value. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::ExclusiveScan( + void *d_temp_storage, + size_t &temp_storage_bytes, + IteratorT d_data, + ScanOpT scan_op, + FutureValue init_value, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs and writing scan outputs + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the sequence of data items + + + +Binary associative scan functor + + + +Initial value to seed the exclusive scan (and is assigned to `*d_out`) + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the exclusive prefix min-scan of an `int` device vector + +```cpp showLineNumbers={false} +#include // or equivalently +#include // for INT_MAX + +// CustomMin functor +struct CustomMin +{ + template + __host__ __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return (b < a) ? b : a; + } +}; + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_data; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_init_iter; // e.g., INT_MAX +CustomMin min_op; + +auto future_init_value = + cub::FutureValue(d_init_iter); + +... + +// Determine temporary device storage requirements for exclusive +// prefix scan +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::ExclusiveScan( + d_temp_storage, temp_storage_bytes, + d_data, min_op, future_init_value, num_items); + +// Allocate temporary storage for exclusive prefix scan +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run exclusive prefix min-scan +cub::DeviceScan::ExclusiveScan( + d_temp_storage, temp_storage_bytes, + d_data, min_op, future_init_value, num_items); + +// d_data <-- [2147483647, 8, 6, 6, 5, 3, 0] +``` + + + + +--- + +## Inclusive scans + +### InclusiveSum inline static + + + + +Computes a device-wide inclusive prefix sum. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::InclusiveSum( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative sum operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +When `d_in` and `d_out` are equal, the scan is performed in-place. The range `[d_in, d_in + num_items)` and `[d_out, d_out + num_items)` shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan outputs (may be a simple pointer type) + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the inclusive prefix sum of an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_out; // e.g., [ , , , , , , ] +... + +// Determine temporary device storage requirements for inclusive +// prefix sum +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, + d_in, d_out, num_items); + +// Allocate temporary storage for inclusive prefix sum +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run inclusive prefix sum +cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, + d_in, d_out, num_items); + +// d_out <-- [8, 14, 21, 26, 29, 29, 38] +``` + + + + +Computes a device-wide inclusive prefix sum in-place. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::InclusiveSum( + void *d_temp_storage, + size_t &temp_storage_bytes, + IteratorT d_data, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative sum operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs and writing scan outputs + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the sequence of data items + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the inclusive prefix sum of an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_data; // e.g., [8, 6, 7, 5, 3, 0, 9] +... + +// Determine temporary device storage requirements for inclusive +// prefix sum +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, + d_data, num_items); + +// Allocate temporary storage for inclusive prefix sum +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run inclusive prefix sum +cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, + d_data, num_items); + +// d_data <-- [8, 14, 21, 26, 29, 29, 38] +``` + + + + +### InclusiveScan inline static + + + + +Computes a device-wide inclusive prefix scan using the specified binary associative `scan_op` functor. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::InclusiveScan( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +When `d_in` and `d_out` are equal, the scan is performed in-place. The range `[d_in, d_in + num_items)` and `[d_out, d_out + num_items)` shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan outputs (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Binary associative scan functor + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the inclusive prefix min-scan of an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently +#include // for INT_MAX + +// CustomMin functor +struct CustomMin +{ + template + __host__ __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return (b < a) ? b : a; + } +}; + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_out; // e.g., [ , , , , , , ] +CustomMin min_op; +... + +// Determine temporary device storage requirements for inclusive +// prefix scan +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::InclusiveScan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, min_op, num_items); + +// Allocate temporary storage for inclusive prefix scan +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run inclusive prefix min-scan +cub::DeviceScan::InclusiveScan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, min_op, num_items); + +// d_out <-- [8, 6, 6, 5, 3, 0, 0] +``` + + + + +Computes a device-wide inclusive prefix scan using the specified binary associative `scan_op` functor. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::InclusiveScan( + void *d_temp_storage, + size_t &temp_storage_bytes, + IteratorT d_data, + ScanOpT scan_op, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs and writing scan outputs + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the sequence of data items + + + +Binary associative scan functor + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the inclusive prefix min-scan of an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently +#include // for INT_MAX + +// CustomMin functor +struct CustomMin +{ + template + __host__ __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return (b < a) ? b : a; + } +}; + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_data; // e.g., [8, 6, 7, 5, 3, 0, 9] +CustomMin min_op; +... + +// Determine temporary device storage requirements for inclusive +// prefix scan +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::InclusiveScan( + d_temp_storage, temp_storage_bytes, + d_data, min_op, num_items); + +// Allocate temporary storage for inclusive prefix scan +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run inclusive prefix min-scan +cub::DeviceScan::InclusiveScan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, min_op, num_items); + +// d_data <-- [8, 6, 6, 5, 3, 0, 0] +``` + + + + +nodiscard + +Computes a device-wide inclusive prefix scan using the specified binary associative `scan_op` functor. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::InclusiveScan( + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + NumItemsT num_items, + EnvT env = {} +) +``` + + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +When `d_in` and `d_out` are equal, the scan is performed in-place. The range `[d_in, d_in + num_items)` and `[d_out, d_out + num_items)` shall not overlap in any other way. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan outputs (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** An integral type representing the number of input elements + + + +**[inferred]** Execution environment type. Default is `::cuda::std::execution::env<>`. + + +**Parameters** + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Binary associative scan functor + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is `::cuda::std::execution::env{}`. +//! + + + + + +### InclusiveScanInit inline static + + + + +Computes a device-wide inclusive prefix scan using the specified binary associative `scan_op` functor. The result of applying the `scan_op` binary operator to `init_value` value and `*d_in` is assigned to `*d_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::InclusiveScanInit( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + InitValueT init_value, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +When `d_in` and `d_out` are equal, the scan is performed in-place. The range `[d_in, d_in + num_items)` and `[d_out, d_out + num_items)` shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan outputs (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to the size in bytes of the `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Binary associative scan functor + + + +Initial value to seed the inclusive scan (`scan_op(init_value, d_in[0])` is assigned to `*d_out`) + + + +Total number of input items (i.e., the length of `d_in`) + + + +CUDA stream to launch kernels within. + + + + + +nodiscard + +Computes a device-wide inclusive prefix scan using the specified binary associative `scan_op` functor. The result of applying the `scan_op` binary operator to `init_value` value and `*d_in` is assigned to `*d_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::InclusiveScanInit( + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + InitValueT init_value, + NumItemsT num_items, + EnvT env = {} +) +``` + + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +When `d_in` and `d_out` are equal, the scan is performed in-place. The range `[d_in, d_in + num_items)` and `[d_out, d_out + num_items)` shall not overlap in any other way. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan outputs (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + + +**[inferred]** An integral type representing the number of input elements + + + +**[inferred]** Execution environment type. Default is `::cuda::std::execution::env<>`. + + +**Parameters** + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Binary associative scan functor + + + +Initial value to seed the inclusive scan (`scan_op(init_value, d_in[0])` is assigned to `*d_out`) + + + +Total number of input items (i.e., the length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is `::cuda::std::execution::env{}`. +//! + + + + + +--- + +## Scans by key + +### ExclusiveSumByKey inline static + +Computes a device-wide exclusive prefix sum-by-key with key equality defined by `equality_op`. The value of `0` is applied as the initial value, and is assigned to the beginning of each segment in `d_values_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::ExclusiveSumByKey( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + NumItemsT num_items, + EqualityOpT equality_op = EqualityOpT(), + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative sum operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +`d_keys_in` may equal `d_values_out` but the range `[d_keys_in, d_keys_in + num_items)` and the range `[d_values_out, d_values_out + num_items)` shall not overlap otherwise. +`d_values_in` may equal `d_values_out` but the range `[d_values_in, d_values_in + num_items)` and the range `[d_values_out, d_values_out + num_items)` shall not overlap otherwise. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan keys inputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading scan values inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan values outputs (may be a simple pointer type) + + + +**[inferred]** Functor type having member `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access input iterator to the input sequence of key items + + + +Random-access input iterator to the input sequence of value items + + + +Random-access output iterator to the output sequence of value items + + + +Total number of input items (i.e., the length of `d_keys_in` and `d_values_in`) + + + +Binary functor that defines the equality of keys. Default is cuda::std::equal_to<>{}. + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the exclusive prefix sum-by-key of an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_keys_in; // e.g., [0, 0, 1, 1, 1, 2, 2] +int *d_values_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_values_out; // e.g., [ , , , , , , ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::ExclusiveSumByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run exclusive prefix sum +cub::DeviceScan::ExclusiveSumByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, num_items); + +// d_values_out <-- [0, 8, 0, 7, 12, 0, 0] +``` + +### ExclusiveScanByKey inline static + +Computes a device-wide exclusive prefix scan-by-key using the specified binary associative `scan_op` functor. The key equality is defined by `equality_op`. The `init_value` value is applied as the initial value, and is assigned to the beginning of each segment in `d_values_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::ExclusiveScanByKey( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + ScanOpT scan_op, + InitValueT init_value, + NumItemsT num_items, + EqualityOpT equality_op = EqualityOpT(), + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +`d_keys_in` may equal `d_values_out` but the range `[d_keys_in, d_keys_in + num_items)` and the range `[d_values_out, d_values_out + num_items)` shall not overlap otherwise. +`d_values_in` may equal `d_values_out` but the range `[d_values_in, d_values_in + num_items)` and the range `[d_values_out, d_values_out + num_items)` shall not overlap otherwise. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan keys inputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading scan values inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan values outputs (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + + +**[inferred]** Functor type having member `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access input iterator to the input sequence of key items + + + +Random-access input iterator to the input sequence of value items + + + +Random-access output iterator to the output sequence of value items + + + +Binary associative scan functor + + + +Initial value to seed the exclusive scan (and is assigned to the beginning of each segment in `d_values_out`) + + + +Total number of input items (i.e., the length of `d_keys_in` and `d_values_in`) + + + +Binary functor that defines the equality of keys. Default is cuda::std::equal_to<>{}. + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the exclusive prefix min-scan-by-key of an `int` device vector + +```cpp showLineNumbers={false} +#include // or equivalently +#include // for INT_MAX + +// CustomMin functor +struct CustomMin +{ + template + __host__ __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return (b < a) ? b : a; + } +}; + +// CustomEqual functor +struct CustomEqual +{ + template + __host__ __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return a == b; + } +}; + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_keys_in; // e.g., [0, 0, 1, 1, 1, 2, 2] +int *d_values_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_values_out; // e.g., [ , , , , , , ] +CustomMin min_op; +CustomEqual equality_op; +... + +// Determine temporary device storage requirements for exclusive +// prefix scan +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::ExclusiveScanByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, min_op, + (int) INT_MAX, num_items, equality_op); + +// Allocate temporary storage for exclusive prefix scan +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run exclusive prefix min-scan +cub::DeviceScan::ExclusiveScanByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, min_op, + (int) INT_MAX, num_items, equality_op); + +// d_values_out <-- [2147483647, 8, 2147483647, 7, 5, 2147483647, 0] +``` + +### InclusiveSumByKey inline static + +Computes a device-wide inclusive prefix sum-by-key with key equality defined by `equality_op`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::InclusiveSumByKey( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + NumItemsT num_items, + EqualityOpT equality_op = EqualityOpT(), + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative sum operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +`d_keys_in` may equal `d_values_out` but the range `[d_keys_in, d_keys_in + num_items)` and the range `[d_values_out, d_values_out + num_items)` shall not overlap otherwise. +`d_values_in` may equal `d_values_out` but the range `[d_values_in, d_values_in + num_items)` and the range `[d_values_out, d_values_out + num_items)` shall not overlap otherwise. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan keys inputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading scan values inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan values outputs (may be a simple pointer type) + + + +**[inferred]** Functor type having member `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access input iterator to the input sequence of key items + + + +Random-access input iterator to the input sequence of value items + + + +Random-access output iterator to the output sequence of value items + + + +Total number of input items (i.e., the length of `d_keys_in` and `d_values_in`) + + + +Binary functor that defines the equality of keys. Default is cuda::std::equal_to<>{}. + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the inclusive prefix sum-by-key of an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_keys_in; // e.g., [0, 0, 1, 1, 1, 2, 2] +int *d_values_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_values_out; // e.g., [ , , , , , , ] +... + +// Determine temporary device storage requirements for inclusive prefix sum +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::InclusiveSumByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, num_items); + +// Allocate temporary storage for inclusive prefix sum +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run inclusive prefix sum +cub::DeviceScan::InclusiveSumByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, num_items); + +// d_out <-- [8, 14, 7, 12, 15, 0, 9] +``` + +### InclusiveScanByKey inline static + +Computes a device-wide inclusive prefix scan-by-key using the specified binary associative `scan_op` functor. The key equality is defined by `equality_op`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceScan::InclusiveScanByKey( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + ScanOpT scan_op, + NumItemsT num_items, + EqualityOpT equality_op = EqualityOpT(), + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. Additional details can be found in the @lookback description. +`d_keys_in` may equal `d_values_out` but the range `[d_keys_in, d_keys_in + num_items)` and the range `[d_values_out, d_values_out + num_items)` shall not overlap otherwise. +`d_values_in` may equal `d_values_out` but the range `[d_values_in, d_values_in + num_items)` and the range `[d_values_out, d_values_out + num_items)` shall not overlap otherwise. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading scan keys inputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading scan values inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing scan values outputs (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Functor type having member `T operator()(const T &a, const T &b)` for binary operations that defines the equality of keys + + + +**[inferred]** An integral type representing the number of input elements + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access input iterator to the input sequence of key items + + + +Random-access input iterator to the input sequence of value items + + + +Random-access output iterator to the output sequence of value items + + + +Binary associative scan functor + + + +Total number of input items (i.e., the length of `d_keys_in` and `d_values_in`) + + + +Binary functor that defines the equality of keys. Default is cuda::std::equal_to<>{}. + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the inclusive prefix min-scan-by-key of an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently +#include // for INT_MAX + +// CustomMin functor +struct CustomMin +{ + template + __host__ __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return (b < a) ? b : a; + } +}; + +// CustomEqual functor +struct CustomEqual +{ + template + __host__ __device__ __forceinline__ + T operator()(const T &a, const T &b) const { + return a == b; + } +}; + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_items; // e.g., 7 +int *d_keys_in; // e.g., [0, 0, 1, 1, 1, 2, 2] +int *d_values_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_values_out; // e.g., [ , , , , , , ] +CustomMin min_op; +CustomEqual equality_op; +... + +// Determine temporary device storage requirements for inclusive prefix scan +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::InclusiveScanByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, min_op, num_items, equality_op); + +// Allocate temporary storage for inclusive prefix scan +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run inclusive prefix min-scan +cub::DeviceScan::InclusiveScanByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, min_op, num_items, equality_op); + +// d_out <-- [8, 6, 7, 5, 3, 0, 0] +``` diff --git a/fern/cudapages/cub/cub/cub/DeviceSegmentedRadixSort.mdx b/fern/cudapages/cub/cub/cub/DeviceSegmentedRadixSort.mdx new file mode 100644 index 0000000..69c94eb --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceSegmentedRadixSort.mdx @@ -0,0 +1,1260 @@ +--- +title: cub::DeviceSegmentedRadixSort +description: "" +--- + +DeviceSegmentedRadixSort provides device-wide, parallel operations for computing a batched radix sort across multiple, non-overlapping sequences of data items residing within device-accessible memory. + +--- + +## Key-value pairs + +### SortPairs inline static + + + + +Sorts segments of key-value pairs into ascending order. (`~2N` auxiliary storage required) + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedRadixSort::SortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +Let `in` be one of `{d_keys_in, d_values_in}` and `out` be any of `{d_keys_out, d_values_out}`. The range `[out, out + num_items)` shall not overlap `[in, in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +@devicestorageNP For sorting using only `O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_values_in[i]`, `d_keys_out[i]`, `d_values_out[i]` will not be accessed nor modified. +Note, the size of any segment may not exceed `INT_MAX`. Please consider using `DeviceSegmentedSort` instead, if the size of at least one of your segments could exceed `INT_MAX`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +Device-accessible pointer to the corresponding input sequence of associated value items + + + +Device-accessible pointer to the correspondingly-reordered output sequence of associated value items + + + +The total number of items within the segmented array, including items not covered by segments. `num_items` should match the largest element within the range `[d_end_offsets, d_end_offsets + num_segments)`. + + + +The number of segments that comprise the sorting data + + + +Random-access input iterator to the sequence of beginning offsets of length `num_segments`, such that `d_begin_offsets[i]` is the first element of the *i*th data segment in `d_keys_*` and `d_values_*` + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. If +//! ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `int` keys with associated vector of `int` values. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_values_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedRadixSort::SortPairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedRadixSort::SortPairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [6, 7, 8, 0, 3, 5, 9] +// d_values_out <-- [1, 2, 0, 5, 4, 3, 6] +``` + + + + +Sorts segments of key-value pairs into ascending order. (`~N` auxiliary storage required) + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedRadixSort::SortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers within each pair may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +Let `cur` be one of `{d_keys.Current(), d_values.Current()}` and `alt` be any of `{d_keys.Alternate(), d_values.Alternate()}`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys.Current()[i]`, `d_values.Current()[i]`, `d_keys.Alternate()[i]`, `d_values.Alternate()[i]` will not be accessed nor modified. +Note, the size of any segment may not exceed `INT_MAX`. Please consider using `DeviceSegmentedSort` instead, if the size of at least one of your segments could exceed `INT_MAX`. +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +The total number of items within the segmented array, including items not covered by segments. `num_items` should match the largest element within the range `[d_end_offsets, d_end_offsets + num_segments)`. + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `int` keys with associated vector of `int` values. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_value_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a set of DoubleBuffers to wrap pairs of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); +cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedRadixSort::SortPairs( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedRadixSort::SortPairs( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [6, 7, 8, 0, 3, 5, 9] +// d_values.Current() <-- [5, 4, 3, 1, 2, 0, 6] +``` + + + + +### SortPairsDescending inline static + + + + +Sorts segments of key-value pairs into descending order. (`~2N` auxiliary storage required). + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedRadixSort::SortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +Let `in` be one of `{d_keys_in, d_values_in}` and `out` be any of `{d_keys_out, d_values_out}`. The range `[out, out + num_items)` shall not overlap `[in, in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +@devicestorageNP For sorting using only `O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_values_in[i]`, `d_keys_out[i]`, `d_values_out[i]` will not be accessed nor modified. +Note, the size of any segment may not exceed `INT_MAX`. Please consider using `DeviceSegmentedSort` instead, if the size of at least one of your segments could exceed `INT_MAX`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +Device-accessible pointer to the corresponding input sequence of associated value items + + + +Device-accessible pointer to the correspondingly-reordered output sequence of associated value items + + + +The total number of items within the segmented array, including items not covered by segments. `num_items` should match the largest element within the range `[d_end_offsets, d_end_offsets + num_segments)`. + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `int` keys with associated vector of `int` values. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_values_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [8, 7, 6, 9, 5, 3, 0] +// d_values_out <-- [0, 2, 1, 6, 3, 4, 5] +``` + + + + +Sorts segments of key-value pairs into descending order. (`~N` auxiliary storage required). + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedRadixSort::SortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers within each pair may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +Let `cur` be one of `{d_keys.Current(), d_values.Current()}` and `alt` be any of `{d_keys.Alternate(), d_values.Alternate()}`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys.Current()[i]`, `d_values.Current()[i]`, `d_keys.Alternate()[i]`, `d_values.Alternate()[i]` will not be accessed nor modified. not to be modified. +Note, the size of any segment may not exceed `INT_MAX`. Please consider using `DeviceSegmentedSort` instead, if the size of at least one of your segments could exceed `INT_MAX`. +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +The total number of items within the segmented array, including items not covered by segments. `num_items` should match the largest element within the range `[d_end_offsets, d_end_offsets + num_segments)`. + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `int` keys with associated vector of `int` values. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_value_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a set of DoubleBuffers to wrap pairs of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); +cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [8, 7, 6, 9, 5, 3, 0] +// d_values.Current() <-- [0, 2, 1, 6, 3, 4, 5] +``` + + + + +--- + +## Keys-only + +### SortKeys inline static + + + + +Sorts segments of keys into ascending order. (`~2N` auxiliary storage required) + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedRadixSort::SortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +The range `[d_keys_out, d_keys_out + num_items)` shall not overlap `[d_keys_in, d_keys_in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +@devicestorageNP For sorting using only `O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_keys_out[i]` will not be accessed nor modified. +Note, the size of any segment may not exceed `INT_MAX`. Please consider using `DeviceSegmentedSort` instead, if the size of at least one of your segments could exceed `INT_MAX`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +The total number of items within the segmented array, including items not covered by segments. `num_items` should match the largest element within the range `[d_end_offsets, d_end_offsets + num_segments)`. + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `int` keys. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedRadixSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedRadixSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [6, 7, 8, 0, 3, 5, 9] +``` + + + + +Sorts segments of keys into ascending order. (`~N` auxiliary storage required). + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedRadixSort::SortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +Let `cur = d_keys.Current()` and `alt = d_keys.Alternate()`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys.Current()[i]`, `d_keys[i].Alternate()[i]` will not be accessed nor modified. +Note, the size of any segment may not exceed `INT_MAX`. Please consider using `DeviceSegmentedSort` instead, if the size of at least one of your segments could exceed `INT_MAX`. +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +The total number of items within the segmented array, including items not covered by segments. `num_items` should match the largest element within the range `[d_end_offsets, d_end_offsets + num_segments)`. + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `int` keys. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a DoubleBuffer to wrap the pair of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedRadixSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedRadixSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [6, 7, 8, 0, 3, 5, 9] +``` + + + + +### SortKeysDescending inline static + + + + +Sorts segments of keys into descending order. (`~2N` auxiliary storage required). + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedRadixSort::SortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +The range `[d_keys_out, d_keys_out + num_items)` shall not overlap `[d_keys_in, d_keys_in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +@devicestorageNP For sorting using only `O(P)` temporary storage, see the sorting interface using DoubleBuffer wrappers below. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_keys_out[i]` will not be accessed nor modified. +Note, the size of any segment may not exceed `INT_MAX`. Please consider using `DeviceSegmentedSort` instead, if the size of at least one of your segments could exceed `INT_MAX`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +The total number of items within the segmented array, including items not covered by segments. `num_items` should match the largest element within the range `[d_end_offsets, d_end_offsets + num_segments)`. + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., sizeof(unsigned int) * 8) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `int` keys. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +... + +// Create a DoubleBuffer to wrap the pair of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedRadixSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedRadixSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [8, 7, 6, 9, 5, 3, 0] +``` + + + + +Sorts segments of keys into descending order. (`~N` auxiliary storage required). + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedRadixSort::SortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) *8, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +An optional bit subrange `[begin_bit, end_bit)` of differentiating key bits can be specified. This can reduce overall sorting overhead and yield a corresponding performance improvement. +Let `cur = d_keys.Current()` and `alt = d_keys.Alternate()`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys.Current()[i]`, `d_keys[i].Alternate()[i]` will not be accessed nor modified. +Note, the size of any segment may not exceed `INT_MAX`. Please consider using `DeviceSegmentedSort` instead, if the size of at least one of your segments could exceed `INT_MAX`. +@devicestorageP +@devicestorage + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +The total number of items within the segmented array, including items not covered by segments. `num_items` should match the largest element within the range `[d_end_offsets, d_end_offsets + num_segments)`. + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +**[optional]** The least-significant bit index (inclusive) needed for key comparison + + + +**[optional]** The most-significant bit index (exclusive) needed for key comparison (e.g., `sizeof(unsigned int) * 8`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `int` keys. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a DoubleBuffer to wrap the pair of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedRadixSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedRadixSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [8, 7, 6, 9, 5, 3, 0] +``` + + + + +--- + +## Utility methods + +### GetName inline static constexpr + + +```cpp showLineNumbers={false} +static constexpr const char * cub::DeviceSegmentedRadixSort::GetName() +``` + diff --git a/fern/cudapages/cub/cub/cub/DeviceSegmentedReduce.mdx b/fern/cudapages/cub/cub/cub/DeviceSegmentedReduce.mdx new file mode 100644 index 0000000..9ec2643 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceSegmentedReduce.mdx @@ -0,0 +1,1181 @@ +--- +title: cub::DeviceSegmentedReduce +description: "" +--- + +DeviceSegmentedReduce provides device-wide, parallel operations for computing a reduction across multiple sequences of data items residing within device-accessible memory. + +--- + +## Static methods + +### Reduce inline static + + + + +Computes a device-wide segmented reduction using the specified binary `reduction_op` functor. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::Reduce( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + ReductionOpT reduction_op, + T initial_value, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Does not support binary reduction operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +Let `s` be in `[0, num_segments)`. The range `[d_out + d_begin_offsets[s], d_out + d_end_offsets[s])` shall not overlap `[d_in + d_begin_offsets[s], d_in + d_end_offsets[s])`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Data element type that is convertible to the `value` type of `InputIteratorT` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +Binary reduction functor + + + +Initial value of the reduction for each segment + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +Computes a device-wide segmented reduction using the specified binary `reduction_op` functor and a fixed segment size. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::Reduce( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + int segment_size, + ReductionOpT reduction_op, + T initial_value, + cudaStream_t stream = 0 +) +``` + + +*Added in v3.2.0. First appears in CUDA Toolkit 13.2.* + + +Does not support binary reduction operators that are non-commutative. +Provides "run-to-run" determinism for pseudo-associative reduction (e.g., addition of floating point types) on the same GPU device. However, results for pseudo-associative reduction may be inconsistent from one device to a another device of a different compute-capability because CUB can employ different tile-sizing for different architectures. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Data element type that is convertible to the `value` type of `InputIteratorT` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregates + + + +The number of segments that comprise the segmented reduction data + + + +The fixed segment size of each segment + + + +Binary reduction functor + + + +Initial value of the reduction for each segment + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +### Sum inline static + + + + +Computes a device-wide segmented sum using the addition (`+`) operator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::Sum( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses `0` as the initial value of the reduction for each segment. +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +Does not support `+` operators that are non-commutative. +Let `s` be in `[0, num_segments)`. The range `[d_out + d_begin_offsets[s], d_out + d_end_offsets[s])` shall not overlap `[d_in + d_begin_offsets[s], d_in + d_end_offsets[s])`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments`, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +Computes a device-wide segmented sum using the addition (`+`) operator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::Sum( + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + EnvT env = {} +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses `0` as the initial value of the reduction for each segment. +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +Does not support `+` operators that are non-commutative. +Let `s` be in `[0, num_segments)`. The range `[d_out + d_begin_offsets[s], d_out + d_end_offsets[s])` shall not overlap `[d_in + d_begin_offsets[s], d_in + d_end_offsets[s])`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)`. +Can use a specific stream or cuda memory resource through the `env` parameter +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + + +**[inferred]** Execution environment type. Default is `cuda::std::execution::env<>`. + + +**Parameters** + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments`, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``. +//! + + + + + +Computes a device-wide segmented sum using the addition (`+`) operator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::Sum( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + int segment_size, + cudaStream_t stream = 0 +) +``` + + +*Added in v3.2.0. First appears in CUDA Toolkit 13.2.* + + +Uses `0` as the initial value of the reduction for each segment. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +The fixed segment size of each segment + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +### Min inline static + + + + +Computes a device-wide segmented minimum using the less-than (`<`) operator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::Min( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses `::cuda::std::numeric_limits::max()` as the initial value of the reduction for each segment. +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +Does not support `<` operators that are non-commutative. +Let `s` be in `[0, num_segments)`. The range `[d_out + d_begin_offsets[s], d_out + d_end_offsets[s])` shall not overlap `[d_in + d_begin_offsets[s], d_in + d_end_offsets[s])`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +Computes a device-wide segmented minimum using the less-than (`<`) operator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::Min( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + int segment_size, + cudaStream_t stream = 0 +) +``` + + +*Added in v3.2.0. First appears in CUDA Toolkit 13.2.* + + +Uses `::cuda::std::numeric_limits::max()` as the initial value of the reduction for each segment. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +The fixed segment size of each segment + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +### ArgMin inline static + + + + +Finds the first device-wide minimum in each segment using the less-than (`<`) operator, also returning the in-segment index of that item. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::ArgMin( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The output value type of `d_out` is `cub::KeyValuePair` (assuming the value type of `d_in` is `T`) +The minimum of the *i*th segment is written to `d_out[i].value` and its offset in that segment is written to `d_out[i].key`. +The `{1, ::cuda::std::numeric_limits::max()}` tuple is produced for zero-length inputs +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +Does not support `<` operators that are non-commutative. +Let `s` be in `[0, num_segments)`. The range `[d_out + d_begin_offsets[s], d_out + d_end_offsets[s])` shall not overlap `[d_in + d_begin_offsets[s], d_in + d_end_offsets[s])`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (of some type `T`) (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (having value type `KeyValuePair`) (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +Finds the first device-wide minimum in each segment using the less-than (`<`) operator, also returning the in-segment index of that item. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::ArgMin( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + int segment_size, + cudaStream_t stream = 0 +) +``` + + +*Added in v3.2.0. First appears in CUDA Toolkit 13.2.* + + +The output value type of `d_out` is `::cuda::std::pair` (assuming the value type of `d_in` is `T`) +The minimum of the *i*th segment is written to `d_out[i].second` and its offset in that segment is written to `d_out[i].first`. +The `{1, ::cuda::std::numeric_limits::max()}` tuple is produced for zero-length inputs + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (of some type `T`) (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (having value type `cuda::std::pair`) (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +The fixed segment size of each segment + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +### Max inline static + + + + +Computes a device-wide segmented maximum using the greater-than (`>`) operator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::Max( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Uses `::cuda::std::numeric_limits::lowest()` as the initial value of the reduction. +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +Does not support `>` operators that are non-commutative. +Let `s` be in `[0, num_segments)`. The range `[d_out + d_begin_offsets[s], d_out + d_end_offsets[s])` shall not overlap `[d_in + d_begin_offsets[s], d_in + d_end_offsets[s])`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +Computes a device-wide segmented maximum using the greater-than (`>`) operator. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::Max( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + int segment_size, + cudaStream_t stream = 0 +) +``` + + +*Added in v3.2.0. First appears in CUDA Toolkit 13.2.* + + +Uses `::cuda::std::numeric_limits::lowest()` as the initial value of the reduction. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +The fixed segment size of each segment + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +### ArgMax inline static + + + + +Finds the first device-wide maximum in each segment using the greater-than (`>`) operator, also returning the in-segment index of that item + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::ArgMax( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The output value type of `d_out` is `cub::KeyValuePair` (assuming the value type of `d_in` is `T`) +The maximum of the *i*th segment is written to `d_out[i].value` and its offset in that segment is written to `d_out[i].key`. +The `{1, ::cuda::std::numeric_limits::lowest()}` tuple is produced for zero-length inputs +When input a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +Does not support `>` operators that are non-commutative. +Let `s` be in `[0, num_segments)`. The range `[d_out + d_begin_offsets[s], d_out + d_end_offsets[s])` shall not overlap `[d_in + d_begin_offsets[s], d_in + d_end_offsets[s])`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)`. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (of some type `T`) (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (having value type `KeyValuePair`) (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length `num_segments`, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the *i*\ :sup:`th` is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +Finds the first device-wide maximum in each segment using the greater-than (`>`) operator, also returning the in-segment index of that item + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedReduce::ArgMax( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + int segment_size, + cudaStream_t stream = 0 +) +``` + + +*Added in v3.2.0. First appears in CUDA Toolkit 13.2.* + + +The output value type of `d_out` is `::cuda::std::pair` (assuming the value type of `d_in` is `T`) +The maximum of the *i*th segment is written to `d_out[i].second` and its offset in that segment is written to `d_out[i].first`. +The `{1, ::cuda::std::numeric_limits::lowest()}` tuple is produced for zero-length inputs + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (of some type `T`) (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the reduced aggregate (having value type `cuda::std::pair`) (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the segmented reduction data + + + +The fixed segment size of each segment + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + diff --git a/fern/cudapages/cub/cub/cub/DeviceSegmentedScan.mdx b/fern/cudapages/cub/cub/cub/DeviceSegmentedScan.mdx new file mode 100644 index 0000000..f99b279 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceSegmentedScan.mdx @@ -0,0 +1,1186 @@ +--- +title: cub::DeviceSegmentedScan +description: "" +--- + +DeviceSegmentedScan provides device-wide, parallel operations for computing a batched prefix scan across multiple sequences of data items residing within device-accessible memory. + +--- + +## Static methods + +### ExclusiveSegmentedSum inline static + + + + +Computes a device-wide segmented exclusive prefix sum. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedScan::ExclusiveSegmentedSum( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + BeginOffsetIteratorInputT d_in_begin_offsets, + EndOffsetIteratorInputT d_in_end_offsets, + ::cuda::std::int64_t num_segments, + cudaStream_t stream = 0 +) +``` + + + +Results are not deterministic for computation of prefix sum on floating-point types and may vary from run to run. +When `d_in` and `d_out` are equal, the scan is performed in-place. The input and output sequences shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading segmented scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing segmented scan outputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets in the input data sequence (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_in_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_in`` and in ``d_out``. +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_in_end_offsets[i] - 1`` is the last element of +//! the \ *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_in_end_offsets[i] - 1 <= d_in_begin_offsets[i]``, the \ *i*\ :sup:`th` +//! is considered empty. +//! + + + +The number of segments that comprise the segmented prefix scan data. + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the exclusive segmented prefix sum of an `int` device vector. + +```cpp showLineNumbers={false} +#include +// or, equivalently +// #include + +// Declare, allocate, and initialize device-accessible pointers for +// input and output +int num_segments; // e.g., 3 +int *d_in; // e.g., [8, 6, 7, 5, 3, -2, 9] +int *d_offsets; // e.g., [0, 2, 5, 7] +int *d_out; // e.g., [ , , , , , , ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceScan::ExclusiveSegmentedSum( + d_temp_storage, temp_storage_bytes, + d_in, d_out, d_offsets, d_offsets + 1, num_segments); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run exclusive prefix sum +cub::DeviceScan::ExclusiveSegmentedSum( + d_temp_storage, temp_storage_bytes, + d_in, d_out, d_offsets, d_offsets + 1, num_segments); + +// d_out <-- [0, 8, 0, 7, 12, 0, -2] +``` + + + + +Computes a device-wide segmented exclusive prefix sum. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedScan::ExclusiveSegmentedSum( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + BeginOffsetIteratorInputT d_in_begin_offsets, + EndOffsetIteratorInputT d_in_end_offsets, + BeginOffsetIteratorOutputT d_out_begin_offsets, + ::cuda::std::int64_t num_segments, + cudaStream_t stream = 0 +) +``` + + + +Results are not deterministic for computation of prefix sum on floating-point types and may vary from run to run. +When `d_in` and `d_out` are equal, the scan is performed in-place. The input and output sequences shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading segmented scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing segmented scan outputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the output sequence (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_in_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_in_end_offsets[i] - 1`` is the last element of +//! the \ *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_in_end_offsets[i] - 1 <= d_in_begin_offsets[i]``, the \ *i*\ :sup:`th` +//! is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_out_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_out`` +//! + + + +The number of segments that comprise the segmented prefix scan data. + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +### ExclusiveSegmentedScan inline static + + + + +Computes a device-wide segmented exclusive prefix scan using the specified binary associative `scan_op` functor. The `init_value` value is applied as the initial value, and is assigned to the first element in each output segment. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedScan::ExclusiveSegmentedScan( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + BeginOffsetIteratorInputT d_in_begin_offsets, + EndOffsetIteratorInputT d_in_end_offsets, + ::cuda::std::int64_t num_segments, + ScanOpT scan_op, + InitValueT init_value, + cudaStream_t stream = 0 +) +``` + + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. +When `d_in` and `d_out` are equal, the scan is performed in-place. The input and output sequences shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading segmented scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing segmented scan outputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_in_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_in`` and in ``d_out`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_in_end_offsets[i] - 1`` is the last element of +//! the \ *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_in_end_offsets[i] - 1 <= d_in_begin_offsets[i]``, the \ *i*\ :sup:`th` +//! is considered empty. +//! + + + +The number of segments that comprise the segmented prefix scan data. + + + +Binary associative scan functor + + + +Initial value to seed the exclusive scan for each segment in the output sequence + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +Computes a device-wide segmented exclusive prefix scan using the specified binary associative `scan_op` functor. The `init_value` value is applied as the initial value, and is assigned to the first element in each output segment. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedScan::ExclusiveSegmentedScan( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + BeginOffsetIteratorInputT d_in_begin_offsets, + EndOffsetIteratorInputT d_in_end_offsets, + BeginOffsetIteratorOutputT d_out_begin_offsets, + ::cuda::std::int64_t num_segments, + ScanOpT scan_op, + InitValueT init_value, + cudaStream_t stream = 0 +) +``` + + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. +When `d_in` and `d_out` are equal, the scan is performed in-place. The input and output sequences shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading segmented scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing segmented scan outputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the output sequence (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_in_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_in_end_offsets[i] - 1`` is the last element of +//! the \ *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_in_end_offsets[i] - 1 <= d_in_begin_offsets[i]``, the \ *i*\ :sup:`th` +//! is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_out_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_out`` +//! + + + +The number of segments that comprise the segmented prefix scan data. + + + +Binary associative scan functor + + + +Initial value to seed the exclusive scan for each segment in the output sequence + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +### InclusiveSegmentedSum inline static + + + + +Computes a device-wide segmented inclusive prefix sum. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedScan::InclusiveSegmentedSum( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + BeginOffsetIteratorInputT d_in_begin_offsets, + EndOffsetIteratorInputT d_in_end_offsets, + ::cuda::std::int64_t num_segments, + cudaStream_t stream = 0 +) +``` + + + +Results are not deterministic for computation of prefix sum on floating-point types and may vary from run to run. +When `d_in` and `d_out` are equal, the scan is performed in-place. The input and output sequences shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading segmented scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing segmented scan outputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets in the input data sequence (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_in_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_in`` and in ``d_out`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_in_end_offsets[i] - 1`` is the last element of +//! the \ *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_in_end_offsets[i] - 1 <= d_in_begin_offsets[i]``, the \ *i*\ :sup:`th` +//! is considered empty. +//! + + + +The number of segments that comprise the segmented prefix scan data. + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +Computes a device-wide segmented inclusive prefix sum. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedScan::InclusiveSegmentedSum( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + BeginOffsetIteratorInputT d_in_begin_offsets, + EndOffsetIteratorInputT d_in_end_offsets, + BeginOffsetIteratorOutputT d_out_begin_offsets, + ::cuda::std::int64_t num_segments, + cudaStream_t stream = 0 +) +``` + + + +Results are not deterministic for computation of prefix sum on floating-point types and may vary from run to run. +When `d_in` and `d_out` are equal, the scan is performed in-place. The input and output sequences shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading segmented scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing segmented scan outputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the output sequence (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_in_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_in_end_offsets[i] - 1`` is the last element of +//! the \ *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_in_end_offsets[i] - 1 <= d_in_begin_offsets[i]``, the \ *i*\ :sup:`th` +//! is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_out_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_out`` +//! + + + +The number of segments that comprise the segmented prefix scan data. + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +### InclusiveSegmentedScan inline static + + + + +Computes a device-wide segmented inclusive prefix scan using the specified binary associative `scan_op` functor. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedScan::InclusiveSegmentedScan( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + BeginOffsetIteratorInputT d_in_begin_offsets, + EndOffsetIteratorInputT d_in_end_offsets, + ::cuda::std::int64_t num_segments, + ScanOpT scan_op, + cudaStream_t stream = 0 +) +``` + + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. +When `d_in` and `d_out` are equal, the scan is performed in-place. The input and output sequences shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading segmented scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing segmented scan outputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_in_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_in`` and in ``d_out`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_in_end_offsets[i] - 1`` is the last element of +//! the \ *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_in_end_offsets[i] - 1 <= d_in_begin_offsets[i]``, the \ *i*\ :sup:`th` +//! is considered empty. +//! + + + +The number of segments that comprise the segmented prefix scan data. + + + +Binary associative scan functor + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +Computes a device-wide segmented inclusive prefix scan using the specified binary associative `scan_op` functor. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedScan::InclusiveSegmentedScan( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + BeginOffsetIteratorInputT d_in_begin_offsets, + EndOffsetIteratorInputT d_in_end_offsets, + BeginOffsetIteratorOutputT d_out_begin_offsets, + ::cuda::std::int64_t num_segments, + ScanOpT scan_op, + cudaStream_t stream = 0 +) +``` + + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. +When `d_in` and `d_out` are equal, the scan is performed in-place. The input and output sequences shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading segmented scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing segmented scan outputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the output sequence (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_in_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_in_end_offsets[i] - 1`` is the last element of +//! the \ *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_in_end_offsets[i] - 1 <= d_in_begin_offsets[i]``, the \ *i*\ :sup:`th` +//! is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_out_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_out`` +//! + + + +The number of segments that comprise the segmented prefix scan data. + + + +Binary associative scan functor + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +### InclusiveSegmentedScanInit inline static + + + + +Computes a device-wide segmented inclusive prefix scan using the specified binary associative `scan_op` functor. The result of applying the `scan_op` binary operator to `init_value` value and the first value in each input segment is assigned to the first value of the corresponding output segment. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedScan::InclusiveSegmentedScanInit( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + BeginOffsetIteratorInputT d_in_begin_offsets, + EndOffsetIteratorInputT d_in_end_offsets, + ::cuda::std::int64_t num_segments, + ScanOpT scan_op, + InitValueT init_value, + cudaStream_t stream = 0 +) +``` + + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. +When `d_in` and `d_out` are equal, the scan is performed in-place. The input and output sequences shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading segmented scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing segmented scan outputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_in_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_in`` and in ``d_out`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_in_end_offsets[i] - 1`` is the last element of +//! the \ *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_in_end_offsets[i] - 1 <= d_in_begin_offsets[i]``, the \ *i*\ :sup:`th` +//! is considered empty. +//! + + + +The number of segments that comprise the segmented prefix scan data. + + + +Binary associative scan functor + + + +Initial value to seed the exclusive scan for each segment in the output sequence + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +Computes a device-wide segmented inclusive prefix scan using the specified binary associative `scan_op` functor. The result of applying the `scan_op` binary operator to `init_value` value and the first value in each input segment is assigned to the first value of the corresponding output segment. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedScan::InclusiveSegmentedScanInit( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + BeginOffsetIteratorInputT d_in_begin_offsets, + EndOffsetIteratorInputT d_in_end_offsets, + BeginOffsetIteratorOutputT d_out_begin_offsets, + ::cuda::std::int64_t num_segments, + ScanOpT scan_op, + InitValueT init_value, + cudaStream_t stream = 0 +) +``` + + + +Supports non-commutative scan operators. +Results are not deterministic for pseudo-associative operators (e.g., addition of floating-point types). Results for pseudo-associative operators may vary from run to run. +When `d_in` and `d_out` are equal, the scan is performed in-place. The input and output sequences shall not overlap in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading segmented scan inputs (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing segmented scan outputs (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets in the input data sequence (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets in the output sequence (may be a simple pointer type) + + + +**[inferred]** Binary associative scan functor type having member `T operator()(const T &a, const T &b)` + + + +**[inferred]** Type of the `init_value` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence of data items + + + +Random-access iterator to the output sequence of data items + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_in_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_in`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_in_end_offsets[i] - 1`` is the last element of +//! the \ *i*\ :sup:`th` data segment in ``d_in``. +//! If ``d_in_end_offsets[i] - 1 <= d_in_begin_offsets[i]``, the \ *i*\ :sup:`th` +//! is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_out_begin_offsets[i]`` is the first +//! element of the \ *i*\ :sup:`th` data segment in ``d_out`` +//! + + + +The number of segments that comprise the segmented prefix scan data. + + + +Binary associative scan functor + + + +Initial value to seed the exclusive scan for each segment in the output sequence + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + diff --git a/fern/cudapages/cub/cub/cub/DeviceSegmentedSort.mdx b/fern/cudapages/cub/cub/cub/DeviceSegmentedSort.mdx new file mode 100644 index 0000000..390dc4c --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceSegmentedSort.mdx @@ -0,0 +1,2522 @@ +--- +title: cub::DeviceSegmentedSort +description: "" +--- + +DeviceSegmentedSort provides device-wide, parallel operations for computing a batched sort across multiple, non-overlapping sequences of data items residing within device-accessible memory. + +## Example + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_values_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::SortPairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::SortPairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [6, 7, 8, 0, 3, 5, 9] +// d_values_out <-- [1, 2, 0, 5, 4, 3, 6] +``` + +--- + +## Keys-only + +### SortKeysDescendingNoNVTX inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortKeysDescendingNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortKeysDescendingNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + + + + +### SortKeysNoNVTX inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortKeysNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +### SortPairsNoNVTX inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortPairsNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +### SortKeys inline static + + + + +Sorts segments of keys into ascending order. Approximately `num_items + 2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets+1`). +SortKeys is not guaranteed to be stable. That is, suppose that `i` and `j` are equivalent: neither one is less than the other. It is not guaranteed that the relative order of these two elements will be preserved by sort. +The range `[d_keys_out, d_keys_out + num_items)` shall not overlap `[d_keys_in, d_keys_in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_keys_out[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the i-th segment is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `int` keys. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible +// pointers for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [6, 7, 8, 0, 3, 5, 9] +``` + + + + +Sorts segments of keys into ascending order. Approximately `2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits and the targeted device architecture). +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets +1`). +SortKeys is not guaranteed to be stable. That is, suppose that `i` and `j` are equivalent: neither one is less than the other. It is not guaranteed that the relative order of these two elements will be preserved by sort. +Let `cur = d_keys.Current()` and `alt = d_keys.Alternate()`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys.Current()[i]`, `d_keys[i].Alternate()[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the ``i``-th segment is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible +// pointers for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a DoubleBuffer to wrap the pair of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::SortKeys( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [6, 7, 8, 0, 3, 5, 9] +``` + + + + +### SortKeysDescending inline static + + + + +Sorts segments of keys into descending order. Approximately `num_items + 2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +SortKeysDescending is not guaranteed to be stable. That is, suppose that `i` and `j` are equivalent: neither one is less than the other. It is not guaranteed that the relative order of these two elements will be preserved by sort. +The range `[d_keys_out, d_keys_out + num_items)` shall not overlap `[d_keys_in, d_keys_in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_keys_out[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the ``i``-th segment is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [8, 7, 6, 9, 5, 3, 0] +``` + + + + +Sorts segments of keys into descending order. Approximately `2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits and the targeted device architecture). +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +SortKeysDescending is not guaranteed to be stable. That is, suppose that `i` and `j` are equivalent: neither one is less than the other. It is not guaranteed that the relative order of these two elements will be preserved by sort. +Let `cur = d_keys.Current()` and `alt = d_keys.Alternate()`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys.Current()[i]`, `d_keys[i].Alternate()[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1<= d_begin_offsets[i]``, the ``i``-th segment is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a DoubleBuffer to wrap the pair of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::SortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [8, 7, 6, 9, 5, 3, 0] +``` + + + + +### StableSortKeys inline static + + + + +Sorts segments of keys into ascending order. Approximately `num_items + 2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::StableSortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +StableSortKeys is stable: it preserves the relative ordering of equivalent elements. That is, if `x` and `y` are elements such that `x` precedes `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`) then a postcondition of stable sort is that `x` still precedes `y`. +The range `[d_keys_out, d_keys_out + num_items)` shall not overlap `[d_keys_in, d_keys_in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_keys_out[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the ``i``-th segment is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::StableSortKeys( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::StableSortKeys( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [6, 7, 8, 0, 3, 5, 9] +``` + + + + +Sorts segments of keys into ascending order. Approximately `2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::StableSortKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits and the targeted device architecture). +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +StableSortKeys is stable: it preserves the relative ordering of equivalent elements. That is, if `x` and `y` are elements such that `x` precedes `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`) then a postcondition of stable sort is that `x` still precedes `y`. +Let `cur = d_keys.Current()` and `alt = d_keys.Alternate()`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys.Current()[i]`, `d_keys[i].Alternate()[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the ``i``-th segment is +//! considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a DoubleBuffer to wrap the pair of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::StableSortKeys( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::StableSortKeys( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [6, 7, 8, 0, 3, 5, 9] +``` + + + + +### StableSortKeysDescending inline static + + + + +Sorts segments of keys into descending order. Approximately `num_items + 2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::StableSortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +StableSortKeysDescending is stable: it preserves the relative ordering of equivalent elements. That is, if `x` and `y` are elements such that `x` precedes `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`) then a postcondition of stable sort is that `x` still precedes `y`. +The range `[d_keys_out, d_keys_out + num_items)` shall not overlap `[d_keys_in, d_keys_in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_keys_out[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and +//! ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the ``i``-th segment is +//! considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::StableSortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::StableSortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [8, 7, 6, 9, 5, 3, 0] +``` + + + + +Sorts segments of keys into descending order. Approximately `2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::StableSortKeysDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within the DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits and the targeted device architecture). +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +StableSortKeysDescending is stable: it preserves the relative ordering of equivalent elements. That is, if `x` and `y` are elements such that `x` precedes `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`) then a postcondition of stable sort is that `x` still precedes `y`. +Let `cur = d_keys.Current()` and `alt = d_keys.Alternate()`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values ```i` outside the specified segments `d_keys.Current()[i]`, `d_keys[i].Alternate()[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and +//! ``d_values_*``. If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the +//! ``i``-th segment is considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a DoubleBuffer to wrap the pair of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::StableSortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::StableSortKeysDescending( + d_temp_storage, temp_storage_bytes, d_keys, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [8, 7, 6, 9, 5, 3, 0] +``` + + + + +--- + +## Key-value pairs + +### SortPairsDescendingNoNVTX inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortPairsDescendingNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortPairsDescendingNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + + + + +### SortPairsNoNVTX inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortPairsNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +### SortPairs inline static + + + + +Sorts segments of key-value pairs into ascending order. Approximately `2 * num_items + 2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +SortPairs is not guaranteed to be stable. That is, suppose that `i` and `j` are equivalent: neither one is less than the other. It is not guaranteed that the relative order of these two elements will be preserved by sort. +Let `in` be one of `{d_keys_in, d_values_in}` and `out` be any of `{d_keys_out, d_values_out}`. The range `[out, out + num_items)` shall not overlap `[in, in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_values_in[i]`, `d_keys_out[i]`, `d_values_out[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +Device-accessible pointer to the corresponding input sequence of associated value items + + + +Device-accessible pointer to the correspondingly-reordered output sequence of associated value items + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i]-1 <= d_begin_offsets[i]``, the ``i``-th segment is +//! considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys with associated vector of `i` nt values. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_values_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::SortPairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::SortPairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [6, 7, 8, 0, 3, 5, 9] +// d_values_out <-- [1, 2, 0, 5, 4, 3, 6] +``` + + + + +Sorts segments of key-value pairs into ascending order. Approximately `2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers within each pair may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +SortPairs is not guaranteed to be stable. That is, suppose that `i` and `j` are equivalent: neither one is less than the other. It is not guaranteed that the relative order of these two elements will be preserved by sort. +Let `cur` be one of `{d_keys.Current(), d_values.Current()}` and `alt` be any of `{d_keys.Alternate(), d_values.Alternate()}`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys.Current()[i]`, `d_values.Current()[i]`, `d_keys.Alternate()[i]`, `d_values.Alternate()[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the i-th segment is +//! considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys with associated vector of `i` nt values. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_value_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a set of DoubleBuffers to wrap pairs of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); +cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::SortPairs( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::SortPairs( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [6, 7, 8, 0, 3, 5, 9] +// d_values.Current() <-- [5, 4, 3, 1, 2, 0, 6] +``` + + + + +### SortPairsDescending inline static + + + + +Sorts segments of key-value pairs into descending order. Approximately `2 * num_items + 2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +SortPairs is not guaranteed to be stable. That is, suppose that `i` and `j` are equivalent: neither one is less than the other. It is not guaranteed that the relative order of these two elements will be preserved by sort. +Let `in` be one of `{d_keys_in, d_values_in}` and `out` be any of `{d_keys_out, d_values_out}`. The range `[out, out + num_items)` shall not overlap `[in, in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_values_in[i]`, `d_keys_out[i]`, `d_values_out[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +Device-accessible pointer to the corresponding input sequence of associated value items + + + +Device-accessible pointer to the correspondingly-reordered output sequence of associated value items + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the i-th segment is +//! considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys with associated vector of `i` nt values. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_values_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [8, 7, 6, 9, 5, 3, 0] +// d_values_out <-- [0, 2, 1, 6, 3, 4, 5] +``` + + + + +Sorts segments of key-value pairs into descending order. Approximately `2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers within each pair may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +SortPairsDescending is not guaranteed to be stable. That is, suppose that `i` and `j` are equivalent: neither one is less than the other. It is not guaranteed that the relative order of these two elements will be preserved by sort. +Let `cur` be one of `{d_keys.Current(), d_values.Current()}` and `alt` be any of `{d_keys.Alternate(), d_values.Alternate()}`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys.Current()[i]`, `d_values.Current()[i]`, `d_keys.Alternate()[i]`, `d_values.Alternate()[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the ``i``-th segment is +//! considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys with associated vector of `i` nt values. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for +// sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_value_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a set of DoubleBuffers to wrap pairs of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); +cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [8, 7, 6, 9, 5, 3, 0] +// d_values.Current() <-- [0, 2, 1, 6, 3, 4, 5] +``` + + + + +### StableSortPairs inline static + + + + +Sorts segments of key-value pairs into ascending order. Approximately `2 * num_items + 2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::StableSortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +StableSortPairs is stable: it preserves the relative ordering of equivalent elements. That is, if `x` and `y` are elements such that `x` precedes `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`) then a postcondition of stable sort is that `x` still precedes `y`. +Let `in` be one of `{d_keys_in, d_values_in}` and `out` be any of `{d_keys_out, d_values_out}`. The range `[out, out + num_items)` shall not overlap `[in, in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_values_in[i]`, `d_keys_out[i]`, `d_values_out[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +Device-accessible pointer to the corresponding input sequence of associated value items + + + +Device-accessible pointer to the correspondingly-reordered output sequence of associated value items + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the ``i``-th segment is +//! considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys with associated vector of `i` nt values. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_values_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::StableSortPairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::StableSortPairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [6, 7, 8, 0, 3, 5, 9] +// d_values_out <-- [1, 2, 0, 5, 4, 3, 6] +``` + + + + +Sorts segments of key-value pairs into ascending order. Approximately `2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::StableSortPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers within each pair may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +StableSortPairs is stable: it preserves the relative ordering of equivalent elements. That is, if `x` and `y` are elements such that `x` precedes `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`) then a postcondition of stable sort is that `x` still precedes `y`. +Let `cur` be one of `{d_keys.Current(), d_values.Current()}` and `alt` be any of `{d_keys.Alternate(), d_values.Alternate()}`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys.Current()[i]`, `d_values.Current()[i]`, `d_keys.Alternate()[i]`, `d_values.Alternate()[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i]-1 <= d_begin_offsets[i]``, the ``i``-th segment is +//! considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys with associated vector of `i` nt values. + +```cpp showLineNumbers={false} +#include +// or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_value_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a set of DoubleBuffers to wrap pairs of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); +cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::StableSortPairs( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::StableSortPairs( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [6, 7, 8, 0, 3, 5, 9] +// d_values.Current() <-- [5, 4, 3, 1, 2, 0, 6] +``` + + + + +### StableSortPairsDescending inline static + + + + +Sorts segments of key-value pairs into descending order. Approximately `2 * num_items + 2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::StableSortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The contents of the input data are not altered by the sorting operation. +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +StableSortPairsDescending is stable: it preserves the relative ordering of equivalent elements. That is, if `x` and `y` are elements such that `x` precedes `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`) then a postcondition of stable sort is that `x` still precedes `y`. +Let `in` be one of `{d_keys_in, d_values_in}` and `out` be any of `{d_keys_out, d_values_out}`. The range `[out, out + num_items)` shall not overlap `[in, in + num_items)`, `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys_in[i]`, `d_values_in[i]`, `d_keys_out[i]`, `d_values_out[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Device-accessible pointer to the input data of key data to sort + + + +Device-accessible pointer to the sorted output sequence of key data + + + +Device-accessible pointer to the corresponding input sequence of associated value items + + + +Device-accessible pointer to the correspondingly-reordered output sequence of associated value items + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the ``i``-th segment is +//! considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys with associated vector of `i` nt values. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_keys_in; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_keys_out; // e.g., [-, -, -, -, -, -, -] +int *d_values_in; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_values_out; // e.g., [-, -, -, -, -, -, -] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::StableSortPairsDescending( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::StableSortPairsDescending( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys_out <-- [8, 7, 6, 9, 5, 3, 0] +// d_values_out <-- [0, 2, 1, 6, 3, 4, 5] +``` + + + + +Sorts segments of key-value pairs into descending order. Approximately `2 * num_segments` auxiliary storage required. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::StableSortPairsDescending( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The sorting operation is given a pair of key buffers and a corresponding pair of associated value buffers. Each pair is managed by a DoubleBuffer structure that indicates which of the two buffers is "current" (and thus contains the input data to be sorted). +The contents of both buffers within each pair may be altered by the sorting operation. +Upon completion, the sorting operation will update the "current" indicator within each DoubleBuffer wrapper to reference which of the two buffers now contains the sorted output sequence (a function of the number of key bits specified and the targeted device architecture). +When the input is a contiguous sequence of segments, a single sequence `segment_offsets` (of length `num_segments + 1`) can be aliased for both the `d_begin_offsets` and `d_end_offsets` parameters (where the latter is specified as `segment_offsets + 1`). +StableSortPairsDescending is stable: it preserves the relative ordering of equivalent elements. That is, if `x` and `y` are elements such that `x` precedes `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`) then a postcondition of stable sort is that `x` still precedes `y`. +Let `cur` be one of `{d_keys.Current(), d_values.Current()}` and `alt` be any of `{d_keys.Alternate(), d_values.Alternate()}`. The range `[cur, cur + num_items)` shall not overlap `[alt, alt + num_items)`. Both ranges shall not overlap `[d_begin_offsets, d_begin_offsets + num_segments)` nor `[d_end_offsets, d_end_offsets + num_segments)` in any way. +Segments are not required to be contiguous. For all index values `i` outside the specified segments `d_keys.Current()[i]`, `d_values.Current()[i]`, `d_keys.Alternate()[i]`, `d_values.Alternate()[i]` will not be accessed nor modified. + + +**Template parameters** + + +**[inferred]** Key type + + + +**[inferred]** Value type + + + +**[inferred]** Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Reference to the double-buffer of keys whose "current" device-accessible buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer of values whose "current" device-accessible buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +The total number of items to sort (across all segments) + + + +The number of segments that comprise the sorting data + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of beginning offsets of +//! length ``num_segments``, such that ``d_begin_offsets[i]`` is the first +//! element of the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*`` +//! + + + +Embed:rst:leading-asterisk +//! Random-access input iterator to the sequence of ending offsets of length +//! ``num_segments``, such that ``d_end_offsets[i] - 1`` is the last element of +//! the *i*\ :sup:`th` data segment in ``d_keys_*`` and ``d_values_*``. +//! If ``d_end_offsets[i] - 1 <= d_begin_offsets[i]``, the ``i``-th segment is +//! considered empty. +//! + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the batched sorting of three segments (with one zero-length segment) of `i` nt keys with associated vector of `i` nt values. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for sorting data +int num_items; // e.g., 7 +int num_segments; // e.g., 3 +int *d_offsets; // e.g., [0, 3, 3, 7] +int *d_key_buf; // e.g., [8, 6, 7, 5, 3, 0, 9] +int *d_key_alt_buf; // e.g., [-, -, -, -, -, -, -] +int *d_value_buf; // e.g., [0, 1, 2, 3, 4, 5, 6] +int *d_value_alt_buf; // e.g., [-, -, -, -, -, -, -] +... + +// Create a set of DoubleBuffers to wrap pairs of device pointers +cub::DoubleBuffer d_keys(d_key_buf, d_key_alt_buf); +cub::DoubleBuffer d_values(d_value_buf, d_value_alt_buf); + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSegmentedSort::StableSortPairsDescending( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run sorting operation +cub::DeviceSegmentedSort::StableSortPairsDescending( + d_temp_storage, temp_storage_bytes, d_keys, d_values, + num_items, num_segments, d_offsets, d_offsets + 1); + +// d_keys.Current() <-- [8, 7, 6, 9, 5, 3, 0] +// d_values.Current() <-- [0, 2, 1, 6, 3, 4, 5] +``` + + + + +--- + +## Utility methods + +### GetName inline static constexpr + + +```cpp showLineNumbers={false} +static constexpr const char * cub::DeviceSegmentedSort::GetName() +``` + + +### SortKeysNoNVTX inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSegmentedSort::SortKeysNoNVTX( + void *d_temp_storage, + size_t &temp_storage_bytes, + const KeyT *d_keys_in, + KeyT *d_keys_out, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + cudaStream_t stream = 0 +) +``` + diff --git a/fern/cudapages/cub/cub/cub/DeviceSelect.mdx b/fern/cudapages/cub/cub/cub/DeviceSelect.mdx new file mode 100644 index 0000000..89581c6 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceSelect.mdx @@ -0,0 +1,1326 @@ +--- +title: cub::DeviceSelect +description: "" +--- + +DeviceSelect provides device-wide, parallel operations for compacting selected items from sequences of data items residing within device-accessible memory. It is similar to DevicePartition, except that non-selected items are discarded, whereas DevicePartition retains them. + +## Performance considerations + +@linear_performance{select-flagged, select-if, and select-unique} + +--- + +## Methods + +### select_impl inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSelect::select_impl( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FlagIteratorT d_flags, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + OffsetT num_items, + SelectOpT select_op, + EqualityOpT equality_op, + cudaStream_t stream +) +``` + + +--- + +## Static methods + +### Flagged inline static + + + + +Uses the `d_flags` sequence to selectively copy the corresponding items from `d_in` into `d_out`. The total number of items selected is written to `d_num_selected_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSelect::Flagged( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FlagIterator d_flags, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + ::cuda::std::int64_t num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The value type of `d_flags` must be castable to `bool` (e.g., `bool`, `char`, `int`, etc.). +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering. +| The range `[d_out, d_out + *d_num_selected_out)` shall not overlap `[d_in, d_in + num_items)`, | `[d_flags, d_flags + num_items)` nor `d_num_selected_out` in any way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading selection flags (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing selected items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the input sequence of selection flags + + + +Pointer to the output sequence of selected data items + + + +Pointer to the output total number of items selected (i.e., length of `d_out`) + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the compaction of items selected from an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for input, +// flags, and output +int num_items; // e.g., 8 +int *d_in; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +char *d_flags; // e.g., [1, 0, 0, 1, 0, 1, 1, 0] +int *d_out; // e.g., [ , , , , , , , ] +int *d_num_selected_out; // e.g., [ ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSelect::Flagged( + d_temp_storage, temp_storage_bytes, + d_in, d_flags, d_out, d_num_selected_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run selection +cub::DeviceSelect::Flagged( + d_temp_storage, temp_storage_bytes, + d_in, d_flags, d_out, d_num_selected_out, num_items); + +// d_out <-- [1, 4, 6, 7] +// d_num_selected_out <-- [4] +``` + + + + +nodiscard + +Uses the `d_flags` sequence to selectively copy the corresponding items from `d_in` into `d_out`. The total number of items selected is written to `d_num_selected_out`. + +This is an environment-based API that allows customization of: + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSelect::Flagged( + InputIteratorT d_in, + FlagIterator d_flags, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + NumItemsT num_items, + EnvT env = {} +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Stream: Query via `cuda::get_stream` +Memory resource: Query via `cuda::mr::get_memory_resource` +The value type of `d_flags` must be castable to `bool` (e.g., `bool`, `char`, `int`, etc.). +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering. +| The range `[d_out, d_out + *d_num_selected_out)` shall not overlap `[d_in, d_in + num_items)`, | `[d_flags, d_flags + num_items)` nor `d_num_selected_out` in any way. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading selection flags (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing selected items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Type of num_items + + + +**[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`) + + +**Parameters** + + +Pointer to the input sequence of data items + + + +Pointer to the input sequence of selection flags + + + +Pointer to the output sequence of selected data items + + + +Pointer to the output total number of items selected (i.e., length of `d_out`) + + + +Total number of input items (i.e., length of `d_in`) + + + +**[optional]** Execution environment. Default is `cuda::std::execution::env{}`. + + + + + +Uses the `d_flags` sequence to selectively compact the items in `d_data``. The total number of items selected is written to `d_num_selected_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSelect::Flagged( + void *d_temp_storage, + size_t &temp_storage_bytes, + IteratorT d_data, + FlagIterator d_flags, + NumSelectedIteratorT d_num_selected_out, + ::cuda::std::int64_t num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The value type of `d_flags` must be castable to `bool` (e.g., `bool`, `char`, `int`, etc.). +Copies of the selected items are compacted in-place and maintain their original relative ordering. +| The `d_data` may equal `d_flags`. The range `[d_data, d_data + num_items)` shall not overlap | `[d_flags, d_flags + num_items)` in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access iterator type for reading and writing selected items (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading selection flags (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the sequence of data items + + + +Pointer to the input sequence of selection flags + + + +Pointer to the output total number of items selected + + + +Total number of input items (i.e., length of `d_data`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the compaction of items selected from an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers for input, +// flags, and output +int num_items; // e.g., 8 +int *d_data; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +char *d_flags; // e.g., [1, 0, 0, 1, 0, 1, 1, 0] +int *d_num_selected_out; // e.g., [ ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSelect::Flagged( + d_temp_storage, temp_storage_bytes, + d_in, d_flags, d_num_selected_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run selection +cub::DeviceSelect::Flagged( + d_temp_storage, temp_storage_bytes, + d_in, d_flags, d_num_selected_out, num_items); + +// d_data <-- [1, 4, 6, 7] +// d_num_selected_out <-- [4] +``` + + + + +### If inline static + + + + +nodiscard + +Uses the `select_op` functor to selectively copy items from `d_in` into `d_out`. The total number of items selected is written to `d_num_selected_out`. + +This is an environment-based API that allows customization of: + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSelect::If( + InputIteratorT d_in, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + NumItemsT num_items, + SelectOp select_op, + EnvT env = {} +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Stream: Query via `cuda::get_stream` +Memory resource: Query via `cuda::mr::get_memory_resource` +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering. +| The range `[d_out, d_out + *d_num_selected_out)` shall not overlap | `[d_in, d_in + num_items)` nor `d_num_selected_out` in any way. + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing selected items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Selection operator type having member `bool operator()(const T &a)` + + + +**[inferred]** Type of num_items + + + +**[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`) + + +**Parameters** + + +Pointer to the input sequence of data items + + + +Pointer to the output sequence of selected data items + + + +Pointer to the output total number of items selected (i.e., length of `d_out`) + + + +Total number of input items (i.e., length of `d_in`) + + + +Unary selection operator + + + +**[optional]** Execution environment. Default is `cuda::std::execution::env{}`. + + + + + +Uses the `select_op` functor to selectively copy items from `d_in` into `d_out`. The total number of items selected is written to `d_num_selected_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSelect::If( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + ::cuda::std::int64_t num_items, + SelectOp select_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering. +| The range `[d_out, d_out + *d_num_selected_out)` shall not overlap | `[d_in, d_in + num_items)` nor `d_num_selected_out` in any way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing selected items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Selection operator type having member `bool operator()(const T &a)` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output sequence of selected data items + + + +Pointer to the output total number of items selected (i.e., length of `d_out`) + + + +Total number of input items (i.e., length of `d_in`) + + + +Unary selection operator + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the compaction of items selected from an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Functor type for selecting values less than some criteria +struct LessThan +{ + int compare; + + __host__ __device__ __forceinline__ + LessThan(int compare) : compare(compare) {} + + __host__ __device__ __forceinline__ + bool operator()(const int &a) const { + return (a < compare); + } +}; + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 8 +int *d_in; // e.g., [0, 2, 3, 9, 5, 2, 81, 8] +int *d_out; // e.g., [ , , , , , , , ] +int *d_num_selected_out; // e.g., [ ] +LessThan select_op(7); +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSelect::If( + d_temp_storage, temp_storage_bytes, + d_in, d_out, d_num_selected_out, num_items, select_op); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run selection +cub::DeviceSelect::If( + d_temp_storage, temp_storage_bytes, + d_in, d_out, d_num_selected_out, num_items, select_op); + +// d_out <-- [0, 2, 3, 5, 2] +// d_num_selected_out <-- [5] +``` + + + + +Uses the `select_op` functor to selectively compact items in `d_data`. The total number of items selected is written to `d_num_selected_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSelect::If( + void *d_temp_storage, + size_t &temp_storage_bytes, + IteratorT d_data, + NumSelectedIteratorT d_num_selected_out, + ::cuda::std::int64_t num_items, + SelectOp select_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +| Copies of the selected items are compacted in `d_data` and maintain | their original relative ordering. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading and writing items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Selection operator type having member `bool operator()(const T &a)` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the sequence of data items + + + +Pointer to the output total number of items selected + + + +Total number of input items (i.e., length of `d_data`) + + + +Unary selection operator + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the compaction of items selected from an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Functor type for selecting values less than some criteria +struct LessThan +{ + int compare; + + __host__ __device__ __forceinline__ + LessThan(int compare) : compare(compare) {} + + __host__ __device__ __forceinline__ + bool operator()(const int &a) const { + return (a < compare); + } +}; + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 8 +int *d_data; // e.g., [0, 2, 3, 9, 5, 2, 81, 8] +int *d_num_selected_out; // e.g., [ ] +LessThan select_op(7); +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSelect::If( + d_temp_storage, temp_storage_bytes, + d_data, d_num_selected_out, num_items, select_op); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run selection +cub::DeviceSelect::If( + d_temp_storage, temp_storage_bytes, + d_data, d_num_selected_out, num_items, select_op); + +// d_data <-- [0, 2, 3, 5, 2] +// d_num_selected_out <-- [5] +``` + + + + +### FlaggedIf inline static + + + + +Uses the `select_op` functor applied to `d_flags` to selectively copy the corresponding items from `d_in` into `d_out`. The total number of items selected is written to `d_num_selected_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSelect::FlaggedIf( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FlagIterator d_flags, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + ::cuda::std::int64_t num_items, + SelectOp select_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The expression `select_op(flag)` must be convertible to `bool`, where the type of `flag` corresponds to the value type of `FlagIterator`. +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering. +| The range `[d_out, d_out + *d_num_selected_out)` shall not overlap | `[d_in, d_in + num_items)` nor `d_num_selected_out` in any way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading selection flags (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing selected items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Selection operator type having member `bool operator()(const T &a)` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the input sequence of selection flags + + + +Pointer to the output sequence of selected data items + + + +Pointer to the output total number of items selected (i.e., length of `d_out`) + + + +Total number of input items (i.e., length of `d_in`) + + + +Unary selection operator + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +Uses the `select_op` functor applied to `d_flags` to selectively compact the corresponding items in `d_data`. The total number of items selected is written to `d_num_selected_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSelect::FlaggedIf( + void *d_temp_storage, + size_t &temp_storage_bytes, + IteratorT d_data, + FlagIterator d_flags, + NumSelectedIteratorT d_num_selected_out, + ::cuda::std::int64_t num_items, + SelectOp select_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The expression `select_op(flag)` must be convertible to `bool`, where the type of `flag` corresponds to the value type of `FlagIterator`. +Copies of the selected items are compacted in-place and maintain their original relative ordering. +| The `d_data` may equal `d_flags`. The range `[d_data, d_data + num_items)` shall not overlap | `[d_flags, d_flags + num_items)` in any other way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access iterator type for reading and writing selected items (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading selection flags (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Selection operator type having member `bool operator()(const T &a)` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the sequence of data items + + + +Pointer to the input sequence of selection flags + + + +Pointer to the output total number of items selected + + + +Total number of input items (i.e., length of `d_data`) + + + +Unary selection operator + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + + + + +### Unique inline static + +Given an input sequence `d_in` having runs of consecutive equal-valued keys, only the first key from each run is selectively copied to `d_out`. The total number of items selected is written to `d_num_selected_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSelect::Unique( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + ::cuda::std::int64_t num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The `==` equality operator is used to determine whether keys are equivalent +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering. +| The range `[d_out, d_out + *d_num_selected_out)` shall not overlap | `[d_in, d_in + num_items)` nor `d_num_selected_out` in any way. +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input items (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing selected items (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output sequence of selected data items + + + +Pointer to the output total number of items selected (i.e., length of `d_out`) + + + +Total number of input items (i.e., length of `d_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the compaction of items selected from an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 8 +int *d_in; // e.g., [0, 2, 2, 9, 5, 5, 5, 8] +int *d_out; // e.g., [ , , , , , , , ] +int *d_num_selected_out; // e.g., [ ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSelect::Unique( + d_temp_storage, temp_storage_bytes, + d_in, d_out, d_num_selected_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run selection +cub::DeviceSelect::Unique( + d_temp_storage, temp_storage_bytes, + d_in, d_out, d_num_selected_out, num_items); + +// d_out <-- [0, 2, 9, 5, 8] +// d_num_selected_out <-- [5] +``` + +### UniqueByKey inline static + + + + +Given an input sequence `d_keys_in` and `d_values_in` with runs of key-value pairs with consecutive equal-valued keys, only the first key and its value from each run is selectively copied to `d_keys_out` and `d_values_out`. The total number of items selected is written to `d_num_selected_out`. + + +```cpp showLineNumbers={false} +template +static ::cuda::std::enable_if_t, cudaError_t> cub::DeviceSelect::UniqueByKey( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_keys_in, + ValueInputIteratorT d_values_in, + KeyOutputIteratorT d_keys_out, + ValueOutputIteratorT d_values_out, + NumSelectedIteratorT d_num_selected_out, + NumItemsT num_items, + EqualityOpT equality_op, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The user-provided equality operator, `equality_op`, is used to determine whether keys are equivalent +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys_in, d_keys_in + num_items)` +`[d_keys_out, d_keys_out + *d_num_selected_out)` +`[d_values_in, d_values_in + num_items)` +`[d_values_out, d_values_out + *d_num_selected_out)` +`[d_num_selected_out, d_num_selected_out + 1)` +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input keys (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading input values (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing selected keys (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing selected values (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Type of num_items + + + +**[inferred]** Type of equality_op + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of keys + + + +Pointer to the input sequence of values + + + +Pointer to the output sequence of selected keys + + + +Pointer to the output sequence of selected values + + + +Pointer to the total number of items selected (i.e., length of `d_keys_out` or `d_values_out`) + + + +Total number of input items (i.e., length of `d_keys_in` or `d_values_in`) + + + +Binary predicate to determine equality + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the compaction of items selected from an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 8 +int *d_keys_in; // e.g., [0, 2, 2, 9, 5, 5, 5, 8] +int *d_values_in; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +int *d_keys_out; // e.g., [ , , , , , , , ] +int *d_values_out; // e.g., [ , , , , , , , ] +int *d_num_selected_out; // e.g., [ ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSelect::UniqueByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, + d_keys_out, d_values_out, d_num_selected_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run selection +cub::DeviceSelect::UniqueByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, + d_keys_out, d_values_out, d_num_selected_out, num_items); + +// d_keys_out <-- [0, 2, 9, 5, 8] +// d_values_out <-- [1, 2, 4, 5, 8] +// d_num_selected_out <-- [5] +``` + + + + +Given an input sequence `d_keys_in` and `d_values_in` with runs of key-value pairs with consecutive equal-valued keys, only the first key and its value from each run is selectively copied to `d_keys_out` and `d_values_out`. The total number of items selected is written to `d_num_selected_out`. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceSelect::UniqueByKey( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_keys_in, + ValueInputIteratorT d_values_in, + KeyOutputIteratorT d_keys_out, + ValueOutputIteratorT d_values_out, + NumSelectedIteratorT d_num_selected_out, + NumItemsT num_items, + cudaStream_t stream = 0 +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The `==` equality operator is used to determine whether keys are equivalent +Copies of the selected items are compacted into `d_out` and maintain their original relative ordering. +In-place operations are not supported. There must be no overlap between any of the provided ranges: +`[d_keys_in, d_keys_in + num_items)` +`[d_keys_out, d_keys_out + *d_num_selected_out)` +`[d_values_in, d_values_in + num_items)` +`[d_values_out, d_values_out + *d_num_selected_out)` +`[d_num_selected_out, d_num_selected_out + 1)` +@devicestorage + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input keys (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading input values (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing selected keys (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing selected values (may be a simple pointer type) + + + +**[inferred]** Output iterator type for recording the number of items selected (may be a simple pointer type) + + + +**[inferred]** Type of num_items + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Pointer to the input sequence of keys + + + +Pointer to the input sequence of values + + + +Pointer to the output sequence of selected keys + + + +Pointer to the output sequence of selected values + + + +Pointer to the total number of items selected (i.e., length of `d_keys_out` or `d_values_out`) + + + +Total number of input items (i.e., length of `d_keys_in` or `d_values_in`) + + + +Embed:rst:leading-asterisk +//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`. +//! + + +**Example** + +The code snippet below illustrates the compaction of items selected from an `int` device vector. + +```cpp showLineNumbers={false} +#include // or equivalently + +// Declare, allocate, and initialize device-accessible pointers +// for input and output +int num_items; // e.g., 8 +int *d_keys_in; // e.g., [0, 2, 2, 9, 5, 5, 5, 8] +int *d_values_in; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +int *d_keys_out; // e.g., [ , , , , , , , ] +int *d_values_out; // e.g., [ , , , , , , , ] +int *d_num_selected_out; // e.g., [ ] +... + +// Determine temporary device storage requirements +void *d_temp_storage = nullptr; +size_t temp_storage_bytes = 0; +cub::DeviceSelect::UniqueByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, + d_keys_out, d_values_out, d_num_selected_out, num_items); + +// Allocate temporary storage +cudaMalloc(&d_temp_storage, temp_storage_bytes); + +// Run selection +cub::DeviceSelect::UniqueByKey( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, + d_keys_out, d_values_out, d_num_selected_out, num_items); + +// d_keys_out <-- [0, 2, 9, 5, 8] +// d_values_out <-- [1, 2, 4, 5, 8] +// d_num_selected_out <-- [5] +``` + + + diff --git a/fern/cudapages/cub/cub/cub/DeviceTopK.mdx b/fern/cudapages/cub/cub/cub/DeviceTopK.mdx new file mode 100644 index 0000000..76713a5 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceTopK.mdx @@ -0,0 +1,338 @@ +--- +title: cub::DeviceTopK +description: "" +--- + +DeviceTopK provides device-wide, parallel operations for finding the largest (or smallest) K items from sequences of unordered data items residing within device-accessible memory. + +## Performance considerations + +@linear_performance{top-k} + +--- + +## Static methods + +### MaxPairs inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTopK::MaxPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_keys_in, + KeyOutputIteratorT d_keys_out, + ValueInputIteratorT d_values_in, + ValueOutputIteratorT d_values_out, + NumItemsT num_items, + NumOutItemsT k, + EnvT env = {} +) +``` + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input keys (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing output keys (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading input values (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for writing output values (may be a simple pointer type) + + + +The integral type of variable num_items + + + +The integral type of variable k + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence containing the keys + + + +Random-access iterator to the output sequence of keys, where K values will be written to + + + +Random-access iterator to the input sequence containing the values associated to each key + + + +Random-access iterator to the output sequence of values, corresponding to the top k keys, where k values will be written to + + + +Number of items to be read and processed from `d_keys_in` and `d_values_in` each + + + +The value of K, which is the number of largest pairs to find from `num_items` pairs. Capped to a maximum of `num_items`. + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is `cuda::std::execution::env{}`. +//! + + +### MinPairs inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTopK::MinPairs( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_keys_in, + KeyOutputIteratorT d_keys_out, + ValueInputIteratorT d_values_in, + ValueOutputIteratorT d_values_out, + NumItemsT num_items, + NumOutItemsT k, + EnvT env = {} +) +``` + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input keys (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing output keys (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for reading input values (may be a simple pointer type) + + + +**[inferred]** Random-access input iterator type for writing output values (may be a simple pointer type) + + + +The integral type of variable num_items + + + +The integral type of variable k + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence containing the keys + + + +Random-access iterator to the output sequence of keys, where K values will be written to + + + +Random-access iterator to the input sequence containing the values associated to each key + + + +Random-access iterator to the output sequence of values, corresponding to the top k keys, where k values will be written to + + + +Number of items to be read and processed from `d_keys_in` and `d_values_in` each + + + +The value of K, which is the number of lowest pairs to find from `num_items` pairs. Capped to a maximum of `num_items`. + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is `cuda::std::execution::env{}`. +//! + + +### MaxKeys inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTopK::MaxKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_keys_in, + KeyOutputIteratorT d_keys_out, + NumItemsT num_items, + NumOutItemsT k, + EnvT env = {} +) +``` + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input keys (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing output keys (may be a simple pointer type) + + + +The integral type of variable num_items + + + +The integral type of variable k + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence containing the keys + + + +Random-access iterator to the output sequence of keys, where K values will be written to + + + +Number of items to be read and processed from `d_keys_in` + + + +The value of K, which is the number of largest pairs to find from `num_items` pairs. Capped to a maximum of `num_items`. + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is `cuda::std::execution::env{}`. +//! + + +### MinKeys inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTopK::MinKeys( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_keys_in, + KeyOutputIteratorT d_keys_out, + NumItemsT num_items, + NumOutItemsT k, + EnvT env = {} +) +``` + + +**Template parameters** + + +**[inferred]** Random-access input iterator type for reading input keys (may be a simple pointer type) + + + +**[inferred]** Random-access output iterator type for writing output keys (may be a simple pointer type) + + + +The integral type of variable num_items + + + +The integral type of variable k + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +Random-access iterator to the input sequence containing the keys + + + +Random-access iterator to the output sequence of keys, where K values will be written to + + + +Number of items to be read and processed from `d_keys_in` + + + +The value of K, which is the number of largest pairs to find from `num_items` pairs. Capped to a maximum of `num_items`. + + + +Embed:rst:leading-asterisk +//! **[optional]** Execution environment. Default is `cuda::std::execution::env{}`. +//! + diff --git a/fern/cudapages/cub/cub/cub/DeviceTransform.mdx b/fern/cudapages/cub/cub/cub/DeviceTransform.mdx new file mode 100644 index 0000000..cdbf879 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DeviceTransform.mdx @@ -0,0 +1,476 @@ +--- +title: cub::DeviceTransform +description: "[DeviceTransform](/library/api/cub::_device_transform) provides device-wide, parallel operations for transforming elements tuple-wise from multiple input sequences into an output sequence." +--- + +`DeviceTransform` provides device-wide, parallel operations for transforming elements tuple-wise from multiple input sequences into an output sequence. + +--- + +## Methods + +### TransformInternal inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTransform::TransformInternal( + ::cuda::std::tuple inputs, + RandomAccessIteratorOut output, + NumItemsT num_items, + Predicate predicate, + TransformOp transform_op, + Env env +) +``` + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTransform::TransformInternal( + ::cuda::std::tuple inputs, + ::cuda::std::tuple outputs, + NumItemsT num_items, + Predicate predicate, + TransformOp transform_op, + Env env +) +``` + + + + + +--- + +## Static methods + +### Transform inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTransform::Transform( + ::cuda::std::tuple inputs, + ::cuda::std::tuple outputs, + NumItemsT num_items, + TransformOp transform_op, + Env env = {} +) +``` + + +**Parameters** + + +A tuple of iterators to the input sequences where num_items elements are read from each. The iterators' value types must be trivially relocatable. + + + +A tuple of iterators to the output sequences where num_items results are written to each. Each sequence may point to the beginning of one of the input sequences, performing the transformation inplace. Any output sequence must not overlap with any of the input sequence in any other way. + + + +The number of elements in each input and output sequence. + + + +An n-ary function object, where n is the number of input sequences. The input iterators' value types must be convertible to the parameters of the function object's call operator. The return type of the call operator must be a tuple where each tuple element is assignable to the corresponding dereferenced output iterators. + + + +Execution environment, or cudaStream_t. Default is `cuda::std::execution::env{}`, which will run on stream\ :sub:`0` + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTransform::Transform( + ::cuda::std::tuple inputs, + RandomAccessIteratorOut output, + NumItemsT num_items, + TransformOp transform_op, + Env env = {} +) +``` + + +*Added in v2.8.0. First appears in CUDA Toolkit 12.9.* + +**Parameters** + + +A tuple of iterators to the input sequences where num_items elements are read from each. + + + +An iterator to the output sequence where num_items results are written to. May point to the beginning of one of the input sequences, performing the transformation inplace. The output sequence must not overlap with any of the input sequence in any other way. + + + +The number of elements in each input sequence. + + + +An n-ary function object, where n is the number of input sequences. The input iterators' value types must be convertible to the parameters of the function object's call operator. The return type of the call operator must be assignable to the dereferenced output iterator. + + + +Execution environment, or cudaStream_t. Default is `cuda::std::execution::env{}`, which will run on stream\ :sub:`0` + + + + + +Transforms one input sequence into one output sequence, by applying a transformation operation on each input element and writing the result to the corresponding output element. No guarantee is given on the identity (i.e. address) of the objects passed to the call operator of the transformation operation. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTransform::Transform( + RandomAccessIteratorIn input, + RandomAccessIteratorOut output, + NumItemsT num_items, + TransformOp transform_op, + Env env = {} +) +``` + + +*Added in v2.8.0. First appears in CUDA Toolkit 12.9.* + +**Parameters** + + +An iterator to the input sequence where num_items elements are read from. + + + +An iterator to the output sequence where num_items results are written to. May point to the same sequence as `input`, performing the transformation inplace. The output sequence must not overlap with the input sequence in any other way. + + + +The number of elements in each input sequence. + + + +A unary function object. The input iterator's value type must be convertible to the parameter of the function object's call operator. The return type of the call operator must be assignable to the dereferenced output iterator. + + + +Execution environment, or cudaStream_t. Default is `cuda::std::execution::env{}`, which will run on stream\ :sub:`0` + + + + + +### Generate inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTransform::Generate( + RandomAccessIteratorOut output, + NumItemsT num_items, + Generator generator, + Env env = {} +) +``` + + +*Added in v2.8.0. First appears in CUDA Toolkit 12.9.* + +**Parameters** + + +An iterator to the output sequence where num_items results are written to. + + + +The number of elements to write to the output sequence. + + + +A nullary function object. The return type of the call operator must be assignable to the dereferenced output iterator. + + + +Execution environment, or cudaStream_t. Default is `cuda::std::execution::env{}`, which will run on stream\ :sub:`0` + + +### Fill inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTransform::Fill( + RandomAccessIteratorOut output, + NumItemsT num_items, + Value value, + Env env = {} +) +``` + + +*Added in v2.8.0. First appears in CUDA Toolkit 12.9.* + +**Parameters** + + +An iterator to the output sequence where num_items results are written to. + + + +The number of elements to write to the output sequence. + + + +The value to write. Must be assignable to the dereferenced output iterator. + + + +Execution environment, or cudaStream_t. Default is `cuda::std::execution::env{}`, which will run on stream\ :sub:`0` + + +### TransformIf inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTransform::TransformIf( + ::cuda::std::tuple inputs, + RandomAccessIteratorOut output, + NumItemsT num_items, + Predicate predicate, + TransformOp transform_op, + Env env = {} +) +``` + + +*Added in v2.8.0. First appears in CUDA Toolkit 12.9.* + +**Parameters** + + +A tuple of iterators to the input sequences where num_items elements are read from each. + + + +An iterator to the output sequence where num_items results are written to. May point to the beginning of one of the input sequences, performing the transformation inplace. The output sequence must not overlap with any of the input sequence in any other way. + + + +The number of elements in each input sequence. + + + +An n-ary function object, where n is the number of input sequences. The input iterators' value types must be convertible to the parameters of the function object's call operator, which must return a boolean value. + + + +An n-ary function object, where n is the number of input sequences. The input iterators' value types must be convertible to the parameters of the function object's call operator. The return type of the call operator must be assignable to the dereferenced output iterator. Will only be invoked if `predicate` returns true. + + + +Execution environment, or cudaStream_t. Default is `cuda::std::execution::env{}`, which will run on stream\ :sub:`0` + + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTransform::TransformIf( + RandomAccessIteratorIn input, + RandomAccessIteratorOut output, + NumItemsT num_items, + Predicate predicate, + TransformOp transform_op, + Env env = {} +) +``` + + +*Added in v2.8.0. First appears in CUDA Toolkit 12.9.* + +**Parameters** + + +An iterator to the input sequence where num_items elements are read from. + + + +An iterator to the output sequence where num_items results are written to. May point to the same sequence as `input`, performing the transformation inplace. The output sequence must not overlap with the input sequence in any other way. + + + +The number of elements in each input sequence. + + + +A unary function objects returning `bool`. The input iterators' value types must be convertible to the parameters of the function object's call operator. + + + +A unary function object. The input iterator's value type must be convertible to the parameter of the function object's call operator. The return type of the call operator must be assignable to the dereferenced output iterator. Will only be invoked if `predicate` returns true. + + + +Execution environment, or cudaStream_t. Default is `cuda::std::execution::env{}`, which will run on stream\ :sub:`0` + + + + + +### TransformStableArgumentAddresses inline static + + + + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTransform::TransformStableArgumentAddresses( + ::cuda::std::tuple inputs, + RandomAccessIteratorOut output, + NumItemsT num_items, + TransformOp transform_op, + Env env = {} +) +``` + + +*Added in v2.8.0. First appears in CUDA Toolkit 12.9.* + +**Parameters** + + +A tuple of iterators to the input sequences where num_items elements are read from each. + + + +An iterator to the output sequence where num_items results are written to. May point to the beginning of one of the input sequences, performing the transformation inplace. The output sequence must not overlap with any of the input sequence in any other way. + + + +The number of elements in each input sequence. + + + +An n-ary function object, where n is the number of input sequences. The input iterators' value types must be convertible to the parameters of the function object's call operator. The return type of the call operator must be assignable to the dereferenced output iterator. + + + +Execution environment, or cudaStream_t. Default is `cuda::std::execution::env{}`, which will run on stream\ :sub:`0` + + + + + +Transforms one input sequence into one output sequence, by applying a transformation operation on corresponding input elements and writing the result to the corresponding output element. The objects passed to the call operator of the transformation operation are guaranteed to reside in the input sequences and are never copied. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DeviceTransform::TransformStableArgumentAddresses( + RandomAccessIteratorIn input, + RandomAccessIteratorOut output, + NumItemsT num_items, + TransformOp transform_op, + Env env = {} +) +``` + + +*Added in v2.8.0. First appears in CUDA Toolkit 12.9.* + +**Parameters** + + +An iterator to the input sequence where num_items elements are read from. + + + +An iterator to the output sequence where num_items results are written to. May point to the beginning of one of the input sequences, performing the transformation inplace. The output sequence must not overlap with any of the input sequence in any other way. + + + +The number of elements in each input sequence. + + + +An n-ary function object, where n is the number of input sequences. The input iterators' value types must be convertible to the parameters of the function object's call operator. The return type of the call operator must be assignable to the dereferenced output iterator. + + + +Execution environment, or cudaStream_t. Default is `cuda::std::execution::env{}`, which will run on stream\ :sub:`0` + + + + diff --git a/fern/cudapages/cub/cub/cub/DispatchAdjacentDifference.mdx b/fern/cudapages/cub/cub/cub/DispatchAdjacentDifference.mdx new file mode 100644 index 0000000..a757d8d --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchAdjacentDifference.mdx @@ -0,0 +1,110 @@ +--- +title: cub::DispatchAdjacentDifference +description: "" +--- + + + + + + + + + + + + + + + + + + + + + + + + + + + + +--- + +## Constructors + +### DispatchAdjacentDifference inline + + +```cpp showLineNumbers={false} +cub::DispatchAdjacentDifference::DispatchAdjacentDifference( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + OffsetT num_items, + DifferenceOpT difference_op, + cudaStream_t stream +) +``` + + +--- + +## Methods + +### Invoke inline + +Invocation. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchAdjacentDifference::Invoke() +``` + + +--- + +## Static methods + +### Dispatch inline static + + +```cpp showLineNumbers={false} +static cudaError_t cub::DispatchAdjacentDifference::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + OffsetT num_items, + DifferenceOpT difference_op, + cudaStream_t stream +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `InputT` | `detail::it_value_t< InputIteratorT >` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `d_temp_storage` | `void *` | | +| `temp_storage_bytes` | `size_t &` | | +| `d_input` | `InputIteratorT` | | +| `d_output` | `OutputIteratorT` | | +| `num_items` | `OffsetT` | | +| `difference_op` | `DifferenceOpT` | | +| `stream` | `cudaStream_t` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchHistogram.mdx b/fern/cudapages/cub/cub/cub/DispatchHistogram.mdx new file mode 100644 index 0000000..6cd7af0 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchHistogram.mdx @@ -0,0 +1,661 @@ +--- +title: cub::DispatchHistogram +description: "Utility class for dispatching the appropriately-tuned kernels for [DeviceHistogram](/library/api/cub::_device_histogram)." +--- + +Utility class for dispatching the appropriately-tuned kernels for [DeviceHistogram](/library/api/cub::_device_histogram). + + + + + +Number of channels interleaved in the input data (may be greater than the number of channels being actively histogrammed) + + + +Number of channels actively being histogrammed + + + +Random-access input iterator type for reading input items (may be a simple pointer type) + + + +Integer type for counting sample occurrences per histogram bin + + + +Type for specifying bin level boundaries + + + +Signed integer type for global offsets + + + +Implementation detail, do not specify directly, requirements on the content of this type are subject to breaking change. + + + + + + + + + + + + + + +--- + +## Static methods + +### DispatchRange inline static + + + + +Dispatch routine for HistogramRange with host-side decode operator initialization, specialized for sample types larger than 8bit. + +This variant initializes the decode operators on the host before kernel launch. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchHistogram::DispatchRange( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_output_histograms, + ::cuda::std::array num_output_levels, + ::cuda::std::array d_levels, + OffsetT num_row_pixels, + OffsetT num_rows, + OffsetT row_stride_samples, + cudaStream_t stream, + ::cuda::std::false_type, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + + + +The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of `d_histograms[i]` should be `num_output_levels[i] - 1`. + + + +The number of boundaries (levels) for delineating histogram samples in each active channel. Implies that the number of bins for channeli is `num_output_levels[i] - 1`. + + + +The pointers to the arrays of boundaries (levels), one for each active channel. Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + + + +The number of multi-channel pixels per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of samples between starts of consecutive rows in the region of interest + + + +CUDA stream to launch kernels within. Default is stream0. + + + + + +Dispatch routine for HistogramRange with host-side decode operator initialization, specialized for 8-bit sample types (computes 256-bin privatized histograms and then reduces to user-specified levels). + +This variant initializes the decode operators on the host before kernel launch. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchHistogram::DispatchRange( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_output_histograms, + ::cuda::std::array num_output_levels, + ::cuda::std::array d_levels, + OffsetT num_row_pixels, + OffsetT num_rows, + OffsetT row_stride_samples, + cudaStream_t stream, + ::cuda::std::true_type, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + + + +The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of `d_histograms[i]` should be `num_output_levels[i] - 1`. + + + +The number of boundaries (levels) for delineating histogram samples in each active channel. Implies that the number of bins for channeli is `num_output_levels[i] - 1`. + + + +The pointers to the arrays of boundaries (levels), one for each active channel. Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + + + +The number of multi-channel pixels per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of samples between starts of consecutive rows in the region of interest + + + +CUDA stream to launch kernels within. Default is stream0. + + + + + +### DispatchEven inline static + + + + +Dispatch routine for HistogramEven with host-side decode operator initialization, specialized for sample types larger than 8-bit. + +This variant initializes the decode operators on the host before kernel launch. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchHistogram::DispatchEven( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_output_histograms, + ::cuda::std::array num_output_levels, + ::cuda::std::array lower_level, + ::cuda::std::array upper_level, + OffsetT num_row_pixels, + OffsetT num_rows, + OffsetT row_stride_samples, + cudaStream_t stream, + ::cuda::std::false_type, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the input sequence of sample items. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + + + +The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of `d_histograms[i]` should be `num_output_levels[i] - 1`. + + + +The number of bin level boundaries for delineating histogram samples in each active channel. Implies that the number of bins for channeli is `num_output_levels[i] - 1`. + + + +The lower sample value bound (inclusive) for the lowest histogram bin in each active channel. + + + +The upper sample value bound (exclusive) for the highest histogram bin in each active channel. + + + +The number of multi-channel pixels per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of samples between starts of consecutive rows in the region of interest + + + +CUDA stream to launch kernels within. Default is stream0. + + + + + +Dispatch routine for HistogramEven with host-side decode operator initialization, specialized for 8-bit sample types (computes 256-bin privatized histograms and then reduces to user-specified levels). + +This variant initializes the decode operators on the host before kernel launch. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchHistogram::DispatchEven( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_output_histograms, + ::cuda::std::array num_output_levels, + ::cuda::std::array lower_level, + ::cuda::std::array upper_level, + OffsetT num_row_pixels, + OffsetT num_rows, + OffsetT row_stride_samples, + cudaStream_t stream, + ::cuda::std::true_type, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the input sequence of sample items. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + + + +The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of `d_histograms[i]` should be `num_output_levels[i] - 1`. + + + +The number of bin level boundaries for delineating histogram samples in each active channel. Implies that the number of bins for channeli is `num_output_levels[i] - 1`. + + + +The lower sample value bound (inclusive) for the lowest histogram bin in each active channel. + + + +The upper sample value bound (exclusive) for the highest histogram bin in each active channel. + + + +The number of multi-channel pixels per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of samples between starts of consecutive rows in the region of interest + + + +CUDA stream to launch kernels within. Default is stream0. + + + + + +### __dispatch_range_device_init inline static + + + + +Dispatch routine for HistogramRange with device-side decode operator initialization, specialized for sample types larger than 8bit. + +This variant initializes the decode operators inside the kernel from level arrays. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchHistogram::__dispatch_range_device_init( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_output_histograms, + NumOutputLevelsArrayT num_output_levels, + LevelsArrayT d_levels, + OffsetT num_row_pixels, + OffsetT num_rows, + OffsetT row_stride_samples, + cudaStream_t stream, + ::cuda::std::false_type, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + + + +The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of `d_histograms[i]` should be `num_output_levels[i] - 1`. + + + +The number of boundaries (levels) for delineating histogram samples in each active channel. Implies that the number of bins for channeli is `num_output_levels[i] - 1`. + + + +The pointers to the arrays of boundaries (levels), one for each active channel. Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + + + +The number of multi-channel pixels per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of samples between starts of consecutive rows in the region of interest + + + +CUDA stream to launch kernels within. Default is stream0. + + + + + +Dispatch routine for HistogramRange with device-side decode operator initialization, specialized for 8-bit sample types (computes 256-bin privatized histograms and then reduces to user-specified levels). + +This variant initializes the decode operators inside the kernel from level arrays. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchHistogram::__dispatch_range_device_init( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_output_histograms, + NumOutputLevelsArrayT num_output_levels, + LevelsArrayT d_levels, + OffsetT num_row_pixels, + OffsetT num_rows, + OffsetT row_stride_samples, + cudaStream_t stream, + ::cuda::std::true_type, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the multi-channel input sequence of data samples. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + + + +The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of `d_histograms[i]` should be `num_output_levels[i] - 1`. + + + +The number of boundaries (levels) for delineating histogram samples in each active channel. Implies that the number of bins for channeli is `num_output_levels[i] - 1`. + + + +The pointers to the arrays of boundaries (levels), one for each active channel. Bin ranges are defined by consecutive boundary pairings: lower sample value boundaries are inclusive and upper sample value boundaries are exclusive. + + + +The number of multi-channel pixels per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of samples between starts of consecutive rows in the region of interest + + + +CUDA stream to launch kernels within. Default is stream0. + + + + + +### __dispatch_even_device_init inline static + + + + +Dispatch routine for HistogramEven with device-side decode operator initialization, specialized for sample types larger than 8-bit. + +This variant initializes the decode operators inside the kernel from level bounds. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchHistogram::__dispatch_even_device_init( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_output_histograms, + ::cuda::std::array num_output_levels, + LowerLevelArrayT lower_level, + UpperLevelArrayT upper_level, + OffsetT num_row_pixels, + OffsetT num_rows, + OffsetT row_stride_samples, + cudaStream_t stream, + ::cuda::std::false_type, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the input sequence of sample items. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + + + +The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of `d_histograms[i]` should be `num_output_levels[i] - 1`. + + + +The number of bin level boundaries for delineating histogram samples in each active channel. Implies that the number of bins for channeli is `num_output_levels[i] - 1`. + + + +The lower sample value bound (inclusive) for the lowest histogram bin in each active channel. + + + +The upper sample value bound (exclusive) for the highest histogram bin in each active channel. + + + +The number of multi-channel pixels per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of samples between starts of consecutive rows in the region of interest + + + +CUDA stream to launch kernels within. Default is stream0. + + + + + +Dispatch routine for HistogramEven with device-side decode operator initialization, specialized for 8-bit sample types (computes 256-bin privatized histograms and then reduces to user-specified levels). + +This variant initializes the decode operators inside the kernel from level bounds. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchHistogram::__dispatch_even_device_init( + void *d_temp_storage, + size_t &temp_storage_bytes, + SampleIteratorT d_samples, + ::cuda::std::array d_output_histograms, + ::cuda::std::array num_output_levels, + LowerLevelArrayT lower_level, + UpperLevelArrayT upper_level, + OffsetT num_row_pixels, + OffsetT num_rows, + OffsetT row_stride_samples, + cudaStream_t stream, + ::cuda::std::true_type, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Reference to size in bytes of `d_temp_storage` allocation + + + +The pointer to the input sequence of sample items. The samples from different channels are assumed to be interleaved (e.g., an array of 32-bit pixels where each pixel consists of four RGBA 8-bit samples). + + + +The pointers to the histogram counter output arrays, one for each active channel. For channeli, the allocation length of `d_histograms[i]` should be `num_output_levels[i] - 1`. + + + +The number of bin level boundaries for delineating histogram samples in each active channel. Implies that the number of bins for channeli is `num_output_levels[i] - 1`. + + + +The lower sample value bound (inclusive) for the lowest histogram bin in each active channel. + + + +The upper sample value bound (exclusive) for the highest histogram bin in each active channel. + + + +The number of multi-channel pixels per row in the region of interest + + + +The number of rows in the region of interest + + + +The number of samples between starts of consecutive rows in the region of interest + + + +CUDA stream to launch kernels within. Default is stream0. + + + + diff --git a/fern/cudapages/cub/cub/cub/DispatchMergeSort.mdx b/fern/cudapages/cub/cub/cub/DispatchMergeSort.mdx new file mode 100644 index 0000000..b199c4b --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchMergeSort.mdx @@ -0,0 +1,132 @@ +--- +title: cub::DispatchMergeSort +description: "" +--- + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +--- + +## Constructors + +### DispatchMergeSort inline + + +```cpp showLineNumbers={false} +cub::DispatchMergeSort::DispatchMergeSort( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_input_keys, + ValueInputIteratorT d_input_items, + KeyIteratorT d_output_keys, + ValueIteratorT d_output_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream, + int ptx_version, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {} +) +``` + + +--- + +## Methods + +### Invoke inline + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchMergeSort::Invoke( + ActivePolicyT policy = {} +) +``` + + +--- + +## Static methods + +### Dispatch inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchMergeSort::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_input_keys, + ValueInputIteratorT d_input_items, + KeyIteratorT d_output_keys, + ValueIteratorT d_output_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `KEYS_ONLY` static constexpr | `bool` | Whether or not there are values to be trucked along with keys. | +| `d_temp_storage` | `void *` | Device-accessible allocation of temporary storage. | +| `temp_storage_bytes` | `size_t &` | Reference to size in bytes of `d_temp_storage` allocation. | +| `d_input_keys` | `KeyInputIteratorT` | Pointer to the input sequence of unsorted input keys. | +| `d_input_items` | `ValueInputIteratorT` | Pointer to the input sequence of unsorted input values. | +| `d_output_keys` | `KeyIteratorT` | Pointer to the output sequence of sorted input keys. | +| `d_output_items` | `ValueIteratorT` | Pointer to the output sequence of sorted input values. | +| `num_items` | `OffsetT` | Number of items to sort. | +| `compare_op` | `CompareOpT` | Comparison function object which returns true if the first argument is ordered before the second. | +| `stream` | `cudaStream_t` | CUDA stream to launch kernels within. Default is stream0. | +| `ptx_version` | `int` | | +| `kernel_source` | `KernelSource` | | +| `launcher_factory` | `KernelLauncherFactory` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchRadixSort.mdx b/fern/cudapages/cub/cub/cub/DispatchRadixSort.mdx new file mode 100644 index 0000000..de4ac38 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchRadixSort.mdx @@ -0,0 +1,377 @@ +--- +title: cub::DispatchRadixSort +description: "Utility class for dispatching the appropriately-tuned kernels for device-wide radix sort." +--- + +Utility class for dispatching the appropriately-tuned kernels for device-wide radix sort. + + + + + + + + +Key type + + + +Value type + + + +Signed integer type for global offsets + + + +Implementation detail, do not specify directly, requirements on the content of this type are subject to breaking change. + + + + + + + + + + + + + + +--- + +## Constructors + +### DispatchRadixSort inline + + +```cpp showLineNumbers={false} +cub::DispatchRadixSort::DispatchRadixSort( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + OffsetT num_items, + int begin_bit, + int end_bit, + bool is_overwrite_okay, + cudaStream_t stream, + int ptx_version, + DecomposerT decomposer = {}, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {} +) +``` + + +--- + +## Methods + +### InvokeSingleTile inline + +Invoke a single block to sort in-core. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchRadixSort::InvokeSingleTile( + SingleTileKernelT single_tile_kernel, + ActivePolicyT policy = {} +) +``` + + +**Template parameters** + + +Umbrella policy active for the target device + + + +Function type of cub::DeviceRadixSortSingleTileKernel + + +**Parameters** + + +Kernel function pointer to parameterization of cub::DeviceRadixSortSingleTileKernel + + +### InvokePass inline + +Invoke a three-kernel sorting pass at the current bit. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchRadixSort::InvokePass( + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + OffsetT *d_spine, + int, + int ¤t_bit, + PassConfigT &pass_config +) +``` + + +### InvokeOnesweep inline + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchRadixSort::InvokeOnesweep( + ActivePolicyT policy = {} +) +``` + + +### InvokePasses inline + +Invocation (run multiple digit passes). + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchRadixSort::InvokePasses( + UpsweepKernelT upsweep_kernel, + UpsweepKernelT alt_upsweep_kernel, + ScanKernelT scan_kernel, + DownsweepKernelT downsweep_kernel, + DownsweepKernelT alt_downsweep_kernel, + ActivePolicyT policy = {} +) +``` + + +**Template parameters** + + +Umbrella policy active for the target device + + + +Function type of cub::DeviceRadixSortUpsweepKernel + + + +Function type of cub::SpineScanKernel + + + +Function type of cub::DeviceRadixSortDownsweepKernel + + +**Parameters** + + +Kernel function pointer to parameterization of cub::DeviceRadixSortUpsweepKernel + + + +Alternate kernel function pointer to parameterization of cub::DeviceRadixSortUpsweepKernel + + + +Kernel function pointer to parameterization of cub::SpineScanKernel + + + +Kernel function pointer to parameterization of cub::DeviceRadixSortDownsweepKernel + + + +Alternate kernel function pointer to parameterization of cub::DeviceRadixSortDownsweepKernel + + +### InvokeCopy inline + + +```cpp showLineNumbers={false} +cudaError_t cub::DispatchRadixSort::InvokeCopy() +``` + + +### Invoke inline + +Invocation. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchRadixSort::Invoke( + ActivePolicyT = {} +) +``` + + +### __invoke inline + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchRadixSort::__invoke( + PolicyGetter policy_getter +) +``` + + +### __invoke_single_tile inline + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchRadixSort::__invoke_single_tile( + SingleTileKernelT single_tile_kernel, + detail::radix_sort::radix_sort_downsweep_policy policy +) +``` + + +### __invoke_onesweep inline + + +```cpp showLineNumbers={false} +cudaError_t cub::DispatchRadixSort::__invoke_onesweep( + detail::radix_sort::radix_sort_policy policy +) +``` + + +### __invoke_passes inline + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchRadixSort::__invoke_passes( + UpsweepKernelT upsweep_kernel, + UpsweepKernelT alt_upsweep_kernel, + ScanKernelT scan_kernel, + DownsweepKernelT downsweep_kernel, + DownsweepKernelT alt_downsweep_kernel, + const detail::radix_sort::radix_sort_policy &policy +) +``` + + +--- + +## Static methods + +### Dispatch inline static + +Internal dispatch routine. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchRadixSort::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + OffsetT num_items, + int begin_bit, + int end_bit, + bool is_overwrite_okay, + cudaStream_t stream, + DecomposerT decomposer = {}, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_radix_sort::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_radix_sort::d_temp_storage) allocation + + + +Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +Number of items to sort + + + +The beginning (least-significant) bit index needed for key comparison + + + +The past-the-end (most-significant) bit index needed for key comparison + + + +Whether is okay to overwrite source buffers + + + +CUDA stream to launch kernels within. Default is stream0. + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `KEYS_ONLY` static constexpr | `bool` | | +| `d_temp_storage` | `void *` | Device-accessible allocation of temporary storage. | +| `temp_storage_bytes` | `size_t &` | Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_radix_sort::d_temp_storage) allocation. | +| `d_keys` | `DoubleBuffer< KeyT > &` | Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys. | +| `d_values` | `DoubleBuffer< ValueT > &` | Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values. | +| `num_items` | `OffsetT` | Number of items to sort. | +| `begin_bit` | `int` | The beginning (least-significant) bit index needed for key comparison. | +| `end_bit` | `int` | The past-the-end (most-significant) bit index needed for key comparison. | +| `stream` | `cudaStream_t` | CUDA stream to launch kernels within. Default is stream0. | +| `ptx_version` | `int` | PTX version. | +| `is_overwrite_okay` | `bool` | Whether is okay to overwrite source buffers. | +| `decomposer` | `DecomposerT` | | +| `kernel_source` | `KernelSource` | | +| `launcher_factory` | `KernelLauncherFactory` | | + +--- + +## Inner classes + +### PassConfig + + +```cpp showLineNumbers={false} +struct cub::DispatchRadixSort::PassConfig +``` + + +Pass configuration structure. + +| Name | Type | Description | +|---|---|---| +| `upsweep_kernel` | `UpsweepKernelT` | | +| `upsweep_config` | `detail::KernelConfig` | | +| `scan_kernel` | `ScanKernelT` | | +| `scan_config` | `detail::KernelConfig` | | +| `downsweep_kernel` | `DownsweepKernelT` | | +| `downsweep_config` | `detail::KernelConfig` | | +| `radix_bits` | `int` | | +| `radix_digits` | `int` | | +| `max_downsweep_grid_size` | `int` | | +| `even_share` | `GridEvenShare< OffsetT >` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchReduce.mdx b/fern/cudapages/cub/cub/cub/DispatchReduce.mdx new file mode 100644 index 0000000..37b9d0c --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchReduce.mdx @@ -0,0 +1,241 @@ +--- +title: cub::DispatchReduce +description: "Utility class for dispatching the appropriately-tuned kernels for device-wide reduction." +--- + +Utility class for dispatching the appropriately-tuned kernels for device-wide reduction. + + + + + +Random-access input iterator type for reading input items (may be a simple pointer type) + + + +Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +Signed integer type for global offsets + + + +Binary reduction functor type having member `auto operator()(const T &a, const U &b)` + + + +Initial value type + + + + + + + + + + + + + + + + + + + + +--- + +## Constructors + +### DispatchReduce inline + +Constructor. + + +```cpp showLineNumbers={false} +cub::DispatchReduce::DispatchReduce( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + OffsetT num_items, + ReductionOpT reduction_op, + InitT init, + cudaStream_t stream, + int ptx_version, + TransformOpT transform_op = {}, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {} +) +``` + + +--- + +## Methods + +### InvokeSingleTile inline + +Invoke a single block block to reduce in-core. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchReduce::InvokeSingleTile( + SingleTileKernelT single_tile_kernel, + ActivePolicyT policy = {} +) +``` + + +**Template parameters** + + +Umbrella policy active for the target device + + + +Function type of cub::DeviceReduceSingleTileKernel + + +**Parameters** + + +Kernel function pointer to parameterization of cub::DeviceReduceSingleTileKernel + + +### InvokePasses inline + +Invoke two-passes to reduce. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchReduce::InvokePasses( + ReduceKernelT reduce_kernel, + SingleTileKernelT single_tile_kernel, + ActivePolicyT active_policy = {} +) +``` + + +**Template parameters** + + +Umbrella policy active for the target device + + + +Function type of cub::DeviceReduceKernel + + + +Function type of cub::DeviceReduceSingleTileKernel + + +**Parameters** + + +Kernel function pointer to parameterization of cub::DeviceReduceKernel + + + +Kernel function pointer to parameterization of cub::DeviceReduceSingleTileKernel + + +### Invoke inline + +Invocation. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchReduce::Invoke( + ActivePolicyT active_policy = {} +) +``` + + +--- + +## Static methods + +### Dispatch inline static + +Internal dispatch routine for computing a device-wide reduction. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchReduce::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + OffsetT num_items, + ReductionOpT reduction_op, + InitT init, + cudaStream_t stream, + TransformOpT transform_op = {}, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_reduce::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_reduce::d_temp_storage) allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +Total number of input items (i.e., length of [`d_in`](/library/api/cub::_dispatch_reduce::d_in)) + + + +Binary reduction functor + + + +The initial value of the reduction + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `d_temp_storage` | `void *` | Device-accessible allocation of temporary storage. | +| `temp_storage_bytes` | `size_t &` | Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_reduce::d_temp_storage) allocation. | +| `d_in` | `InputIteratorT` | Pointer to the input sequence of data items. | +| `d_out` | `OutputIteratorT` | Pointer to the output aggregate. | +| `num_items` | `OffsetT` | Total number of input items (i.e., length of [`d_in`](/library/api/cub::_dispatch_reduce::d_in)). | +| `reduction_op` | `ReductionOpT` | Binary reduction functor. | +| `init` | `InitT` | The initial value of the reduction. | +| `stream` | `cudaStream_t` | CUDA stream to launch kernels within. Default is stream0. | +| `ptx_version` | `int` | | +| `transform_op` | `TransformOpT` | | +| `kernel_source` | `KernelSource` | | +| `launcher_factory` | `KernelLauncherFactory` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchReduceByKey.mdx b/fern/cudapages/cub/cub/cub/DispatchReduceByKey.mdx new file mode 100644 index 0000000..118fa7f --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchReduceByKey.mdx @@ -0,0 +1,210 @@ +--- +title: cub::DispatchReduceByKey +description: "Utility class for dispatching the appropriately-tuned kernels for DeviceReduceByKey." +--- + +Utility class for dispatching the appropriately-tuned kernels for DeviceReduceByKey. + + + + + +Random-access input iterator type for keys + + + +Random-access output iterator type for keys + + + +Random-access input iterator type for values + + + +Random-access output iterator type for values + + + +Output iterator type for recording number of segments encountered + + + +KeyT equality operator type + + + +ValueT reduction operator type + + + +Signed integer type for global offsets + + + + + + +Implementation detail, do not specify directly, requirements on the content of this type are subject to breaking change. + + + + + +--- + +## Constructors + +### DispatchReduceByKey inline + + +```cpp showLineNumbers={false} +cub::DispatchReduceByKey::DispatchReduceByKey( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, + EqualityOpT equality_op, + ReductionOpT reduction_op, + OffsetT num_items, + cudaStream_t stream +) +``` + + +--- + +## Methods + +### Invoke inline + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchReduceByKey::Invoke( + ScanInitKernelT init_kernel, + ReduceByKeyKernelT reduce_by_key_kernel +) +``` + + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchReduceByKey::Invoke() +``` + + + + + +--- + +## Static methods + +### Dispatch inline static + +Internal dispatch routine. + + +```cpp showLineNumbers={false} +static cudaError_t cub::DispatchReduceByKey::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, + EqualityOpT equality_op, + ReductionOpT reduction_op, + OffsetT num_items, + cudaStream_t stream +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_reduce_by_key::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_reduce_by_key::d_temp_storage) allocation + + + +Pointer to the input sequence of keys + + + +Pointer to the output sequence of unique keys (one key per run) + + + +Pointer to the input sequence of corresponding values + + + +Pointer to the output sequence of value aggregates (one aggregate per run) + + + +Pointer to total number of runs encountered (i.e., the length of d_unique_out) + + + +KeyT equality operator + + + +ValueT reduction operator + + + +Total number of items to select from + + + +CUDA stream to launch kernels within. Default is stream0. + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `ValueInputT` | `cub::detail::it_value_t< ValuesInputIteratorT >` | +| `streaming_context_t` | `NullType` | +| `ScanTileStateT` | `ReduceByKeyScanTileState< AccumT, OffsetT >` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `INIT_KERNEL_THREADS` static constexpr | `int` | | +| `d_temp_storage` | `void *` | | +| `temp_storage_bytes` | `size_t &` | | +| `d_keys_in` | `KeysInputIteratorT` | | +| `d_unique_out` | `UniqueOutputIteratorT` | | +| `d_values_in` | `ValuesInputIteratorT` | | +| `d_aggregates_out` | `AggregatesOutputIteratorT` | | +| `d_num_runs_out` | `NumRunsOutputIteratorT` | | +| `equality_op` | `EqualityOpT` | | +| `reduction_op` | `ReductionOpT` | | +| `num_items` | `OffsetT` | | +| `stream` | `cudaStream_t` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchScan.mdx b/fern/cudapages/cub/cub/cub/DispatchScan.mdx new file mode 100644 index 0000000..1ecc522 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchScan.mdx @@ -0,0 +1,241 @@ +--- +title: cub::DispatchScan +description: "Utility class for dispatching the appropriately-tuned kernels for [DeviceScan](/library/api/cub::_device_scan)." +--- + +Utility class for dispatching the appropriately-tuned kernels for [DeviceScan](/library/api/cub::_device_scan). + + + + + +Random-access input iterator type for reading scan inputs (may be a simple pointer type) + + + +Random-access output iterator type for writing scan outputs (may be a simple pointer type) + + + +Binary scan functor type having member `auto operator()(const T &a, const U &b)` + + + +The init_value element type for ScanOpT (cub::NullType for inclusive scans) + + + +Unsigned integer type for global offsets + + + + + + +Enum flag to specify whether to enforce inclusive scan. + + + + + + + + + + + + + + +--- + +## Constructors + +### DispatchScan inline + + +```cpp showLineNumbers={false} +cub::DispatchScan::DispatchScan( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + OffsetT num_items, + ScanOpT scan_op, + InitValueT init_value, + cudaStream_t stream, + int ptx_version, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_scan::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_scan::d_temp_storage) allocation + + + +Iterator to the input sequence of data items + + + +Iterator to the output sequence of data items + + + +Total number of input items (i.e., the length of [`d_in`](/library/api/cub::_dispatch_scan::d_in)) + + + +Binary scan functor + + + +Initial value to seed the exclusive scan + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + +Object specifying implementation kernels + + + +Object to execute implementation kernels on the given stream + + +--- + +## Methods + +### Invoke inline + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchScan::Invoke( + InitKernelT init_kernel, + ScanKernelT scan_kernel, + ActivePolicyT policy = {} +) +``` + + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchScan::Invoke( + ActivePolicyT active_policy = {} +) +``` + + + + + +--- + +## Static methods + +### Dispatch inline static + +Internal dispatch routine. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchScan::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items, + cudaStream_t stream, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_scan::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_scan::d_temp_storage) allocation + + + +Iterator to the input sequence of data items + + + +Iterator to the output sequence of data items + + + +Binary scan functor + + + +Initial value to seed the exclusive scan + + + +Total number of input items (i.e., the length of [`d_in`](/library/api/cub::_dispatch_scan::d_in)) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + + +Object specifying implementation kernels + + + +Object to execute implementation kernels on the given stream + + + +Struct encoding chain of algorithm tuning policies + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `INIT_KERNEL_THREADS` static constexpr | `int` | | +| `d_temp_storage` | `void *` | Device-accessible allocation of temporary storage. | +| `temp_storage_bytes` | `size_t &` | Reference to size in bytes of `d_temp_storage` allocation. | +| `d_in` | `InputIteratorT` | Iterator to the input sequence of data items. | +| `d_out` | `OutputIteratorT` | Iterator to the output sequence of data items. | +| `scan_op` | `ScanOpT` | Binary scan functor. | +| `init_value` | `InitValueT` | Initial value to seed the exclusive scan. | +| `num_items` | `OffsetT` | Total number of input items (i.e., the length of `d_in`). | +| `stream` | `cudaStream_t` | CUDA stream to launch kernels within. Default is stream0. | +| `ptx_version` | `int` | | +| `kernel_source` | `KernelSource` | | +| `launcher_factory` | `KernelLauncherFactory` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchScanByKey.mdx b/fern/cudapages/cub/cub/cub/DispatchScanByKey.mdx new file mode 100644 index 0000000..1a9d52b --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchScanByKey.mdx @@ -0,0 +1,242 @@ +--- +title: cub::DispatchScanByKey +description: "Utility class for dispatching the appropriately-tuned kernels for [DeviceScan](/library/api/cub::_device_scan)." +--- + +Utility class for dispatching the appropriately-tuned kernels for [DeviceScan](/library/api/cub::_device_scan). + + + + + +Random-access input iterator type + + + +Random-access input iterator type + + + +Random-access output iterator type + + + +Equality functor type + + + +Scan functor type + + + +The init_value element for ScanOpT type (cub::NullType for inclusive scan) + + + +Unsigned integer type for global offsets + + + + + + + + + + + +--- + +## Constructors + +### DispatchScanByKey inline + + +```cpp showLineNumbers={false} +cub::DispatchScanByKey::DispatchScanByKey( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + EqualityOp equality_op, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items, + cudaStream_t stream, + int ptx_version +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_scan_by_key::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_scan_by_key::d_temp_storage) allocation + + + +Iterator to the input sequence of key items + + + +Iterator to the input sequence of value items + + + +Iterator to the input sequence of value items + + + +Binary equality functor + + + +Binary scan functor + + + +Initial value to seed the exclusive scan + + + +Total number of input items (i.e., the length of `d_in`) + + + +CUDA stream to launch kernels within. + + +--- + +## Methods + +### Invoke inline + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchScanByKey::Invoke( + InitKernel init_kernel, + ScanKernel scan_kernel +) +``` + + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchScanByKey::Invoke() +``` + + + + + +--- + +## Static methods + +### Dispatch inline static + +Internal dispatch routine. + + +```cpp showLineNumbers={false} +static cudaError_t cub::DispatchScanByKey::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + EqualityOp equality_op, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items, + cudaStream_t stream +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_scan_by_key::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_scan_by_key::d_temp_storage) allocation + + + +Iterator to the input sequence of key items + + + +Iterator to the input sequence of value items + + + +Iterator to the input sequence of value items + + + +Binary equality functor + + + +Binary scan functor + + + +Initial value to seed the exclusive scan + + + +Total number of input items (i.e., the length of `d_in`) + + + +CUDA stream to launch kernels within. + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `KeyT` | `cub::detail::it_value_t< KeysInputIteratorT >` | +| `InputT` | `cub::detail::it_value_t< ValuesInputIteratorT >` | +| `ScanByKeyTileStateT` | `ReduceByKeyScanTileState< AccumT, int >` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `INIT_KERNEL_THREADS` static constexpr | `int` | | +| `d_temp_storage` | `void *` | Device-accessible allocation of temporary storage. | +| `temp_storage_bytes` | `size_t &` | Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_scan_by_key::d_temp_storage) allocation. | +| `d_keys_in` | `KeysInputIteratorT` | Iterator to the input sequence of key items. | +| `d_values_in` | `ValuesInputIteratorT` | Iterator to the input sequence of value items. | +| `d_values_out` | `ValuesOutputIteratorT` | Iterator to the input sequence of value items. | +| `equality_op` | `EqualityOp` | Binary equality functor. | +| `scan_op` | `ScanOpT` | Binary scan functor. | +| `init_value` | `InitValueT` | Initial value to seed the exclusive scan. | +| `num_items` | `OffsetT` | Total number of input items (i.e., the length of `d_in`). | +| `stream` | `cudaStream_t` | CUDA stream to launch kernels within. | +| `ptx_version` | `int` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchSegmentedRadixSort.mdx b/fern/cudapages/cub/cub/cub/DispatchSegmentedRadixSort.mdx new file mode 100644 index 0000000..97b0686 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchSegmentedRadixSort.mdx @@ -0,0 +1,274 @@ +--- +title: cub::DispatchSegmentedRadixSort +description: "Utility class for dispatching the appropriately-tuned kernels for segmented device-wide radix sort." +--- + +Utility class for dispatching the appropriately-tuned kernels for segmented device-wide radix sort. + + + + + + + + +Key type + + + +Value type + + + +Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + + +Integer type to index items within a segment + + + + + + + + + + + + + + + + + +--- + +## Constructors + +### DispatchSegmentedRadixSort inline + +Constructor. + + +```cpp showLineNumbers={false} +cub::DispatchSegmentedRadixSort::DispatchSegmentedRadixSort( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + int begin_bit, + int end_bit, + bool is_overwrite_okay, + cudaStream_t stream, + int ptx_version, + DecomposerT decomposer = {}, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {} +) +``` + + +--- + +## Methods + +### InvokePass inline + +Invoke a three-kernel sorting pass at the current bit. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchSegmentedRadixSort::InvokePass( + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + int ¤t_bit, + PassConfigT &pass_config +) +``` + + +### InvokePasses inline + +Invocation (run multiple digit passes). + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchSegmentedRadixSort::InvokePasses( + SegmentedKernelT segmented_kernel, + SegmentedKernelT alt_segmented_kernel, + ActivePolicyT policy = {} +) +``` + + +**Template parameters** + + +Umbrella policy active for the target device + + + +Function type of cub::DeviceSegmentedRadixSortKernel + + +**Parameters** + + +Kernel function pointer to parameterization of cub::DeviceSegmentedRadixSortKernel + + + +Alternate kernel function pointer to parameterization of cub::DeviceSegmentedRadixSortKernel + + +### Invoke inline + +Invocation. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchSegmentedRadixSort::Invoke( + ActivePolicyT policy = {} +) +``` + + +--- + +## Static methods + +### Dispatch inline static + +Internal dispatch routine. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchSegmentedRadixSort::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + ::cuda::std::int64_t num_items, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + int begin_bit, + int end_bit, + bool is_overwrite_okay, + cudaStream_t stream, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_segmented_radix_sort::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_segmented_radix_sort::d_temp_storage) allocation + + + +Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys + + + +Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values + + + +Number of items to sort + + + +The number of segments that comprise the sorting data + + + +Random-access input iterator to the sequence of beginning offsets of length [`num_segments`](/library/api/cub::_dispatch_segmented_radix_sort::num_segments), such that [`d_begin_offsets`](/library/api/cub::_dispatch_segmented_radix_sort::d_begin_offsets)`[i]` is the first element of the *i*th data segment in `d_keys_*` and `d_values_*` + + + +Random-access input iterator to the sequence of ending offsets of length [`num_segments`](/library/api/cub::_dispatch_segmented_radix_sort::num_segments), such that [`d_end_offsets`](/library/api/cub::_dispatch_segmented_radix_sort::d_end_offsets)`[i]-1` is the last element of the *i*th data segment in `d_keys_*` and `d_values_*`. If [`d_end_offsets`](/library/api/cub::_dispatch_segmented_radix_sort::d_end_offsets)`[i]-1` <= [`d_begin_offsets`](/library/api/cub::_dispatch_segmented_radix_sort::d_begin_offsets)`[i]`, the *i*th is considered empty. + + + +The beginning (least-significant) bit index needed for key comparison + + + +The past-the-end (most-significant) bit index needed for key comparison + + + +Whether is okay to overwrite source buffers + + + +CUDA stream to launch kernels within. Default is stream0. + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `KEYS_ONLY` static constexpr | `bool` | | +| `d_temp_storage` | `void *` | Device-accessible allocation of temporary storage. | +| `temp_storage_bytes` | `size_t &` | Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_segmented_radix_sort::d_temp_storage) allocation. | +| `d_keys` | `DoubleBuffer< KeyT > &` | Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys. | +| `d_values` | `DoubleBuffer< ValueT > &` | Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values. | +| `num_items` | `::cuda::std::int64_t` | Number of items to sort. | +| `num_segments` | `::cuda::std::int64_t` | The number of segments that comprise the sorting data. | +| `d_begin_offsets` | `BeginOffsetIteratorT` | Random-access input iterator to the sequence of beginning offsets of length [`num_segments`](/library/api/cub::_dispatch_segmented_radix_sort::num_segments), such that [`d_begin_offsets`](/library/api/cub::_dispatch_segmented_radix_sort::d_begin_offsets)`[i]` is the first element of the *i*th data segment in `d_keys_*` and `d_values_*`. | +| `d_end_offsets` | `EndOffsetIteratorT` | Random-access input iterator to the sequence of ending offsets of length [`num_segments`](/library/api/cub::_dispatch_segmented_radix_sort::num_segments), such that [`d_end_offsets`](/library/api/cub::_dispatch_segmented_radix_sort::d_end_offsets)`[i]-1` is the last element of the *i*th data segment in `d_keys_*` and `d_values_*`. | +| `begin_bit` | `int` | The beginning (least-significant) bit index needed for key comparison. | +| `end_bit` | `int` | The past-the-end (most-significant) bit index needed for key comparison. | +| `stream` | `cudaStream_t` | CUDA stream to launch kernels within. Default is stream0. | +| `ptx_version` | `int` | PTX version. | +| `is_overwrite_okay` | `bool` | Whether is okay to overwrite source buffers. | +| `decomposer` | `DecomposerT` | | +| `kernel_source` | `KernelSource` | | +| `launcher_factory` | `KernelLauncherFactory` | | + +--- + +## Inner classes + +### PassConfig + + +```cpp showLineNumbers={false} +struct cub::DispatchSegmentedRadixSort::PassConfig +``` + + +[PassConfig](/library/api/cub::DispatchSegmentedRadixSort::PassConfig) data structure. + +| Name | Type | Description | +|---|---|---| +| `segmented_kernel` | `SegmentedKernelT` | | +| `segmented_config` | `detail::KernelConfig` | | +| `radix_bits` | `int` | | +| `radix_digits` | `int` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchSegmentedReduce.mdx b/fern/cudapages/cub/cub/cub/DispatchSegmentedReduce.mdx new file mode 100644 index 0000000..e20d207 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchSegmentedReduce.mdx @@ -0,0 +1,218 @@ +--- +title: cub::DispatchSegmentedReduce +description: "Utility class for dispatching the appropriately-tuned kernels for device-wide reduction." +--- + +Utility class for dispatching the appropriately-tuned kernels for device-wide reduction. + + + + + +Random-access input iterator type for reading input items (may be a simple pointer type) + + + +Output iterator type for recording the reduced aggregate (may be a simple pointer type) + + + +Random-access input iterator type for reading segment beginning offsets (may be a simple pointer type) + + + +Random-access input iterator type for reading segment ending offsets (may be a simple pointer type) + + + +Signed integer type for global offsets + + + +Binary reduction functor type having member `auto operator()(const T &a, const U &b)` + + + +Value type + + + + + + + + + + + + + + + + + +--- + +## Constructors + +### DispatchSegmentedReduce inline + +Constructor. + + +```cpp showLineNumbers={false} +cub::DispatchSegmentedReduce::DispatchSegmentedReduce( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + ReductionOpT reduction_op, + InitT init, + cudaStream_t stream, + int ptx_version, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {} +) +``` + + +--- + +## Methods + +### InvokePasses inline + +Invocation. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchSegmentedReduce::InvokePasses( + DeviceSegmentedReduceKernelT segmented_reduce_kernel, + ActivePolicyT policy = {} +) +``` + + +**Template parameters** + + +Umbrella policy active for the target device + + + +Function type of cub::DeviceSegmentedReduceKernel + + +**Parameters** + + +Kernel function pointer to instantiation of cub::DeviceSegmentedReduceKernel + + +### Invoke inline + +Invocation. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchSegmentedReduce::Invoke( + ActivePolicyT policy = {} +) +``` + + +--- + +## Static methods + +### Dispatch inline static + +Internal dispatch routine for computing a device-wide reduction. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchSegmentedReduce::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ::cuda::std::int64_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + ReductionOpT reduction_op, + InitT init, + cudaStream_t stream, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_segmented_reduce::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_segmented_reduce::d_temp_storage) allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the output aggregate + + + +The number of segments that comprise the sorting data + + + +Random-access input iterator to the sequence of beginning offsets of length [`num_segments`](/library/api/cub::_dispatch_segmented_reduce::num_segments), such that [`d_begin_offsets`](/library/api/cub::_dispatch_segmented_reduce::d_begin_offsets)`[i]` is the first element of the *i*th data segment in `d_keys_*` and `d_values_*` + + + +Random-access input iterator to the sequence of ending offsets of length [`num_segments`](/library/api/cub::_dispatch_segmented_reduce::num_segments), such that [`d_end_offsets`](/library/api/cub::_dispatch_segmented_reduce::d_end_offsets)`[i] - 1` is the last element of the *i*th data segment in `d_keys_*` and `d_values_*`. If [`d_end_offsets`](/library/api/cub::_dispatch_segmented_reduce::d_end_offsets)`[i] - 1 <= `[`d_begin_offsets`](/library/api/cub::_dispatch_segmented_reduce::d_begin_offsets)`[i]`, the *i*th is considered empty. + + + +Binary reduction functor + + + +The initial value of the reduction + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `d_temp_storage` | `void *` | Device-accessible allocation of temporary storage. | +| `temp_storage_bytes` | `size_t &` | Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_segmented_reduce::d_temp_storage) allocation. | +| `d_in` | `InputIteratorT` | Pointer to the input sequence of data items. | +| `d_out` | `OutputIteratorT` | Pointer to the output aggregate. | +| `num_segments` | `::cuda::std::int64_t` | The number of segments that comprise the segmented reduction data. | +| `d_begin_offsets` | `BeginOffsetIteratorT` | Random-access input iterator to the sequence of beginning offsets of length [`num_segments`](/library/api/cub::_dispatch_segmented_reduce::num_segments), such that [`d_begin_offsets`](/library/api/cub::_dispatch_segmented_reduce::d_begin_offsets)`[i]` is the first element of the *i*th data segment in `d_keys_*` and `d_values_*`. | +| `d_end_offsets` | `EndOffsetIteratorT` | Random-access input iterator to the sequence of ending offsets of length [`num_segments`](/library/api/cub::_dispatch_segmented_reduce::num_segments), such that [`d_end_offsets`](/library/api/cub::_dispatch_segmented_reduce::d_end_offsets)`[i] - 1` is the last element of the *i*th data segment in `d_keys_*` and `d_values_*`. | +| `reduction_op` | `ReductionOpT` | Binary reduction functor. | +| `init` | `InitT` | The initial value of the reduction. | +| `stream` | `cudaStream_t` | CUDA stream to launch kernels within. Default is stream0. | +| `ptx_version` | `int` | | +| `kernel_source` | `KernelSource` | | +| `launcher_factory` | `KernelLauncherFactory` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchSegmentedSort.mdx b/fern/cudapages/cub/cub/cub/DispatchSegmentedSort.mdx new file mode 100644 index 0000000..b38c67c --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchSegmentedSort.mdx @@ -0,0 +1,190 @@ +--- +title: cub::DispatchSegmentedSort +description: "" +--- + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +--- + +## Methods + +### Invoke inline + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchSegmentedSort::Invoke( + ActivePolicyT policy = {} +) +``` + + +### GetNumPasses inline + + +```cpp showLineNumbers={false} +int cub::DispatchSegmentedSort::GetNumPasses( + int radix_bits +) +``` + + +### GetFinalSelector inline + + +```cpp showLineNumbers={false} +int cub::DispatchSegmentedSort::GetFinalSelector( + int selector, + int radix_bits +) +``` + + +### GetFinalOutput inline + + +```cpp showLineNumbers={false} +template +T * cub::DispatchSegmentedSort::GetFinalOutput( + int radix_bits, + DoubleBuffer &buffer +) +``` + + +### SortWithPartitioning inline + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchSegmentedSort::SortWithPartitioning( + LargeKernelT large_kernel, + SmallKernelT small_kernel, + size_t three_way_partition_temp_storage_bytes, + cub::detail::device_double_buffer &d_keys_double_buffer, + cub::detail::device_double_buffer &d_values_double_buffer, + typename KernelSource::LargeSegmentsSelectorT &large_segments_selector, + typename KernelSource::SmallSegmentsSelectorT &small_segments_selector, + cub::detail::temporary_storage::alias &device_partition_temp_storage, + cub::detail::temporary_storage::alias &large_and_medium_segments_indices, + cub::detail::temporary_storage::alias &small_segments_indices, + cub::detail::temporary_storage::alias &group_sizes, + WrappedPolicyT wrapped_policy +) +``` + + +### SortWithoutPartitioning inline + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchSegmentedSort::SortWithoutPartitioning( + FallbackKernelT fallback_kernel, + cub::detail::device_double_buffer &d_keys_double_buffer, + cub::detail::device_double_buffer &d_values_double_buffer, + WrappedPolicyT wrapped_policy +) +``` + + +--- + +## Static methods + +### Dispatch inline static + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchSegmentedSort::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + DoubleBuffer &d_keys, + DoubleBuffer &d_values, + ::cuda::std::int64_t num_items, + global_segment_offset_t num_segments, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + bool is_overwrite_okay, + cudaStream_t stream, + KernelSource kernel_source = {}, + PartitionKernelSource partition_kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {}, + PartitionMaxPolicyT partition_max_policy = {} +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `local_segment_index_t` | `detail::segmented_sort::local_segment_index_t` | +| `global_segment_offset_t` | `detail::segmented_sort::global_segment_offset_t` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `KEYS_ONLY` static constexpr | `int` | | +| `num_selected_groups` static constexpr | `size_t` | | +| `d_temp_storage` | `void *` | Device-accessible allocation of temporary storage. | +| `temp_storage_bytes` | `size_t &` | Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_segmented_sort::d_temp_storage) allocation. | +| `d_keys` | `DoubleBuffer< KeyT > &` | Double-buffer whose current buffer contains the unsorted input keys and, upon return, is updated to point to the sorted output keys. | +| `d_values` | `DoubleBuffer< ValueT > &` | Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values. | +| `num_items` | `::cuda::std::int64_t` | Number of items to sort. | +| `num_segments` | `global_segment_offset_t` | The number of segments that comprise the sorting data. | +| `d_begin_offsets` | `BeginOffsetIteratorT` | Random-access input iterator to the sequence of beginning offsets of length [`num_segments`](/library/api/cub::_dispatch_segmented_sort::num_segments), such that [`d_begin_offsets`](/library/api/cub::_dispatch_segmented_sort::d_begin_offsets)`[i]` is the first element of the *i*th data segment in `d_keys_*` and `d_values_*`. | +| `d_end_offsets` | `EndOffsetIteratorT` | Random-access input iterator to the sequence of ending offsets of length [`num_segments`](/library/api/cub::_dispatch_segmented_sort::num_segments), such that [`d_end_offsets`](/library/api/cub::_dispatch_segmented_sort::d_end_offsets)`[i]-1` is the last element of the *i*th data segment in `d_keys_*` and `d_values_*`. | +| `is_overwrite_okay` | `bool` | Whether is okay to overwrite source buffers. | +| `stream` | `cudaStream_t` | CUDA stream to launch kernels within. | +| `kernel_source` | `KernelSource` | | +| `partition_kernel_source` | `PartitionKernelSource` | | +| `launcher_factory` | `KernelLauncherFactory` | | +| `partition_max_policy` | `PartitionPolicyHub::MaxPolicy` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchSelectIf.mdx b/fern/cudapages/cub/cub/cub/DispatchSelectIf.mdx new file mode 100644 index 0000000..d27187d --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchSelectIf.mdx @@ -0,0 +1,248 @@ +--- +title: cub::DispatchSelectIf +description: "Utility class for dispatching the appropriately-tuned kernels for [DeviceSelect](/library/api/cub::_device_select) and [DevicePartition](/library/api/cub::_device_partition)." +--- + +Utility class for dispatching the appropriately-tuned kernels for [DeviceSelect](/library/api/cub::_device_select) and [DevicePartition](/library/api/cub::_device_partition). + + + + + +Random-access input iterator type for reading input items + + + +Random-access input iterator type for reading selection flags (NullType* if a selection functor or discontinuity flagging is used for selection) + + + +Random-access output iterator type for writing selected items + + + +Output iterator type for recording the number of items selected + + + +Selection operator type (NullType if selection flags or discontinuity flagging is used for selection) + + + +Equality operator type (NullType if selection functor or selection flags are used for selection) + + + +Signed integer type for global offsets + + + +[SelectImpl](/library/api/cub::SelectImpl) indicating whether to partition, just selection or selection where the memory for the input and output may alias each other. + + + + + + + + +--- + +## Constructors + +### DispatchSelectIf inline + + +```cpp showLineNumbers={false} +cub::DispatchSelectIf::DispatchSelectIf( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FlagsInputIteratorT d_flags, + SelectedOutputIteratorT d_selected_out, + NumSelectedIteratorT d_num_selected_out, + SelectOpT select_op, + EqualityOpT equality_op, + OffsetT num_items, + cudaStream_t stream, + int ptx_version +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_select_if::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_select_if::d_temp_storage) allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the input sequence of selection flags (if applicable) + + + +Pointer to the output sequence of selected data items + + + +Pointer to the total number of items selected (i.e., length of [`d_selected_out`](/library/api/cub::_dispatch_select_if::d_selected_out)) + + + +Selection operator + + + +Equality operator + + + +Total number of input items (i.e., length of [`d_in`](/library/api/cub::_dispatch_select_if::d_in)) + + + +CUDA stream to launch kernels within. Default is stream0. + + +--- + +## Methods + +### Invoke inline + + + + +Internal dispatch routine for computing a device-wide selection using the specified kernel functions. + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchSelectIf::Invoke( + ScanInitKernelPtrT scan_init_kernel, + SelectIfKernelPtrT select_if_kernel +) +``` + + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchSelectIf::Invoke() +``` + + + + + +--- + +## Static methods + +### Dispatch inline static + +Internal dispatch routine. + + +```cpp showLineNumbers={false} +static cudaError_t cub::DispatchSelectIf::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FlagsInputIteratorT d_flags, + SelectedOutputIteratorT d_selected_out, + NumSelectedIteratorT d_num_selected_out, + SelectOpT select_op, + EqualityOpT equality_op, + OffsetT num_items, + cudaStream_t stream +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When `nullptr`, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_select_if::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_select_if::d_temp_storage) allocation + + + +Pointer to the input sequence of data items + + + +Pointer to the input sequence of selection flags (if applicable) + + + +Pointer to the output sequence of selected data items + + + +Pointer to the total number of items selected (i.e., length of [`d_selected_out`](/library/api/cub::_dispatch_select_if::d_selected_out)) + + + +Selection operator + + + +Equality operator + + + +Total number of input items (i.e., length of [`d_in`](/library/api/cub::_dispatch_select_if::d_in)) + + + +CUDA stream to launch kernels within. Default is stream0. + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `per_partition_offset_t` | `detail::select::per_partition_offset_t` | +| `num_total_items_t` | `OffsetT` | +| `streaming_context_t` | `detail::select::streaming_context_t< num_total_items_t, use_streaming_context >` | +| `ScanTileStateT` | `ScanTileState< per_partition_offset_t >` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `is_partitioning_invocation` static constexpr | `bool` | | +| `use_streaming_context` static constexpr | `bool` | | +| `INIT_KERNEL_THREADS` static constexpr | `int` | | +| `d_temp_storage` | `void *` | Device-accessible allocation of temporary storage. | +| `temp_storage_bytes` | `size_t &` | Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_select_if::d_temp_storage) allocation. | +| `d_in` | `InputIteratorT` | Pointer to the input sequence of data items. | +| `d_flags` | `FlagsInputIteratorT` | Pointer to the input sequence of selection flags (if applicable). | +| `d_selected_out` | `SelectedOutputIteratorT` | Pointer to the output sequence of selected data items. | +| `d_num_selected_out` | `NumSelectedIteratorT` | Pointer to the total number of items selected (i.e., length of [`d_selected_out`](/library/api/cub::_dispatch_select_if::d_selected_out)). | +| `select_op` | `SelectOpT` | Selection operator. | +| `equality_op` | `EqualityOpT` | Equality operator. | +| `num_items` | `OffsetT` | Total number of input items (i.e., length of [`d_in`](/library/api/cub::_dispatch_select_if::d_in)). | +| `stream` | `cudaStream_t` | CUDA stream to launch kernels within. Default is stream0. | +| `ptx_version` | `int` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchThreeWayPartitionIf.mdx b/fern/cudapages/cub/cub/cub/DispatchThreeWayPartitionIf.mdx new file mode 100644 index 0000000..ca12703 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchThreeWayPartitionIf.mdx @@ -0,0 +1,142 @@ +--- +title: cub::DispatchThreeWayPartitionIf +description: "" +--- + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +--- + +## Methods + +### Invoke inline + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchThreeWayPartitionIf::Invoke( + ActivePolicyT policy, + ScanInitKernelPtrT three_way_partition_init_kernel, + SelectIfKernelPtrT three_way_partition_kernel +) +``` + + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchThreeWayPartitionIf::Invoke( + ActivePolicyT active_policy = {} +) +``` + + + + + +--- + +## Static methods + +### Dispatch inline static + +Internal dispatch routine. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchThreeWayPartitionIf::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FirstOutputIteratorT d_first_part_out, + SecondOutputIteratorT d_second_part_out, + UnselectedOutputIteratorT d_unselected_out, + NumSelectedIteratorT d_num_selected_out, + SelectFirstPartOp select_first_part_op, + SelectSecondPartOp select_second_part_op, + OffsetT num_items, + cudaStream_t stream, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `per_partition_offset_t` | `detail::three_way_partition::per_partition_offset_t` | +| `streaming_context_t` | `detail::three_way_partition::streaming_context_t< OffsetT >` | +| `ScanTileStateT` | `detail::three_way_partition::ScanTileStateT` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `partition_size` static constexpr | `per_partition_offset_t` | | +| `INIT_KERNEL_THREADS` static constexpr | `int` | | +| `d_temp_storage` | `void *` | | +| `temp_storage_bytes` | `size_t &` | | +| `d_in` | `InputIteratorT` | | +| `d_first_part_out` | `FirstOutputIteratorT` | | +| `d_second_part_out` | `SecondOutputIteratorT` | | +| `d_unselected_out` | `UnselectedOutputIteratorT` | | +| `d_num_selected_out` | `NumSelectedIteratorT` | | +| `select_first_part_op` | `SelectFirstPartOp` | | +| `select_second_part_op` | `SelectSecondPartOp` | | +| `num_items` | `OffsetT` | | +| `stream` | `cudaStream_t` | | +| `kernel_source` | `KernelSource` | | +| `launcher_factory` | `KernelLauncherFactory` | | diff --git a/fern/cudapages/cub/cub/cub/DispatchUniqueByKey.mdx b/fern/cudapages/cub/cub/cub/DispatchUniqueByKey.mdx new file mode 100644 index 0000000..b261138 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/DispatchUniqueByKey.mdx @@ -0,0 +1,247 @@ +--- +title: cub::DispatchUniqueByKey +description: "Utility class for dispatching the appropriately-tuned kernels for [DeviceSelect](/library/api/cub::_device_select)." +--- + +Utility class for dispatching the appropriately-tuned kernels for [DeviceSelect](/library/api/cub::_device_select). + + + + + +Random-access input iterator type for keys + + + +Random-access input iterator type for values + + + +Random-access output iterator type for keys + + + +Random-access output iterator type for values + + + +Output iterator type for recording the number of items selected + + + +Equality operator type + + + +Signed integer type for global offsets + + + + + + + + + + + + + + + + + + + + + + + +--- + +## Constructors + +### DispatchUniqueByKey inline + + +```cpp showLineNumbers={false} +cub::DispatchUniqueByKey::DispatchUniqueByKey( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_keys_in, + ValueInputIteratorT d_values_in, + KeyOutputIteratorT d_keys_out, + ValueOutputIteratorT d_values_out, + NumSelectedIteratorT d_num_selected_out, + EqualityOpT equality_op, + OffsetT num_items, + cudaStream_t stream, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_unique_by_key::temp_storage_bytes) and no work is done. + + + +Pointer to the input sequence of keys + + + +Pointer to the input sequence of values + + + +Pointer to the output sequence of selected data items + + + +Pointer to the output sequence of selected data items + + + +Pointer to the total number of items selected (i.e., length of `d_keys_out` or `d_values_out`) + + + +Equality operator + + + +Total number of input items (i.e., length of `d_keys_in` or `d_values_in`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +--- + +## Methods + +### Invoke inline + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchUniqueByKey::Invoke( + InitKernelT init_kernel, + UniqueByKeySweepKernelT sweep_kernel, + ActivePolicyT policy = {} +) +``` + + + + + + +```cpp showLineNumbers={false} +template +cudaError_t cub::DispatchUniqueByKey::Invoke( + ActivePolicyT active_policy = {} +) +``` + + + + + +--- + +## Static methods + +### Dispatch inline static + +Internal dispatch routine. + + +```cpp showLineNumbers={false} +template +static cudaError_t cub::DispatchUniqueByKey::Dispatch( + void *d_temp_storage, + size_t &temp_storage_bytes, + KeyInputIteratorT d_keys_in, + ValueInputIteratorT d_values_in, + KeyOutputIteratorT d_keys_out, + ValueOutputIteratorT d_values_out, + NumSelectedIteratorT d_num_selected_out, + EqualityOpT equality_op, + OffsetT num_items, + cudaStream_t stream, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {} +) +``` + + +**Parameters** + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to [`temp_storage_bytes`](/library/api/cub::_dispatch_unique_by_key::temp_storage_bytes) and no work is done. + + + +Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_unique_by_key::d_temp_storage) allocation + + + +Pointer to the input sequence of keys + + + +Pointer to the input sequence of values + + + +Pointer to the output sequence of selected data items + + + +Pointer to the output sequence of selected data items + + + +Pointer to the total number of items selected (i.e., length of `d_keys_out` or `d_values_out`) + + + +Equality operator + + + +Total number of input items (i.e., the length of `d_in`) + + + +**[optional]** CUDA stream to launch kernels within. Default is stream0. + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `INIT_KERNEL_THREADS` static constexpr | `int` | | +| `d_temp_storage` | `void *` | Device-accessible allocation of temporary storage. | +| `temp_storage_bytes` | `size_t &` | Reference to size in bytes of [`d_temp_storage`](/library/api/cub::_dispatch_unique_by_key::d_temp_storage) allocation. | +| `d_keys_in` | `KeyInputIteratorT` | Pointer to the input sequence of keys. | +| `d_values_in` | `ValueInputIteratorT` | Pointer to the input sequence of values. | +| `d_keys_out` | `KeyOutputIteratorT` | Pointer to the output sequence of selected data items. | +| `d_values_out` | `ValueOutputIteratorT` | Pointer to the output sequence of selected data items. | +| `d_num_selected_out` | `NumSelectedIteratorT` | Pointer to the total number of items selected (i.e., length of `d_keys_out` or `d_values_out`). | +| `equality_op` | `EqualityOpT` | Equality operator. | +| `num_items` | `OffsetT` | Total number of input items (i.e., length of `d_keys_in` or `d_values_in`). | +| `stream` | `cudaStream_t` | **[optional]** CUDA stream to launch kernels within. Default is stream0. | +| `kernel_source` | `KernelSource` | | +| `launcher_factory` | `KernelLauncherFactory` | | diff --git a/fern/cudapages/cub/cub/cub/GridEvenShare.mdx b/fern/cudapages/cub/cub/cub/GridEvenShare.mdx new file mode 100644 index 0000000..00ba382 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/GridEvenShare.mdx @@ -0,0 +1,164 @@ +--- +title: cub::GridEvenShare +description: "[GridEvenShare](/library/api/cub::_grid_even_share) is a descriptor utility for distributing input among CUDA thread blocks in an \"even-share\" fashion." +--- + +`GridEvenShare` is a descriptor utility for distributing input among CUDA thread blocks in an "even-share" fashion. + +Each thread block gets roughly the same number of input tiles. + +**Overview** + +Each thread block is assigned a consecutive sequence of input tiles. To help preserve alignment and eliminate the overhead of guarded loads for all but the last thread block, to `GridEvenShare` assigns one of three different amounts of work to a given thread block: "big", "normal", or "last". The "big" workloads are one scheduling grain larger than "normal". The "last" work unit for the last thread block may be partially-full if the input is not an even multiple of the scheduling grain size. + +Before invoking a child grid, a parent thread will typically construct an instance of `GridEvenShare`. The instance can be passed to child thread blocks which can initialize their per-thread block offsets using [`BlockInit()`](/library/api/cub::_grid_even_share::BlockInit()). + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + + + + + + + +--- + +## Constructors + +### GridEvenShare inline + +Constructor. + + +```cpp showLineNumbers={false} +cub::GridEvenShare::GridEvenShare() +``` + + +--- + +## Methods + +### DispatchInit inline + +Dispatch initializer. + +To be called prior to kernel launch. + + +```cpp showLineNumbers={false} +void cub::GridEvenShare::DispatchInit( + OffsetT num_items_, + int max_grid_size, + int tile_items +) +``` + + +**Parameters** + + +Total number of input items + + + +Maximum grid size allowable (actual grid size may be less if not warranted by the the number of input items) + + + +Number of data items per input tile + + +### BlockInit inline + + + + +Initializes ranges for the specified thread block index. + +Specialized for a "raking" access pattern in which each thread block is assigned a consecutive sequence of input tiles. + + +```cpp showLineNumbers={false} +template +void cub::GridEvenShare::BlockInit( + int block_id, + detail::constant_t +) +``` + + + + + +Block-initialization, specialized for a "raking" access pattern in which each thread block is assigned a consecutive sequence of input tiles. + + +```cpp showLineNumbers={false} +template +void cub::GridEvenShare::BlockInit( + int block_id, + detail::constant_t +) +``` + + + + + +Block-initialization, specialized for "strip mining" access pattern in which the input tiles assigned to each thread block are separated by a stride equal to the the extent of the grid. + + +```cpp showLineNumbers={false} +template +void cub::GridEvenShare::BlockInit() +``` + + + + + +Block-initialization, specialized for a "raking" access pattern in which each thread block is assigned a consecutive sequence of input tiles. + + +```cpp showLineNumbers={false} +template +void cub::GridEvenShare::BlockInit( + OffsetT1 block_offset, + OffsetT1 block_end +) +``` + + +**Parameters** + + +Threadblock begin offset (inclusive) + + + +Threadblock end offset (exclusive) + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `total_tiles` | `int` | | +| `big_shares` | `int` | | +| `big_share_items` | `OffsetT` | | +| `normal_share_items` | `OffsetT` | | +| `normal_base_offset` | `OffsetT` | | +| `num_items` | `OffsetT` | Total number of input items. | +| `grid_size` | `int` | Grid size in thread blocks. | +| `block_offset` | `OffsetT` | OffsetT into input marking the beginning of the owning thread block's segment of input tiles. | +| `block_end` | `OffsetT` | OffsetT into input of marking the end (one-past) of the owning thread block's segment of input tiles. | +| `block_stride` | `OffsetT` | Stride between input tiles. | diff --git a/fern/cudapages/cub/cub/cub/GridQueue.mdx b/fern/cudapages/cub/cub/cub/GridQueue.mdx new file mode 100644 index 0000000..4255986 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/GridQueue.mdx @@ -0,0 +1,179 @@ +--- +title: cub::GridQueue +description: "[GridQueue](/library/api/cub::_grid_queue) is a descriptor utility for dynamic queue management." +--- + +`GridQueue` is a descriptor utility for dynamic queue management. + +**Overview** + +`GridQueue` descriptors provides abstractions for "filling" or "draining" globally-shared vectors. + +A "filling" `GridQueue` works by atomically-adding to a zero-initialized counter, returning a unique offset for the calling thread to write its items. The `GridQueue` maintains the total "fill-size". The fill counter must be reset using [GridQueue::ResetFill](/library/api/cub::_grid_queue::ResetFill) by the host or kernel instance prior to the kernel instance that will be filling. + +Similarly, a "draining" `GridQueue` works by atomically-incrementing a zero-initialized counter, returning a unique offset for the calling thread to read its items. Threads can safely drain until the array's logical fill-size is exceeded. The drain counter must be reset using [GridQueue::ResetDrain](/library/api/cub::_grid_queue::ResetDrain) or [GridQueue::FillAndResetDrain](/library/api/cub::_grid_queue::FillAndResetDrain) by the host or kernel instance prior to the kernel instance that will be filling. (For dynamic work distribution of existing data, the corresponding fill-size is simply the number of elements in the array.) + +Iterative work management can be implemented simply with a pair of flip-flopping work buffers, each with an associated set of fill and drain `GridQueue` descriptors. + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + + + + +Signed integer type for global offsets + + + + + +--- + +## Constructors + +### GridQueue inline + + + + +Constructs an invalid `GridQueue` descriptor. + + +```cpp showLineNumbers={false} +cub::GridQueue::GridQueue() +``` + + + + + +Constructs a `GridQueue` descriptor around the device storage allocation. + + +```cpp showLineNumbers={false} +cub::GridQueue::GridQueue( + void *d_storage +) +``` + + +**Parameters** + + +Device allocation to back the `GridQueue`. Must be at least as big as [`AllocationSize()`](/library/api/cub::_grid_queue::AllocationSize()). + + + + + +--- + +## Methods + +### FillAndResetDrain inline + +This operation sets the fill-size and resets the drain counter, preparing the `GridQueue` for draining in the next kernel instance. + +To be called by the host or by a kernel prior to the one which will be draining. + + +```cpp showLineNumbers={false} +cudaError_t cub::GridQueue::FillAndResetDrain( + OffsetT fill_size, + cudaStream_t stream = 0 +) +``` + + +### ResetDrain inline + +This operation resets the drain so that it may advance to meet the existing fill-size. + +To be called by the host or by a kernel prior to the one which will be draining. + + +```cpp showLineNumbers={false} +cudaError_t cub::GridQueue::ResetDrain( + cudaStream_t stream = 0 +) +``` + + +### ResetFill inline + +This operation resets the fill counter. + +To be called by the host or by a kernel prior to the one which will be filling. + + +```cpp showLineNumbers={false} +cudaError_t cub::GridQueue::ResetFill( + cudaStream_t stream = 0 +) +``` + + +### FillSize inline + +Returns the fill-size established by the parent or by the previous kernel. + + +```cpp showLineNumbers={false} +cudaError_t cub::GridQueue::FillSize( + OffsetT &fill_size, + cudaStream_t stream = 0 +) +``` + + +### Drain inline + +Drain `num_items` from the queue. + +Returns offset from which to read items. To be called from CUDA kernel. + + +```cpp showLineNumbers={false} +OffsetT cub::GridQueue::Drain( + OffsetT num_items +) +``` + + +### Fill inline + +Fill `num_items` into the queue. + +Returns offset from which to write items. To be called from CUDA kernel. + + +```cpp showLineNumbers={false} +OffsetT cub::GridQueue::Fill( + OffsetT num_items +) +``` + + +--- + +## Static methods + +### AllocationSize inline static + +Returns the device allocation size in bytes needed to construct a `GridQueue` instance. + + +```cpp showLineNumbers={false} +static size_t cub::GridQueue::AllocationSize() +``` + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `FILL` static constexpr | `int` | Counter indices. | +| `DRAIN` static constexpr | `int` | | +| `d_counters` | `OffsetT *` | Pair of counters. | diff --git a/fern/cudapages/cub/cub/cub/InequalityWrapper.mdx b/fern/cudapages/cub/cub/cub/InequalityWrapper.mdx new file mode 100644 index 0000000..f4e4c6f --- /dev/null +++ b/fern/cudapages/cub/cub/cub/InequalityWrapper.mdx @@ -0,0 +1,57 @@ +--- +title: cub::InequalityWrapper +description: "Inequality functor (wraps equality functor)." +--- + +Inequality functor (wraps equality functor). + + + + + + + + + + +--- + +## Constructors + +### InequalityWrapper inline + +Constructor. + + +```cpp showLineNumbers={false} +cub::InequalityWrapper::InequalityWrapper( + EqualityOp op +) +``` + + +--- + +## Methods + +### operator() inline + +Boolean inequality operator, returns `t != u`. + + +```cpp showLineNumbers={false} +template +bool cub::InequalityWrapper::operator()( + T &&t, + U &&u +) +``` + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `op` | `EqualityOp` | Wrapped equality operator. | diff --git a/fern/cudapages/cub/cub/cub/PtxVersionCacheTag.mdx b/fern/cudapages/cub/cub/cub/PtxVersionCacheTag.mdx new file mode 100644 index 0000000..a120828 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/PtxVersionCacheTag.mdx @@ -0,0 +1,4 @@ +--- +title: cub::PtxVersionCacheTag +description: "" +--- diff --git a/fern/cudapages/cub/cub/cub/RadixSortTwiddle.mdx b/fern/cudapages/cub/cub/cub/RadixSortTwiddle.mdx new file mode 100644 index 0000000..acd6b7c --- /dev/null +++ b/fern/cudapages/cub/cub/cub/RadixSortTwiddle.mdx @@ -0,0 +1,70 @@ +--- +title: cub::RadixSortTwiddle +description: "Twiddling keys for radix sort." +--- + +Twiddling keys for radix sort. + + + + + + + + + + + + + +--- + +## Static methods + +### In inline static + + +```cpp showLineNumbers={false} +template +static bit_ordered_type cub::RadixSortTwiddle::In( + bit_ordered_type key, + DecomposerT decomposer = {} +) +``` + + +### Out inline static + + +```cpp showLineNumbers={false} +template +static bit_ordered_type cub::RadixSortTwiddle::Out( + bit_ordered_type key, + DecomposerT decomposer = {} +) +``` + + +### DefaultKey inline static + + +```cpp showLineNumbers={false} +template +static bit_ordered_type cub::RadixSortTwiddle::DefaultKey( + DecomposerT decomposer = {} +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `traits` | `detail::radix::traits_t< KeyT >` | +| `bit_ordered_type` | `typename traits::bit_ordered_type` | +| `bit_ordered_conversion_policy` | `typename traits::bit_ordered_conversion_policy` | +| `bit_ordered_inversion_policy` | `typename traits::bit_ordered_inversion_policy` | diff --git a/fern/cudapages/cub/cub/cub/ReduceByKeyOp.mdx b/fern/cudapages/cub/cub/cub/ReduceByKeyOp.mdx new file mode 100644 index 0000000..85c6bce --- /dev/null +++ b/fern/cudapages/cub/cub/cub/ReduceByKeyOp.mdx @@ -0,0 +1,83 @@ +--- +title: cub::ReduceByKeyOp +description: "" +--- + + + + + +Binary reduction operator to apply to values + + + + + +--- + +## Constructors + +### ReduceByKeyOp inline + + + + +Constructor. + + +```cpp showLineNumbers={false} +cub::ReduceByKeyOp::ReduceByKeyOp() +``` + + + + + +Constructor. + + +```cpp showLineNumbers={false} +cub::ReduceByKeyOp::ReduceByKeyOp( + ReductionOpT op +) +``` + + + + + +--- + +## Methods + +### operator() inline + +Scan operator. + + +```cpp showLineNumbers={false} +template +KeyValuePairT cub::ReduceByKeyOp::operator()( + const KeyValuePairT &first, + const KeyValuePairT &second +) +``` + + +**Parameters** + + +First partial reduction + + + +Second partial reduction + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `op` | `ReductionOpT` | Wrapped reduction operator. | diff --git a/fern/cudapages/cub/cub/cub/ReduceByKeyScanTileState.mdx b/fern/cudapages/cub/cub/cub/ReduceByKeyScanTileState.mdx new file mode 100644 index 0000000..02c057b --- /dev/null +++ b/fern/cudapages/cub/cub/cub/ReduceByKeyScanTileState.mdx @@ -0,0 +1,21 @@ +--- +title: cub::ReduceByKeyScanTileState +description: "Tile status interface for reduction by key." +--- + +Tile status interface for reduction by key. + + + + + + + + + + + + + + + diff --git a/fern/cudapages/cub/cub/cub/ReduceByKeyScanTileState_ValueT_KeyT_false.mdx b/fern/cudapages/cub/cub/cub/ReduceByKeyScanTileState_ValueT_KeyT_false.mdx new file mode 100644 index 0000000..c83313c --- /dev/null +++ b/fern/cudapages/cub/cub/cub/ReduceByKeyScanTileState_ValueT_KeyT_false.mdx @@ -0,0 +1,44 @@ +--- +title: "cub::ReduceByKeyScanTileState< ValueT, KeyT, false >" +description: "Tile status interface for reduction by key, specialized for scan status and value types that cannot be combined into one machine word." +--- + +Tile status interface for reduction by key, specialized for scan status and value types that cannot be combined into one machine word. + + + + + + + + + + + + + +**Inherits from:** `cub::ScanTileState< KeyValuePair< KeyT, ValueT > >` (public) + +--- + +## Methods + +### ReduceByKeyScanTileState inline + +Constructor. + +": "/library/api/cub::ReduceByKeyScanTileState%3C ValueT, KeyT, false %3E"}}> +```cpp showLineNumbers={false} +cub::ReduceByKeyScanTileState::ReduceByKeyScanTileState() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `SuperClass` | `ScanTileState< KeyValuePair< KeyT, ValueT > >` | diff --git a/fern/cudapages/cub/cub/cub/ReduceBySegmentOp.mdx b/fern/cudapages/cub/cub/cub/ReduceBySegmentOp.mdx new file mode 100644 index 0000000..9fa4308 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/ReduceBySegmentOp.mdx @@ -0,0 +1,95 @@ +--- +title: cub::ReduceBySegmentOp +description: "Reduce-by-segment functor." +--- + +Reduce-by-segment functor. + +Given two cub::KeyValuePair inputs `a` and `b` and a binary associative combining operator `f(const T &x, const T &y)`, an instance of this functor returns a cub::KeyValuePair whose `key` field is `a.key + b.key`, and whose `value` field is either `b.value` if `b.key` is non-zero, or `f(a.value, b.value)` otherwise. + +`ReduceBySegmentOp` is an associative, non-commutative binary combining operator for input sequences of cub::KeyValuePair pairings. Such sequences are typically used to represent a segmented set of values to be reduced and a corresponding set of {0,1}-valued integer "head flags" demarcating the first value of each segment. + + + + + +Binary reduction operator to apply to values + + + + + +--- + +## Constructors + +### ReduceBySegmentOp inline + + + + +Constructor. + + +```cpp showLineNumbers={false} +cub::ReduceBySegmentOp::ReduceBySegmentOp() +``` + + + + + +Constructor. + + +```cpp showLineNumbers={false} +cub::ReduceBySegmentOp::ReduceBySegmentOp( + ReductionOpT op +) +``` + + + + + +--- + +## Methods + +### operator() inline + +Scan operator. + + +```cpp showLineNumbers={false} +template +KeyValuePairT cub::ReduceBySegmentOp::operator()( + const KeyValuePairT &first, + const KeyValuePairT &second +) +``` + + +**Template parameters** + + +KeyValuePair pairing of T (value) and OffsetT (head flag) + + +**Parameters** + + +First partial reduction + + + +Second partial reduction + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `op` | `ReductionOpT` | Wrapped reduction operator. | diff --git a/fern/cudapages/cub/cub/cub/ScanTileState.mdx b/fern/cudapages/cub/cub/cub/ScanTileState.mdx new file mode 100644 index 0000000..2811669 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/ScanTileState.mdx @@ -0,0 +1,18 @@ +--- +title: cub::ScanTileState +description: "Tile status interface." +--- + +Tile status interface. + + + + + + + + + + + + diff --git a/fern/cudapages/cub/cub/cub/ScanTileState_T_false.mdx b/fern/cudapages/cub/cub/cub/ScanTileState_T_false.mdx new file mode 100644 index 0000000..6c21c9e --- /dev/null +++ b/fern/cudapages/cub/cub/cub/ScanTileState_T_false.mdx @@ -0,0 +1,180 @@ +--- +title: "cub::ScanTileState< T, false >" +description: "Tile status interface specialized for scan status and value types that can be combined into one machine word that can be read/written coherently in a single access." +--- + +Tile status interface specialized for scan status and value types that can be combined into one machine word that can be read/written coherently in a single access. + +Tile status interface specialized for scan status and value types that cannot be combined into one machine word. + + + + + + + + + + +--- + +## Methods + +### ScanTileState inline + +Constructor. + +": "/library/api/cub::ScanTileState%3C T, false %3E"}}> +```cpp showLineNumbers={false} +cub::ScanTileState::ScanTileState() +``` + + +### Init inline + +Initializer. + +": "/library/api/cub::ScanTileState%3C T, false %3E"}}> +```cpp showLineNumbers={false} +cudaError_t cub::ScanTileState::Init( + int num_tiles, + void *d_temp_storage, + size_t temp_storage_bytes +) +``` + + +**Parameters** + + +Number of tiles + + + +Device-accessible allocation of temporary storage. When nullptr, the required allocation size is written to `temp_storage_bytes` and no work is done. + + + +Size in bytes of `d_temp_storage` allocation Initializer + + +### InitializeStatus inline + +Initialize (from device). + +": "/library/api/cub::ScanTileState%3C T, false %3E"}}> +```cpp showLineNumbers={false} +void cub::ScanTileState::InitializeStatus( + int num_tiles +) +``` + + +### SetInclusive inline + +Update the specified tile's inclusive value and corresponding status. + +": "/library/api/cub::ScanTileState%3C T, false %3E"}}> +```cpp showLineNumbers={false} +template +void cub::ScanTileState::SetInclusive( + int tile_idx, + T tile_inclusive +) +``` + + +### SetPartial inline + +Update the specified tile's partial value and corresponding status. + +": "/library/api/cub::ScanTileState%3C T, false %3E"}}> +```cpp showLineNumbers={false} +template +void cub::ScanTileState::SetPartial( + int tile_idx, + T tile_partial +) +``` + + +### WaitForValid inline + +Wait for the corresponding tile to become non-invalid. + +": "/library/api/cub::ScanTileState%3C T, false %3E"}}> +```cpp showLineNumbers={false} +template +void cub::ScanTileState::WaitForValid( + int tile_idx, + StatusWord &status, + T &value, + DelayT delay = {} +) +``` + + +### LoadValid inline + +Loads and returns the tile's value. + +The returned value is undefined if either (a) the tile's status is invalid or (b) there is no memory fence between reading a non-invalid status and the call to LoadValid. + +": "/library/api/cub::ScanTileState%3C T, false %3E"}}> +```cpp showLineNumbers={false} +T cub::ScanTileState::LoadValid( + int tile_idx +) +``` + + +--- + +## Static methods + +### AllocationSize inline static constexpr + +Compute device memory needed for tile status. + +": "/library/api/cub::ScanTileState%3C T, false %3E"}}> +```cpp showLineNumbers={false} +static constexpr cudaError_t cub::ScanTileState::AllocationSize( + int num_tiles, + size_t &temp_storage_bytes +) +``` + + +**Parameters** + + +Number of tiles + + + +Size in bytes of `d_temp_storage` allocation + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `StatusValueT` | `T` | +| `StatusWord` | `unsigned int` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `TILE_STATUS_PADDING` static constexpr | `int` | | +| `description_bytes_per_tile` static constexpr | `size_t` | | +| `payload_bytes_per_tile` static constexpr | `size_t` | | +| `d_tile_status` | `StatusWord *` | | +| `d_tile_partial` | `T *` | | +| `d_tile_inclusive` | `T *` | | diff --git a/fern/cudapages/cub/cub/cub/ShiftDigitExtractor.mdx b/fern/cudapages/cub/cub/cub/ShiftDigitExtractor.mdx new file mode 100644 index 0000000..02ebe18 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/ShiftDigitExtractor.mdx @@ -0,0 +1,82 @@ +--- +title: cub::ShiftDigitExtractor +description: "A wrapper type to extract digits." +--- + +A wrapper type to extract digits. + +Uses a combination of shift and bitwise and to extract digits. + + + + + + + + + + +**Inherits from:** `cub::BaseDigitExtractor< KeyT >` (public) + +--- + +## Constructors + +### ShiftDigitExtractor inline explicit + + +```cpp showLineNumbers={false} +cub::ShiftDigitExtractor::ShiftDigitExtractor( + ::cuda::std::uint32_t bit_start = 0, + ::cuda::std::uint32_t num_bits = 0 +) +``` + + +--- + +## Methods + +### Digit inline const + + +```cpp showLineNumbers={false} +::cuda::std::uint32_t cub::ShiftDigitExtractor::Digit( + UnsignedBits key +) const +``` + + +--- + +## Static methods + +### ProcessFloatMinusZero inline static + + +```cpp showLineNumbers={false} +static UnsignedBits cub::BaseDigitExtractor::ProcessFloatMinusZero( + UnsignedBits key +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `TraitsT` | `Traits< KeyT >` | +| `UnsignedBits` | `typename TraitsT::UnsignedBits` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `bit_start` | `::cuda::std::uint32_t` | | +| `mask` | `::cuda::std::uint32_t` | | diff --git a/fern/cudapages/cub/cub/cub/SmVersionCacheTag.mdx b/fern/cudapages/cub/cub/cub/SmVersionCacheTag.mdx new file mode 100644 index 0000000..54e358f --- /dev/null +++ b/fern/cudapages/cub/cub/cub/SmVersionCacheTag.mdx @@ -0,0 +1,4 @@ +--- +title: cub::SmVersionCacheTag +description: "" +--- diff --git a/fern/cudapages/cub/cub/cub/SwizzleScanOp.mdx b/fern/cudapages/cub/cub/cub/SwizzleScanOp.mdx new file mode 100644 index 0000000..017892f --- /dev/null +++ b/fern/cudapages/cub/cub/cub/SwizzleScanOp.mdx @@ -0,0 +1,57 @@ +--- +title: cub::SwizzleScanOp +description: "Binary operator wrapper for switching non-commutative scan arguments." +--- + +Binary operator wrapper for switching non-commutative scan arguments. + + + + + + + + + + +--- + +## Constructors + +### SwizzleScanOp inline + +Constructor. + + +```cpp showLineNumbers={false} +cub::SwizzleScanOp::SwizzleScanOp( + ScanOp scan_op +) +``` + + +--- + +## Methods + +### operator() inline + +Switch the scan arguments. + + +```cpp showLineNumbers={false} +template +T cub::SwizzleScanOp::operator()( + const T &a, + const T &b +) +``` + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `scan_op` | `ScanOp` | Wrapped scan operator. | diff --git a/fern/cudapages/cub/cub/cub/TilePrefixCallbackOp.mdx b/fern/cudapages/cub/cub/cub/TilePrefixCallbackOp.mdx new file mode 100644 index 0000000..c76d316 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/TilePrefixCallbackOp.mdx @@ -0,0 +1,192 @@ +--- +title: cub::TilePrefixCallbackOp +description: "Tile status interface for reduction by key, specialized for scan status and value types that can be combined into one machine word that can be read/written coherently in a single access." +--- + +Tile status interface for reduction by key, specialized for scan status and value types that can be combined into one machine word that can be read/written coherently in a single access. + +Stateful block-scan prefix functor. Provides the running prefix for the current tile by using the callback warp to wait for aggregates/prefixes from predecessor tiles to become available. + + + + + + + + + + + + + + +Implementation detail, do not specify directly, requirements on the content of this type are subject to breaking change. + + + + + +--- + +## Constructors + +### TilePrefixCallbackOp inline + + + + + +```cpp showLineNumbers={false} +cub::TilePrefixCallbackOp::TilePrefixCallbackOp( + ScanTileStateT &tile_status, + TempStorage &temp_storage, + ScanOpT scan_op, + int tile_idx +) +``` + + + + + + +```cpp showLineNumbers={false} +cub::TilePrefixCallbackOp::TilePrefixCallbackOp( + ScanTileStateT &tile_status, + TempStorage &temp_storage, + ScanOpT scan_op +) +``` + + + + + +--- + +## Methods + +### ProcessWindow inline + +Block until all predecessors within the warp-wide window have non-invalid status. + + +```cpp showLineNumbers={false} +template +void cub::TilePrefixCallbackOp::ProcessWindow( + int predecessor_idx, + StatusWord &predecessor_status, + T &window_aggregate, + DelayT delay = {} +) +``` + + +**Parameters** + + +Preceding tile index to inspect + + + +Preceding tile status + + + +Relevant partial reduction from this window of preceding tiles + + +### operator() inline + + +```cpp showLineNumbers={false} +T cub::TilePrefixCallbackOp::operator()( + T block_aggregate +) +``` + + +### GetExclusivePrefix inline + + +```cpp showLineNumbers={false} +T cub::TilePrefixCallbackOp::GetExclusivePrefix() +``` + + +### GetInclusivePrefix inline + + +```cpp showLineNumbers={false} +T cub::TilePrefixCallbackOp::GetInclusivePrefix() +``` + + +### GetBlockAggregate inline + + +```cpp showLineNumbers={false} +T cub::TilePrefixCallbackOp::GetBlockAggregate() +``` + + +### GetTileIdx inline const + + +```cpp showLineNumbers={false} +int cub::TilePrefixCallbackOp::GetTileIdx() const +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `WarpReduceT` | `WarpReduce< T,(1<<(5))>` | +| `StatusWord` | `typename ScanTileStateT::StatusWord` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `temp_storage` | `_TempStorage &` | Reference to a warp-reduction instance. | +| `tile_status` | `ScanTileStateT &` | Interface to tile status. | +| `scan_op` | `ScanOpT` | Binary scan operator. | +| `tile_idx` | `int` | The current tile index. | +| `exclusive_prefix` | `T` | Exclusive prefix for the tile. | +| `inclusive_prefix` | `T` | Inclusive prefix for the tile. | + +--- + +## Inner classes + +### _TempStorage + + +```cpp showLineNumbers={false} +struct cub::TilePrefixCallbackOp::_TempStorage +``` + + +| Name | Type | Description | +|---|---|---| +| `warp_reduce` | `WarpReduceT::TempStorage` | | +| `exclusive_prefix` | `T` | | +| `inclusive_prefix` | `T` | | +| `block_aggregate` | `T` | | + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::TilePrefixCallbackOp::TempStorage +``` + + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/WarpExchange.mdx b/fern/cudapages/cub/cub/cub/WarpExchange.mdx new file mode 100644 index 0000000..430279d --- /dev/null +++ b/fern/cudapages/cub/cub/cub/WarpExchange.mdx @@ -0,0 +1,244 @@ +--- +title: cub::WarpExchange +description: "The [WarpExchange](/library/api/cub::_warp_exchange) class provides [collective](../index.html#sec0) methods for rearranging data partitioned across a CUDA warp." +--- + +The `WarpExchange` class provides [collective](../index.html#sec0) methods for rearranging data partitioned across a CUDA warp. + +**Overview** + +- It is commonplace for a warp of threads to rearrange data items between threads. For example, the global memory accesses prefer patterns where data items are "striped" across threads (where consecutive threads access consecutive items), yet most warp-wide operations prefer a "blocked" partitioning of items across threads (where consecutive items belong to a single thread). +- `WarpExchange` supports the following types of data exchanges: +Transposing between blocked and striped arrangements +Scattering ranked items to a striped arrangement + - Transposing between [blocked](../index.html#sec5sec3) and [striped](../index.html#sec5sec3) arrangements + - Scattering ranked items to a [striped arrangement](../index.html#sec5sec3) + +**A Simple Example** + + +The code snippet below illustrates the conversion from a "blocked" to a "striped" arrangement of 64 integer items partitioned across 16 threads where each thread owns 4 items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + constexpr int warp_threads = 16; + constexpr int block_threads = 256; + constexpr int items_per_thread = 4; + constexpr int warps_per_block = block_threads / warp_threads; + const int warp_id = static_cast(threadIdx.x) / warp_threads; + + // Specialize WarpExchange for a virtual warp of 16 threads owning 4 integer items each + using WarpExchangeT = + cub::WarpExchange; + + // Allocate shared memory for WarpExchange + __shared__ typename WarpExchangeT::TempStorage temp_storage[warps_per_block]; + + // Load a tile of data striped across threads + int thread_data[items_per_thread]; + // ... + + // Collectively exchange data into a blocked arrangement across threads + WarpExchangeT(temp_storage[warp_id]).StripedToBlocked(thread_data, thread_data); +``` + +Suppose the set of striped input `thread_data` across the block of threads is `{ [0,16,32,48], [1,17,33,49], ..., [15, 32, 47, 63] }`. The corresponding output `thread_data` in those threads will be `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [60,61,62,63] }`. + + + + + + + + +The number of items partitioned onto each thread. + + + +**[optional]** The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute-capability (e.g., 32 threads for SM86). Must be a power of two. + + + + + + + + +**Inherits from:** `detail::InternalWarpExchangeImpl< InputT, ITEMS_PER_THREAD, detail::warp_threads, WARP_EXCHANGE_SMEM >` (private) + +--- + +## Collective constructors + +### WarpExchange inline explicit + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::WarpExchange::WarpExchange( + TempStorage &temp_storage +) +``` + + + + + + +```cpp showLineNumbers={false} +cub::WarpExchange::WarpExchange() = delete +``` + + + + + +--- + +## Data movement + +### BlockedToStriped inline + +Transposes data items from *blocked* arrangement to *striped* arrangement. + + +```cpp showLineNumbers={false} +template +void cub::WarpExchange::BlockedToStriped( + const InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Items to exchange, converting between *blocked* and *striped* arrangements. + + + +Items from exchange, converting between *striped* and *blocked* arrangements. May be aliased to `input_items`. + + +### StripedToBlocked inline + +Transposes data items from *striped* arrangement to *blocked* arrangement. + + +```cpp showLineNumbers={false} +template +void cub::WarpExchange::StripedToBlocked( + const InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Parameters** + + +Items to exchange + + + +Items from exchange. May be aliased to `input_items`. + + +### ScatterToStriped inline + + + + +Exchanges valid data items annotated by rank into *striped* arrangement. + + +```cpp showLineNumbers={false} +template +void cub::WarpExchange::ScatterToStriped( + InputT (&items)[ITEMS_PER_THREAD], + OffsetT (&ranks)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** Signed integer type for local offsets + + +**Parameters** + + +Items to exchange + + + +Corresponding scatter ranks + + + + + +Exchanges valid data items annotated by rank into *striped* arrangement. + + +```cpp showLineNumbers={false} +template +void cub::WarpExchange::ScatterToStriped( + const InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD], + OffsetT (&ranks)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +**Template parameters** + + +**[inferred]** Signed integer type for local offsets + + +**Parameters** + + +Items to exchange + + + +Items from exchange. May be aliased to `input_items`. + + + +Corresponding scatter ranks + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `InternalWarpExchange` | `detail::InternalWarpExchangeImpl< InputT, ITEMS_PER_THREAD, LOGICAL_WARP_THREADS, WARP_EXCHANGE_ALGORITHM >` | | +| `TempStorage` | `typename InternalWarpExchange::TempStorage` | The operations exposed by `WarpExchange` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. | diff --git a/fern/cudapages/cub/cub/cub/WarpLoad.mdx b/fern/cudapages/cub/cub/cub/WarpLoad.mdx new file mode 100644 index 0000000..350cd84 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/WarpLoad.mdx @@ -0,0 +1,421 @@ +--- +title: cub::WarpLoad +description: "" +--- + +The WarpLoad class provides collective data movement methods for loading a linear segment of items from memory into a blocked arrangement across a CUDA thread warp. + +## Example + +The code snippet below illustrates the loading of a linear segment of 64 integers into a "blocked" arrangement across 16 threads where each thread owns 4 consecutive items. The load is specialized for `WARP_LOAD_TRANSPOSE`, meaning memory references are efficiently coalesced using a warp-striped access pattern (after which items are locally reordered among threads). + +The set of `thread_data` across the first logical warp of threads in those threads will be: `{ [0,1,2,3], [4,5,6,7], ..., [60,61,62,63] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + constexpr int warp_threads = 16; + constexpr int block_threads = 256; + constexpr int items_per_thread = 4; + + // Specialize WarpLoad for a warp of 16 threads owning 4 integer items each + using WarpLoadT = WarpLoad; + + constexpr int warps_in_block = block_threads / warp_threads; + constexpr int tile_size = items_per_thread * warp_threads; + const int warp_id = static_cast(threadIdx.x) / warp_threads; + + // Allocate shared memory for WarpLoad + __shared__ typename WarpLoadT::TempStorage temp_storage[warps_in_block]; + + // Load a segment of consecutive items that are blocked across threads + int thread_data[items_per_thread]; + WarpLoadT(temp_storage[warp_id]).Load(d_data + warp_id * tile_size, thread_data); +} +``` + + + + + +The data type to read into (which must be convertible from the input iterator's value type). + + + +The number of consecutive items partitioned onto each thread. + + + +**[optional]** [cub::WarpLoadAlgorithm](/library/api/cub::WarpLoadAlgorithm) tuning policy. default: [cub::WARP_LOAD_DIRECT](/library/api/cub::WARP_LOAD_DIRECT). + + + +**[optional]** The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute-capability (e.g., 32 threads for SM86). Must be a power of two. + + + + + +--- + +## Collective constructors + +### WarpLoad inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::WarpLoad::WarpLoad() +``` + + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::WarpLoad::WarpLoad( + TempStorage &temp_storage +) +``` + + + + + +--- + +## Data movement + +### Load inline + + + + +Load a linear segment of items from memory. + + +```cpp showLineNumbers={false} +template +void cub::WarpLoad::Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The thread block's base input iterator for loading from + + + +Data to load + + +**Example** + +The set of `thread_data` across the first logical warp of threads in those threads will be: `{ [0,1,2,3], [4,5,6,7], ..., [60,61,62,63] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + constexpr int warp_threads = 16; + constexpr int block_threads = 256; + constexpr int items_per_thread = 4; + + // Specialize WarpLoad for a warp of 16 threads owning 4 integer items each + using WarpLoadT = WarpLoad; + + constexpr int warps_in_block = block_threads / warp_threads; + constexpr int tile_size = items_per_thread * warp_threads; + const int warp_id = static_cast(threadIdx.x) / warp_threads; + + // Allocate shared memory for WarpLoad + __shared__ typename WarpLoadT::TempStorage temp_storage[warps_in_block]; + + // Load a segment of consecutive items that are blocked across threads + int thread_data[items_per_thread]; + WarpLoadT(temp_storage[warp_id]).Load(d_data + warp_id * tile_size, thread_data); +} +``` + + + + +Load a linear segment of items from memory, guarded by range. + + +```cpp showLineNumbers={false} +template +void cub::WarpLoad::Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The thread block's base input iterator for loading from + + + +Data to load + + + +Number of valid items to load + + +**Example** + +The set of `thread_data` across the first logical warp of threads in those threads will be: `{ [0,1,2,3], [4,?,?,?], ..., [?,?,?,?] }` with only the first two threads being unmasked to load portions of valid data (and other items remaining unassigned). + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, int valid_items, ...) +{ + constexpr int warp_threads = 16; + constexpr int block_threads = 256; + constexpr int items_per_thread = 4; + + // Specialize WarpLoad for a warp of 16 threads owning 4 integer items each + using WarpLoadT = WarpLoad; + + constexpr int warps_in_block = block_threads / warp_threads; + constexpr int tile_size = items_per_thread * warp_threads; + const int warp_id = static_cast(threadIdx.x) / warp_threads; + + // Allocate shared memory for WarpLoad + __shared__ typename WarpLoadT::TempStorage temp_storage[warps_in_block]; + + // Load a segment of consecutive items that are blocked across threads + int thread_data[items_per_thread]; + WarpLoadT(temp_storage[warp_id]).Load(d_data + warp_id * tile_size, thread_data, + valid_items); +} +``` + + + + +Load a linear segment of items from memory, guarded by range. + + +```cpp showLineNumbers={false} +template +void cub::WarpLoad::Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items, + DefaultT oob_default +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The thread block's base input iterator for loading from + + + +Data to load + + + +Number of valid items to load + + + +Default value to assign out-of-bound items + + +**Example** + +out-of-bounds default is `-1`. The set of `thread_data` across the first logical warp of threads in those threads will be: `{ [0,1,2,3], [4,-1,-1,-1], ..., [-1,-1,-1,-1] }` with only the first two threads being unmasked to load portions of valid data (and other items are assigned `-1`). + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, int valid_items, ...) +{ + constexpr int warp_threads = 16; + constexpr int block_threads = 256; + constexpr int items_per_thread = 4; + + // Specialize WarpLoad for a warp of 16 threads owning 4 integer items each + using WarpLoadT = WarpLoad; + + constexpr int warps_in_block = block_threads / warp_threads; + constexpr int tile_size = items_per_thread * warp_threads; + const int warp_id = static_cast(threadIdx.x) / warp_threads; + + // Allocate shared memory for WarpLoad + __shared__ typename WarpLoadT::TempStorage temp_storage[warps_in_block]; + + // Load a segment of consecutive items that are blocked across threads + int thread_data[items_per_thread]; + WarpLoadT(temp_storage[warp_id]).Load(d_data + warp_id * tile_size, + thread_data, + valid_items, + -1); +``` + + + + +--- + +## Utility methods + +### PrivateStorage inline + +Internal storage allocator. + + +```cpp showLineNumbers={false} +_TempStorage & cub::WarpLoad::PrivateStorage() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `InternalLoad` | `LoadInternal< ALGORITHM, 0 >` | Internal load implementation to use. | +| `_TempStorage` | `typename InternalLoad::TempStorage` | Shared memory storage layout type. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `IS_ARCH_WARP` static constexpr | `bool` | | +| `temp_storage` | `_TempStorage &` | Thread reference to shared storage. | +| `linear_tid` | `int` | Linear thread-id. | + +--- + +## Inner classes + +### LoadInternal + + +```cpp showLineNumbers={false} +struct cub::WarpLoad::LoadInternal +``` + + +Load helper. + +### LoadInternal< WARP_LOAD_DIRECT, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::WarpLoad::LoadInternal< WARP_LOAD_DIRECT, DUMMY > +``` + + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | | + +### LoadInternal< WARP_LOAD_STRIPED, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::WarpLoad::LoadInternal< WARP_LOAD_STRIPED, DUMMY > +``` + + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | | + +### LoadInternal< WARP_LOAD_VECTORIZE, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::WarpLoad::LoadInternal< WARP_LOAD_VECTORIZE, DUMMY > +``` + + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | | + +### LoadInternal< WARP_LOAD_TRANSPOSE, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::WarpLoad::LoadInternal< WARP_LOAD_TRANSPOSE, DUMMY > +``` + + +| Name | Type | Description | +|---|---|---| +| `temp_storage` | `_TempStorage &` | | +| `linear_tid` | `int` | | + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::WarpLoad::TempStorage +``` + + +The operations exposed by `WarpLoad` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/WarpMergeSort.mdx b/fern/cudapages/cub/cub/cub/WarpMergeSort.mdx new file mode 100644 index 0000000..891f13f --- /dev/null +++ b/fern/cudapages/cub/cub/cub/WarpMergeSort.mdx @@ -0,0 +1,145 @@ +--- +title: cub::WarpMergeSort +description: "" +--- + +The WarpMergeSort class provides methods for sorting items partitioned across a CUDA warp using a merge sorting method. + +## Example + +The code snippet below illustrates a sort of 64 integer keys that are partitioned across 16 threads where each thread owns 4 consecutive items. + +`{ [0,64,1,63], [2,62,3,61], [4,60,5,59], ..., [31,34,32,33] }`. The corresponding output `thread_keys` in those threads will be `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [31,32,33,34] }`. + +```cpp showLineNumbers={false} +#include // or equivalently + +struct CustomLess +{ + template + __device__ bool operator()(const DataType &lhs, const DataType &rhs) + { + return lhs < rhs; + } +}; + +__global__ void ExampleKernel(...) +{ + constexpr int warp_threads = 16; + constexpr int block_threads = 256; + constexpr int items_per_thread = 4; + constexpr int warps_per_block = block_threads / warp_threads; + const int warp_id = static_cast(threadIdx.x) / warp_threads; + + // Specialize WarpMergeSort for a virtual warp of 16 threads + // owning 4 integer items each + using WarpMergeSortT = + cub::WarpMergeSort; + + // Allocate shared memory for WarpMergeSort + __shared__ typename WarpMergeSortT::TempStorage temp_storage[warps_per_block]; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_keys[items_per_thread]; + // ... + + WarpMergeSortT(temp_storage[warp_id]).Sort(thread_keys, CustomLess()); + // ... +} +``` + + + + + +Key type + + + +The number of items per thread + + + +**[optional]** The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute-capability (e.g., 32 threads for SM86). Must be a power of two. + + + +**[optional]** Value type (default: cub::NullType, which indicates a keys-only sort) + + + + + +**Inherits from:** `cub::BlockMergeSortStrategy< KeyT, NullType, detail::warp_threads, ITEMS_PER_THREAD, WarpMergeSort< KeyT, ITEMS_PER_THREAD, detail::warp_threads, NullType > >` (public) + +--- + +## Constructors + +### WarpMergeSort inline + + + + + +```cpp showLineNumbers={false} +cub::WarpMergeSort::WarpMergeSort( + typename BlockMergeSortStrategyT::TempStorage &temp_storage +) +``` + + + + + + +```cpp showLineNumbers={false} +cub::WarpMergeSort::WarpMergeSort() = delete +``` + + + + + +--- + +## Methods + +### get_member_mask inline const + + +```cpp showLineNumbers={false} +unsigned int cub::WarpMergeSort::get_member_mask() const +``` + + +### SyncImplementation inline const + + +```cpp showLineNumbers={false} +void cub::WarpMergeSort::SyncImplementation() const +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `BlockMergeSortStrategyT` | `BlockMergeSortStrategy< KeyT, ValueT, LOGICAL_WARP_THREADS, ITEMS_PER_THREAD, WarpMergeSort >` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `IS_ARCH_WARP` static constexpr | `bool` | | +| `KEYS_ONLY` static constexpr | `bool` | | +| `TILE_SIZE` static constexpr | `int` | | +| `warp_id` | `const unsigned int` | | +| `member_mask` | `const unsigned int` | | +| `BlockMergeSortStrategyT` | `friend` | | diff --git a/fern/cudapages/cub/cub/cub/WarpReduce.mdx b/fern/cudapages/cub/cub/cub/WarpReduce.mdx new file mode 100644 index 0000000..a2ca6b8 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/WarpReduce.mdx @@ -0,0 +1,749 @@ +--- +title: cub::WarpReduce +description: "" +--- + +The `WarpReduce` class provides collective methods for computing a parallel reduction of items partitioned across a CUDA thread warp. + +![](../../img/warp_reduce_logo.png) + +## Performance considerations + +- Uses special instructions when applicable (e.g., warp `SHFL` instructions) +- Uses synchronization-free communication between warp lanes when applicable +- Incurs zero bank conflicts for most types +- Computation is slightly more efficient (i.e., having lower instruction overhead) for: + + - Summation (**vs.** generic reduction) + - The architecture's warp size is a whole multiple of `LogicalWarpThreads` + +## Example + +The code snippet below illustrates four concurrent warp sum reductions within a block of 128 threads (one per each of the 32-thread warps). + +The corresponding output `aggregate` in threads 0, 32, 64, and 96 will be `496`, `1520`, `2544`, and `3568`, respectively (and is undefined in other threads). + +The code snippet below illustrates a single warp sum reduction within a block of 128 threads. + +The corresponding output `aggregate` in thread0 will be `496` (and is undefined in other threads). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + // Allocate WarpReduce shared memory for 4 warps + __shared__ typename WarpReduce::TempStorage temp_storage[4]; + // Obtain one input item per thread + int thread_data = ... + // Return the warp-wide sums to each lane0 (threads 0, 32, 64, and 96) + int warp_id = threadIdx.x / 32; + int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data); +} + +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + ... + // Only the first warp performs a reduction + if (threadIdx.x < 32) + { + // Obtain one input item per thread + int thread_data = ... + // Return the warp-wide sum to lane0 + int aggregate = WarpReduce(temp_storage).Sum(thread_data); + } +} +``` + + + + + +The reduction input/output element type + + + +**[optional]** The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute-capability (e.g., 32 threads for SM20). + + + + + +--- + +## Collective constructors + +### WarpReduce inline + +Collective constructor using the specified memory allocation as temporary storage. Logical warp and lane identifiers are constructed from `threadIdx.x`. + + +```cpp showLineNumbers={false} +cub::WarpReduce::WarpReduce( + TempStorage &temp_storage +) +``` + + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::WarpReduce::TempStorage) + + +--- + +## Summation reductions + +### Sum inline nodiscard + + + + +Computes a warp-wide sum in the calling warp. The output is valid in warp *lane*0. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Sum( + T input +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Example** + +The code snippet below illustrates four concurrent warp sum reductions within a block of 128 threads (one per each of the 32-thread warps). + +The corresponding output `aggregate` in threads 0, 32, 64, and 96 will `496`, `1520`, `2544`, and `3568`, respectively (and is undefined in other threads). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + // Allocate WarpReduce shared memory for 4 warps + __shared__ typename WarpReduce::TempStorage temp_storage[4]; + // Obtain one input item per thread + int thread_data = ... + // Return the warp-wide sums to each lane0 + int warp_id = threadIdx.x / 32; + int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data); +} +``` + + + + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Sum( + const InputType &input +) +``` + + + + + +Computes a partially-full warp-wide sum in the calling warp. The output is valid in warp *lane*0. + +All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Sum( + T input, + int valid_items +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input + + + +Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) + + +**Example** + +The code snippet below illustrates a sum reduction within a single, partially-full block of 32 threads (one warp). + +The corresponding output `aggregate` in *lane*0 is `6` (and is undefined in other threads). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(int *d_data, int valid_items) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item per thread if in range + int thread_data; + if (threadIdx.x < valid_items) + thread_data = d_data[threadIdx.x]; + + // Return the warp-wide sums to each lane0 + int aggregate = WarpReduce(temp_storage).Sum(thread_data, valid_items); +} +``` + + + + +### Max inline nodiscard + + + + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Max( + T input +) +``` + + + + + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Max( + const InputType &input +) +``` + + + + + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Max( + T input, + int valid_items +) +``` + + + + + +### Min inline nodiscard + + + + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Min( + T input +) +``` + + + + + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Min( + const InputType &input +) +``` + + + + + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Min( + T input, + int valid_items +) +``` + + + + + +### HeadSegmentedSum inline nodiscard + +Computes a segmented sum in the calling warp where segments are defined by head-flags. The sum of each segment is returned to the first lane in that segment (which always includes *lane*0). + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::HeadSegmentedSum( + T input, + FlagT head_flag +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input + + + +Head flag denoting whether or not `input` is the start of a new segment + + +**Example** + +The code snippet below illustrates a head-segmented warp sum reduction within a block of 32 threads (one warp). + +is `{0, 1, 2, 3, ..., 31` and is `{1, 0, 0, 0, 1, 0, 0, 0, ..., 1, 0, 0, 0`, respectively. The corresponding output `aggregate` in threads 0, 4, 8, etc. will be `6`, `22`, `38`, etc. (and is undefined in other threads). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int head_flag = ... + + // Return the warp-wide sums to each lane0 + int aggregate = WarpReduce(temp_storage).HeadSegmentedSum( + thread_data, head_flag); +} +``` + +### TailSegmentedSum inline nodiscard + +Computes a segmented sum in the calling warp where segments are defined by tail-flags. The sum of each segment is returned to the first lane in that segment (which always includes *lane*0). + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::TailSegmentedSum( + T input, + FlagT tail_flag +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input + + + +Head flag denoting whether or not `input` is the start of a new segment + + +**Example** + +The code snippet below illustrates a tail-segmented warp sum reduction within a block of 32 threads (one warp). + +is `{0, 1, 2, 3, ..., 31}` and is `{0, 0, 0, 1, 0, 0, 0, 1, ..., 0, 0, 0, 1}`, respectively. The corresponding output `aggregate` in threads 0, 4, 8, etc. will be `6`, `22`, `38`, etc. (and is undefined in other threads). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int tail_flag = ... + + // Return the warp-wide sums to each lane0 + int aggregate = WarpReduce(temp_storage).TailSegmentedSum( + thread_data, tail_flag); +``` + +--- + +## Generic reductions + +### Reduce inline nodiscard + + + + +Computes a warp-wide reduction in the calling warp using the specified binary reduction functor. The output is valid in warp *lane*0. + +Supports non-commutative reduction operators + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Reduce( + T input, + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction operator + + +**Example** + +The code snippet below illustrates four concurrent warp max reductions within a block of 128 threads (one per each of the 32-thread warps). + +`{0, 1, 2, 3, ..., 127}`. The corresponding output `aggregate` in threads 0, 32, 64, and 96 will be `31`, `63`, `95`, and `127`, respectively (and is undefined in other threads). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for 4 warps + __shared__ typename WarpReduce::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Return the warp-wide reductions to each lane0 + int warp_id = threadIdx.x / 32; + int aggregate = WarpReduce(temp_storage[warp_id]).Reduce( + thread_data, cuda::maximum<>{}); +``` + + + + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Reduce( + const InputType &input, + ReductionOp reduction_op +) +``` + + + + + +Computes a partially-full warp-wide reduction in the calling warp using the specified binary reduction functor. The output is valid in warp *lane*0. + +All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. + +Supports non-commutative reduction operators + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Reduce( + T input, + ReductionOp reduction_op, + int valid_items +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Binary reduction operator + + + +Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) + + +**Example** + +The code snippet below illustrates a max reduction within a single, partially-full block of 32 threads (one warp). + +is `4`. The corresponding output `aggregate` in thread0 is `3` (and is undefined in other threads). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(int *d_data, int valid_items) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item per thread if in range + int thread_data; + if (threadIdx.x < valid_items) + thread_data = d_data[threadIdx.x]; + + // Return the warp-wide reductions to each lane0 + int aggregate = WarpReduce(temp_storage).Reduce( + thread_data, cuda::maximum<>{}, valid_items); +``` + + + + +### HeadSegmentedReduce inline nodiscard + +Computes a segmented reduction in the calling warp where segments are defined by head-flags. The reduction of each segment is returned to the first lane in that segment (which always includes *lane*0). + +Supports non-commutative reduction operators + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::HeadSegmentedReduce( + T input, + FlagT head_flag, + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Head flag denoting whether or not `input` is the start of a new segment + + + +Reduction operator + + +**Example** + +The code snippet below illustrates a head-segmented warp max reduction within a block of 32 threads (one warp). + +is `{0, 1, 2, 3, ..., 31}` and is `{1, 0, 0, 0, 1, 0, 0, 0, ..., 1, 0, 0, 0}`, respectively. The corresponding output `aggregate` in threads 0, 4, 8, etc. will be `3`, `7`, `11`, etc. (and is undefined in other threads). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int head_flag = ... + + // Return the warp-wide reductions to each lane0 + int aggregate = WarpReduce(temp_storage).HeadSegmentedReduce( + thread_data, head_flag, cuda::maximum<>{}); +``` + +### TailSegmentedReduce inline nodiscard + +Computes a segmented reduction in the calling warp where segments are defined by tail-flags. The reduction of each segment is returned to the first lane in that segment (which always includes *lane*0). + +Supports non-commutative reduction operators + + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::TailSegmentedReduce( + T input, + FlagT tail_flag, + ReductionOp reduction_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input + + + +Tail flag denoting whether or not `input` is the end of the current segment + + + +Reduction operator + + +**Example** + +The code snippet below illustrates a tail-segmented warp max reduction within a block of 32 threads (one warp). + +is `{0, 1, 2, 3, ..., 31}` and is `{0, 0, 0, 1, 0, 0, 0, 1, ..., 0, 0, 0, 1}`, respectively. The corresponding output `aggregate` in threads 0, 4, 8, etc. will be `3`, `7`, `11`, etc. (and is undefined in other threads). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int tail_flag = ... + + // Return the warp-wide reductions to each lane0 + int aggregate = WarpReduce(temp_storage).TailSegmentedReduce( + thread_data, tail_flag, cuda::maximum<>{}); +``` + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `_TempStorage` | `typename InternalWarpReduce::TempStorage` | Shared memory storage layout type for `WarpReduce`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `is_full_warp` static constexpr | `bool` | | +| `is_power_of_two` static constexpr | `bool` | | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::WarpReduce::TempStorage +``` + + +The operations exposed by `WarpReduce` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/WarpScan.mdx b/fern/cudapages/cub/cub/cub/WarpScan.mdx new file mode 100644 index 0000000..c02b2f8 --- /dev/null +++ b/fern/cudapages/cub/cub/cub/WarpScan.mdx @@ -0,0 +1,1184 @@ +--- +title: cub::WarpScan +description: "" +--- + +The WarpScan class provides collective methods for computing a parallel prefix scan of items partitioned across a CUDA thread warp. + +![](../../img/warp_scan_logo.png) + +## Performance considerations + +* Uses special instructions when applicable (e.g., warp `SHFL`) +* Uses synchronization-free communication between warp lanes when applicable +* Incurs zero bank conflicts for most types +* Computation is slightly more efficient (i.e., having lower instruction overhead) for: + + * Summation (**vs.** generic scan) + * The architecture's warp size is a whole multiple of `LOGICAL_WARP_THREADS` + +## Example + +The code snippet below illustrates four concurrent warp prefix sums within a block of 128 threads (one per each of the 32-thread warps). + +`{1, 1, 1, 1, ...}`. The corresponding output `thread_data` in each of the four warps of threads will be `0, 1, 2, 3, ..., 31}`. + +The code snippet below illustrates a single warp prefix sum within a block of 128 threads. + +`{1, 1, 1, 1, ...}`. The corresponding output `thread_data` will be `{0, 1, 2, 3, ..., 31}`. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute warp-wide prefix sums + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data); +} + +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for one warp + __shared__ typename WarpScan::TempStorage temp_storage; + ... + + // Only the first warp performs a prefix sum + if (threadIdx.x < 32) + { + // Obtain one input item per thread + int thread_data = ... + + // Compute warp-wide prefix sums + WarpScan(temp_storage).ExclusiveSum(thread_data, thread_data); + } +} +``` + + + + + +The scan input/output element type + + + +**[optional]** The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size associated with the CUDA Compute Capability targeted by the compiler (e.g., 32 threads for SM20). + + + + + +--- + +## Collective constructors + +### WarpScan inline + +Collective constructor using the specified memory allocation as temporary storage. + +Logical warp and lane identifiers are constructed from `threadIdx.x`. + + +```cpp showLineNumbers={false} +cub::WarpScan::WarpScan( + TempStorage &temp_storage +) +``` + + +**Parameters** + + +Reference to memory allocation having layout type [TempStorage](/library/api/cub::WarpScan::TempStorage) + + +--- + +## Inclusive prefix sums + +### InclusiveSum inline + + + + +Computes an inclusive prefix sum across the calling warp. + + +```cpp showLineNumbers={false} +void cub::WarpScan::InclusiveSum( + T input, + T &inclusive_output +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input item. + + + +Calling thread's output item. May be aliased with `input`. + + +**Example** + +The code snippet below illustrates four concurrent warp-wide inclusive prefix sums within a block of 128 threads (one per each of the 32-thread warps). + +`{1, 1, 1, 1, ...}`. The corresponding output `thread_data` in each of the four warps of threads will be `1, 2, 3, ..., 32}`. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute inclusive warp-wide prefix sums + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).InclusiveSum(thread_data, thread_data); +} +``` + + + + +Computes an inclusive prefix sum across the calling warp. Also provides every thread with the warp-wide `warp_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +void cub::WarpScan::InclusiveSum( + T input, + T &inclusive_output, + T &warp_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item. May be aliased with `input` + + + +Warp-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates four concurrent warp-wide inclusive prefix sums within a block of 128 threads (one per each of the 32-thread warps). + +`{1, 1, 1, 1, ...}`. The corresponding output `thread_data` in each of the four warps of threads will be `1, 2, 3, ..., 32}`. Furthermore, `warp_aggregate` for all threads in all warps will be `32`. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute inclusive warp-wide prefix sums + int warp_aggregate; + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).InclusiveSum(thread_data, thread_data, warp_aggregate); +} +``` + + + + +--- + +## Exclusive prefix sums + +### ExclusiveSum inline + + + + +Computes an exclusive prefix sum across the calling warp. The value of 0 is applied as the initial value, and is assigned to `exclusive_output` in *lane*0. + + +```cpp showLineNumbers={false} +void cub::WarpScan::ExclusiveSum( + T input, + T &exclusive_output +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Uses the identity element (zero) as the initial value. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input item. + + + +Calling thread's output item. May be aliased with `input`. + + +**Example** + +The code snippet below illustrates four concurrent warp-wide exclusive prefix sums within a block of 128 threads (one per each of the 32-thread warps). + +`{1, 1, 1, 1, ...}`. The corresponding output `thread_data` in each of the four warps of threads will be `0, 1, 2, ..., 31}`. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute exclusive warp-wide prefix sums + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data); +} +``` + + + + +Computes an exclusive prefix sum across the calling warp. The value of 0 is applied as the initial value, and is assigned to `exclusive_output` in *lane*0. Also provides every thread with the warp-wide `warp_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +void cub::WarpScan::ExclusiveSum( + T input, + T &exclusive_output, + T &warp_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* Uses the identity element (zero) as the initial value. * The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item. May be aliased with `input` + + + +Warp-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates four concurrent warp-wide exclusive prefix sums within a block of 128 threads (one per each of the 32-thread warps). + +`{1, 1, 1, 1, ...}`. The corresponding output `thread_data` in each of the four warps of threads will be `0, 1, 2, ..., 31}`. Furthermore, `warp_aggregate` for all threads in all warps will be `32`. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute exclusive warp-wide prefix sums + int warp_aggregate; + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, + thread_data, + warp_aggregate); +``` + + + + +--- + +## Inclusive prefix scans + +### InclusiveScan inline + + + + +Computes an inclusive prefix scan using the specified binary scan functor across the calling warp. + + +```cpp showLineNumbers={false} +template +void cub::WarpScan::InclusiveScan( + T input, + T &inclusive_output, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item. May be aliased with `input` + + + +Binary scan operator + + +**Example** + +The code snippet below illustrates four concurrent warp-wide inclusive prefix max scans within a block of 128 threads (one per each of the 32-thread warps). + +`{0, -1, 2, -3, ..., 126, -127}`. The corresponding output `thread_data` in the first warp would be `0, 0, 2, 2, ..., 30, 30`, the output for the second warp would be `32, 32, 34, 34, ..., 62, 62`, etc. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute inclusive warp-wide prefix max scans + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}); +``` + + + + +Computes an inclusive prefix scan using the specified binary scan functor across the calling warp. + + +```cpp showLineNumbers={false} +template +void cub::WarpScan::InclusiveScan( + T input, + T &inclusive_output, + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item. May be aliased with `input` + + + +Initial value to seed the inclusive scan (uniform across warp) + + + +Binary scan operator + + + + + +Computes an inclusive prefix scan using the specified binary scan functor across the calling warp. Also provides every thread with the warp-wide `warp_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::WarpScan::InclusiveScan( + T input, + T &inclusive_output, + ScanOp scan_op, + T &warp_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item. May be aliased with `input` + + + +Binary scan operator + + + +Warp-wide aggregate reduction of input items. + + +**Example** + +The code snippet below illustrates four concurrent warp-wide inclusive prefix max scans within a block of 128 threads (one per each of the 32-thread warps). + +`{0, -1, 2, -3, ..., 126, -127}`. The corresponding output `thread_data` in the first warp would be `0, 0, 2, 2, ..., 30, 30`, the output for the second warp would be `32, 32, 34, 34, ..., 62, 62`, etc. Furthermore, `warp_aggregate` would be assigned `30` for threads in the first warp, `62` for threads in the second warp, etc. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute inclusive warp-wide prefix max scans + int warp_aggregate; + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).InclusiveScan( + thread_data, thread_data, cuda::maximum<>{}, warp_aggregate); +``` + + + + +Computes an inclusive prefix scan using the specified binary scan functor across the calling warp. Also provides every thread with the warp-wide `warp_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::WarpScan::InclusiveScan( + T input, + T &inclusive_output, + T initial_value, + ScanOp scan_op, + T &warp_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item. May be aliased with `input` + + + +Initial value to seed the inclusive scan (uniform across warp). It is not taken into account for warp_aggregate. + + + +Binary scan operator + + + +Warp-wide aggregate reduction of input items. + + + + + +--- + +## Exclusive prefix scans + +### ExclusiveScan inline + + + + +Computes an exclusive prefix scan using the specified binary scan functor across the calling warp. Because no initial value is supplied, the `output` computed for *lane*0 is undefined. + + +```cpp showLineNumbers={false} +template +void cub::WarpScan::ExclusiveScan( + T input, + T &exclusive_output, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item. May be aliased with `input` + + + +Binary scan operator + + +**Example** + +The code snippet below illustrates four concurrent warp-wide exclusive prefix max scans within a block of 128 threads (one per each of the 32-thread warps). + +`{0, -1, 2, -3, ..., 126, -127}`. The corresponding output `thread_data` in the first warp would be `?, 0, 0, 2, ..., 28, 30`, the output for the second warp would be `?, 32, 32, 34, ..., 60, 62`, etc. (The output `thread_data` in warp *lane*0 is undefined.) + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute exclusive warp-wide prefix max scans + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).ExclusiveScan(thread_data, thread_data, cuda::maximum<>{}); +``` + + + + +Computes an exclusive prefix scan using the specified binary scan functor across the calling warp. + + +```cpp showLineNumbers={false} +template +void cub::WarpScan::ExclusiveScan( + T input, + T &exclusive_output, + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item. May be aliased with `input` + + + +Initial value to seed the exclusive scan + + + +Binary scan operator + + +**Example** + +The code snippet below illustrates four concurrent warp-wide exclusive prefix max scans within a block of 128 threads (one per each of the 32-thread warps). + +`{0, -1, 2, -3, ..., 126, -127}`. The corresponding output `thread_data` in the first warp would be `INT_MIN, 0, 0, 2, ..., 28, 30`, the output for the second warp would be `30, 32, 32, 34, ..., 60, 62`, etc. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute exclusive warp-wide prefix max scans + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).ExclusiveScan(thread_data, + thread_data, + INT_MIN, + cuda::maximum<>{}); +``` + + + + +Computes an exclusive prefix scan using the specified binary scan functor across the calling warp. Because no initial value is supplied, the `output` computed for *lane*0 is undefined. Also provides every thread with the warp-wide `warp_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::WarpScan::ExclusiveScan( + T input, + T &exclusive_output, + ScanOp scan_op, + T &warp_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item. May be aliased with `input` + + + +Binary scan operator + + + +Warp-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates four concurrent warp-wide exclusive prefix max scans within a block of 128 threads (one per each of the 32-thread warps). + +`{0, -1, 2, -3, ..., 126, -127}`. The corresponding output `thread_data` in the first warp would be `?, 0, 0, 2, ..., 28, 30`, the output for the second warp would be `?, 32, 32, 34, ..., 60, 62`, etc. (The output `thread_data` in warp *lane*0 is undefined). Furthermore, `warp_aggregate` would be assigned `30` for threads in the first warp, \p 62 for threads in the second warp, etc. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute exclusive warp-wide prefix max scans + int warp_aggregate; + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).ExclusiveScan(thread_data, + thread_data, + cuda::maximum<>{}, + warp_aggregate); +``` + + + + +Computes an exclusive prefix scan using the specified binary scan functor across the calling warp. Also provides every thread with the warp-wide `warp_aggregate` of all inputs. + + +```cpp showLineNumbers={false} +template +void cub::WarpScan::ExclusiveScan( + T input, + T &exclusive_output, + T initial_value, + ScanOp scan_op, + T &warp_aggregate +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's output item. May be aliased with `input` + + + +Initial value to seed the exclusive scan + + + +Binary scan operator + + + +Warp-wide aggregate reduction of input items + + +**Example** + +The code snippet below illustrates four concurrent warp-wide exclusive prefix max scans within a block of 128 threads (one per each of the 32-thread warps). + +`{0, -1, 2, -3, ..., 126, -127}`. The corresponding output `thread_data` in the first warp would be `INT_MIN, 0, 0, 2, ..., 28, 30`, the output for the second warp would be `INT_MIN, 32, 32, 34, ..., 60, 62`, etc. Furthermore, `warp_aggregate` would be assigned `30` for threads in the first warp, `62` for threads in the second warp, etc. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute exclusive warp-wide prefix max scans + int warp_aggregate; + int warp_id = threadIdx.x / 32; + WarpScan(temp_storage[warp_id]).ExclusiveScan(thread_data, + thread_data, + INT_MIN, + cuda::maximum<>{}, + warp_aggregate); +``` + + + + +--- + +## Combination (inclusive & exclusive) prefix scans + +### Scan inline + + + + +Computes both inclusive and exclusive prefix scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the `exclusive_output` computed for *lane*0 is undefined. + + +```cpp showLineNumbers={false} +template +void cub::WarpScan::Scan( + T input, + T &inclusive_output, + T &exclusive_output, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's inclusive-scan output item + + + +Calling thread's exclusive-scan output item + + + +Binary scan operator + + +**Example** + +The code snippet below illustrates four concurrent warp-wide exclusive prefix max scans within a block of 128 threads (one per each of the 32-thread warps). + +`{0, -1, 2, -3, ..., 126, -127}`. The corresponding output `inclusive_partial` in the first warp would be `0, 0, 2, 2, ..., 30, 30`, the output for the second warp would be `32, 32, 34, 34, ..., 62, 62`, etc. The corresponding output `exclusive_partial` in the first warp would be `?, 0, 0, 2, ..., 28, 30`, the output for the second warp would be `?, 32, 32, 34, ..., 60, 62`, etc. (The output `thread_data` in warp *lane*0 is undefined.) + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute exclusive warp-wide prefix max scans + int inclusive_partial, exclusive_partial; + WarpScan(temp_storage[warp_id]).Scan(thread_data, + inclusive_partial, + exclusive_partial, + cuda::maximum<>{}); +``` + + + + +Computes both inclusive and exclusive prefix scans using the specified binary scan functor across the calling warp. + + +```cpp showLineNumbers={false} +template +void cub::WarpScan::Scan( + T input, + T &inclusive_output, + T &exclusive_output, + T initial_value, + ScanOp scan_op +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Template parameters** + + +**[inferred]** Binary scan operator type having member `T operator()(const T &a, const T &b)` + + +**Parameters** + + +Calling thread's input item + + + +Calling thread's inclusive-scan output item + + + +Calling thread's exclusive-scan output item + + + +Initial value to seed the exclusive scan + + + +Binary scan operator + + +**Example** + +The code snippet below illustrates four concurrent warp-wide prefix max scans within a block of 128 threads (one per each of the 32-thread warps). + +`{0, -1, 2, -3, ..., 126, -127}`. The corresponding output `inclusive_partial` in the first warp would be `0, 0, 2, 2, ..., 30, 30`, the output for the second warp would be `32, 32, 34, 34, ..., 62, 62`, etc. The corresponding output `exclusive_partial` in the first warp would be `INT_MIN, 0, 0, 2, ..., 28, 30`, the output for the second warp would be `INT_MIN, 32, 32, 34, ..., 60, 62`, etc. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Compute inclusive warp-wide prefix max scans + int warp_id = threadIdx.x / 32; + int inclusive_partial, exclusive_partial; + WarpScan(temp_storage[warp_id]).Scan(thread_data, + inclusive_partial, + exclusive_partial, + INT_MIN, + cuda::maximum<>{}); +``` + + + + +--- + +## Data exchange + +### Broadcast inline + +Broadcast the value `input` from *lane*src_lane to all lanes in the warp + + +```cpp showLineNumbers={false} +T cub::WarpScan::Broadcast( + T input, + unsigned int src_lane +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +* The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The value to broadcast + + + +Which warp lane is to do the broadcasting + + +**Example** + +The code snippet below illustrates the warp-wide broadcasts of values from *lane*0 in each of four warps to all other threads in those warps. + +`{0, 1, 2, 3, ..., 127}`. The corresponding output `thread_data` will be `{0, 0, ..., 0}` in warp0, `{32, 32, ..., 32}` in warp1, `{64, 64, ..., 64}` in warp2, etc. + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpScan for type int + using WarpScan = cub::WarpScan; + + // Allocate WarpScan shared memory for 4 warps + __shared__ typename WarpScan::TempStorage temp_storage[4]; + + // Obtain one input item per thread + int thread_data = ... + + // Broadcast from lane0 in each warp to all other threads in the warp + int warp_id = threadIdx.x / 32; + thread_data = WarpScan(temp_storage[warp_id]).Broadcast(thread_data, 0); +``` + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `InternalWarpScan` | `::cuda::std:: _If< IS_POW_OF_TWO, detail::WarpScanShfl< T, LOGICAL_WARP_THREADS >, detail::WarpScanSmem< T, LOGICAL_WARP_THREADS > >` | Internal specialization. | +| `_TempStorage` | `typename InternalWarpScan::TempStorage` | Shared memory storage layout type for `WarpScan`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `IS_ARCH_WARP` static constexpr | `bool` | Whether the logical warp size and the PTX warp size coincide. | +| `IS_POW_OF_TWO` static constexpr | `bool` | Whether the logical warp size is a power-of-two. | +| `IS_INTEGER` static constexpr | `bool` | Whether the data type is an integer (which has fully-associative addition). | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `lane_id` | `unsigned int` | | + +--- + +## Inner classes + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::WarpScan::TempStorage +``` + + +The operations exposed by `WarpScan` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cub/cub/cub/WarpStore.mdx b/fern/cudapages/cub/cub/cub/WarpStore.mdx new file mode 100644 index 0000000..0fc6fcf --- /dev/null +++ b/fern/cudapages/cub/cub/cub/WarpStore.mdx @@ -0,0 +1,353 @@ +--- +title: cub::WarpStore +description: "" +--- + +The WarpStore class provides collective data movement methods for writing a blocked arrangement of items partitioned across a CUDA warp to a linear segment of memory. + +## Example + +The code snippet below illustrates the storing of a "blocked" arrangement of 64 integers across 16 threads (where each thread owns 4 consecutive items) into a linear segment of memory. The store is specialized for `WARP_STORE_TRANSPOSE`, meaning items are locally reordered among threads so that memory references will be efficiently coalesced using a warp-striped access pattern. + +`{ [0,1,2,3], [4,5,6,7], ..., [60,61,62,63] }`. The output `d_data` will be `0, 1, 2, 3, 4, 5, ...`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + constexpr int warp_threads = 16; + constexpr int block_threads = 256; + constexpr int items_per_thread = 4; + + // Specialize WarpStore for a virtual warp of 16 threads owning 4 integer items each + using WarpStoreT = WarpStore; + + constexpr int warps_in_block = block_threads / warp_threads; + constexpr int tile_size = items_per_thread * warp_threads; + const int warp_id = static_cast(threadIdx.x) / warp_threads; + + // Allocate shared memory for WarpStore + __shared__ typename WarpStoreT::TempStorage temp_storage[warps_in_block]; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Store items to linear memory + WarpStoreT(temp_storage[warp_id]).Store(d_data + warp_id * tile_size, thread_data); +} +``` + + + + + +The type of data to be written. + + + +The number of consecutive items partitioned onto each thread. + + + +**[optional]** [cub::WarpStoreAlgorithm](/library/api/cub::WarpStoreAlgorithm) tuning policy enumeration. default: [cub::WARP_STORE_DIRECT](/library/api/cub::WARP_STORE_DIRECT). + + + +**[optional]** The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute-capability (e.g., 32 threads for SM86). Must be a power of two. + + + + + +--- + +## Collective constructors + +### WarpStore inline + + + + +Collective constructor using a private static allocation of shared memory as temporary storage. + + +```cpp showLineNumbers={false} +cub::WarpStore::WarpStore() +``` + + + + + +Collective constructor using the specified memory allocation as temporary storage. + + +```cpp showLineNumbers={false} +cub::WarpStore::WarpStore( + TempStorage &temp_storage +) +``` + + + + + +--- + +## Data movement + +### Store inline + + + + +Store items into a linear segment of memory. + + +```cpp showLineNumbers={false} +template +void cub::WarpStore::Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD] +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The thread block's base output iterator for storing to + + + +Data to store + + +**Example** + +The code snippet below illustrates the storing of a "blocked" arrangement of 64 integers across 16 threads (where each thread owns 4 consecutive items) into a linear segment of memory. The store is specialized for `WARP_STORE_TRANSPOSE`, meaning items are locally reordered among threads so that memory references will be efficiently coalesced using a warp-striped access pattern. + +`{ [0,1,2,3], [4,5,6,7], ..., [60,61,62,63] }`. The output `d_data` will be `0, 1, 2, 3, 4, 5, ...`. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, ...) +{ + constexpr int warp_threads = 16; + constexpr int block_threads = 256; + constexpr int items_per_thread = 4; + + // Specialize WarpStore for a virtual warp of 16 threads owning 4 integer items each + using WarpStoreT = WarpStore; + + constexpr int warps_in_block = block_threads / warp_threads; + constexpr int tile_size = items_per_thread * warp_threads; + const int warp_id = static_cast(threadIdx.x) / warp_threads; + + // Allocate shared memory for WarpStore + __shared__ typename WarpStoreT::TempStorage temp_storage[warps_in_block]; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Store items to linear memory + WarpStoreT(temp_storage[warp_id]).Store(d_data + warp_id * tile_size, thread_data); +``` + + + + +Store items into a linear segment of memory, guarded by range. + + +```cpp showLineNumbers={false} +template +void cub::WarpStore::Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD], + int valid_items +) +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + +**Parameters** + + +The thread block's base output iterator for storing to + + + +Data to store + + + +Number of valid items to write + + +**Example** + +The code snippet below illustrates the storing of a "blocked" arrangement of 64 integers across 16 threads (where each thread owns 4 consecutive items) into a linear segment of memory. The store is specialized for `WARP_STORE_TRANSPOSE`, meaning items are locally reordered among threads so that memory references will be efficiently coalesced using a warp-striped access pattern. + +`{ [0,1,2,3], [4,5,6,7], ..., [60,61,62,63] }` and `valid_items` is `5`. The output `d_data` will be `0, 1, 2, 3, 4, ?, ?, ...`, with only the first two threads being unmasked to store portions of valid data. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(int *d_data, int valid_items ...) +{ + constexpr int warp_threads = 16; + constexpr int block_threads = 256; + constexpr int items_per_thread = 4; + + // Specialize WarpStore for a virtual warp of 16 threads owning 4 integer items each + using WarpStoreT = WarpStore; + + constexpr int warps_in_block = block_threads / warp_threads; + constexpr int tile_size = items_per_thread * warp_threads; + const int warp_id = static_cast(threadIdx.x) / warp_threads; + + // Allocate shared memory for WarpStore + __shared__ typename WarpStoreT::TempStorage temp_storage[warps_in_block]; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Store items to linear memory + WarpStoreT(temp_storage[warp_id]).Store( + d_data + warp_id * tile_size, thread_data, valid_items); +``` + + + + +--- + +## Utility methods + +### PrivateStorage inline + + +```cpp showLineNumbers={false} +_TempStorage & cub::WarpStore::PrivateStorage() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `InternalStore` | `StoreInternal< ALGORITHM, 0 >` | Internal load implementation to use. | +| `_TempStorage` | `typename InternalStore::TempStorage` | Shared memory storage layout type. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `IS_ARCH_WARP` static constexpr | `bool` | | +| `temp_storage` | `_TempStorage &` | | +| `linear_tid` | `int` | | + +--- + +## Inner classes + +### StoreInternal + + +```cpp showLineNumbers={false} +struct cub::WarpStore::StoreInternal +``` + + +Store helper. + +### StoreInternal< WARP_STORE_DIRECT, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::WarpStore::StoreInternal< WARP_STORE_DIRECT, DUMMY > +``` + + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | | + +### StoreInternal< WARP_STORE_STRIPED, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::WarpStore::StoreInternal< WARP_STORE_STRIPED, DUMMY > +``` + + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | | + +### StoreInternal< WARP_STORE_VECTORIZE, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::WarpStore::StoreInternal< WARP_STORE_VECTORIZE, DUMMY > +``` + + +| Name | Type | Description | +|---|---|---| +| `linear_tid` | `int` | | + +### StoreInternal< WARP_STORE_TRANSPOSE, DUMMY > + + +```cpp showLineNumbers={false} +struct cub::WarpStore::StoreInternal< WARP_STORE_TRANSPOSE, DUMMY > +``` + + +| Name | Type | Description | +|---|---|---| +| `temp_storage` | `_TempStorage &` | | +| `linear_tid` | `int` | | + +### TempStorage + + +```cpp showLineNumbers={false} +struct cub::WarpStore::TempStorage +``` + + +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/cudapages/cuda/cuda/cuda/arch_traits_t.mdx b/fern/cudapages/cuda/cuda/cuda/arch_traits_t.mdx new file mode 100644 index 0000000..d1af5f3 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/arch_traits_t.mdx @@ -0,0 +1,52 @@ +--- +title: "cuda::arch_traits_t" +description: "Architecture traits This type contains information about an architecture that is constant across devices of that architecture." +--- + +Architecture traits This type contains information about an architecture that is constant across devices of that architecture. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `max_threads_per_block` | `int` | | +| `max_block_dim_x` | `int` | | +| `max_block_dim_y` | `int` | | +| `max_block_dim_z` | `int` | | +| `max_grid_dim_x` | `int` | | +| `max_grid_dim_y` | `int` | | +| `max_grid_dim_z` | `int` | | +| `max_shared_memory_per_block` | `::cuda::std::size_t` | | +| `total_constant_memory` | `::cuda::std::size_t` | | +| `warp_size` | `int` | | +| `max_resident_grids` | `int` | | +| `gpu_overlap` | `bool` | | +| `can_map_host_memory` | `bool` | | +| `concurrent_kernels` | `bool` | | +| `stream_priorities_supported` | `bool` | | +| `global_l1_cache_supported` | `bool` | | +| `local_l1_cache_supported` | `bool` | | +| `max_registers_per_block` | `int` | | +| `max_registers_per_multiprocessor` | `int` | | +| `max_registers_per_thread` | `int` | | +| `arch_id` | `::cuda::arch_id` | | +| `compute_capability_major` | `int` | | +| `compute_capability_minor` | `int` | | +| `compute_capability` | `::cuda::compute_capability` | | +| `max_shared_memory_per_multiprocessor` | `::cuda::std::size_t` | | +| `max_blocks_per_multiprocessor` | `int` | | +| `max_threads_per_multiprocessor` | `int` | | +| `max_warps_per_multiprocessor` | `int` | | +| `reserved_shared_memory_per_block` | `::cuda::std::size_t` | | +| `max_shared_memory_per_block_optin` | `::cuda::std::size_t` | | +| `cluster_supported` | `bool` | | +| `redux_intrinisic` | `bool` | | +| `elect_intrinsic` | `bool` | | +| `cp_async_supported` | `bool` | | +| `tma_supported` | `bool` | | diff --git a/fern/pages/libcudacxx/empty_docstring_class.mdx b/fern/cudapages/cuda/cuda/cuda/buffer.mdx similarity index 69% rename from fern/pages/libcudacxx/empty_docstring_class.mdx rename to fern/cudapages/cuda/cuda/cuda/buffer.mdx index c29fa6b..a967d86 100644 --- a/fern/pages/libcudacxx/empty_docstring_class.mdx +++ b/fern/cudapages/cuda/cuda/cuda/buffer.mdx @@ -1,8 +1,12 @@ --- title: "cuda::buffer" -description: "A memory-safe buffer for managing typed, property-annotated device memory with stream-ordered allocation." +description: "" --- +`buffer` is a container that provides resizable typed storage allocated from a given memory resource. It handles alignment, release and growth of the allocation. The elements are initialized during construction, which may require a kernel launch. + +In addition to being type-safe, `buffer` also takes a set of properties to ensure that e.g. execution space constraints are checked at compile time. However, only stateless properties can be forwarded. To use a stateful property, implement get_property(const buffer&, Property). + ```cpp showLineNumbers={false} #include ``` @@ -11,11 +15,11 @@ description: "A memory-safe buffer for managing typed, property-annotated device -The type to be stored in the buffer. +The type to be stored in the buffer -The properties the allocated memory satisfies. +The properties the allocated memory satisfies @@ -25,12 +29,12 @@ The properties the allocated memory satisfies. ## Constructors -### Copy and move constructors +### buffer inline -inline explicit +explicit Copy-constructs from a buffer. @@ -51,7 +55,7 @@ The other buffer. -inline noexcept +noexcept Move-constructs from a buffer. @@ -70,9 +74,9 @@ The other buffer. After move construction, the other buffer can only be assigned - + -inline explicit +explicit Copy-constructs from a buffer with matching properties. @@ -92,9 +96,9 @@ The other buffer. - + -inline noexcept +noexcept Move-constructs from a buffer with matching properties. @@ -114,15 +118,8 @@ The other buffer. After move construction, the other buffer can only be assigned - - -### Resource constructors - - -inline - Constructs an empty buffer using an environment. @@ -136,20 +133,20 @@ cuda::buffer<_Tp, _Properties>::buffer( ``` - + No memory is allocated. - + **Parameters** - -The environment providing the needed information. + +The environment providing the needed information -inline explicit +explicit Constructs a buffer of size `__size` using a memory and leaves all elements uninitialized. @@ -166,9 +163,9 @@ cuda::buffer<_Tp, _Properties>::buffer( ``` - -This constructor does **NOT** initialize any elements. It is the user's responsibility to ensure that the elements within `[vec.begin(), vec.end())` are properly initialized. - + +This constructor does *NOT* initialize any elements. It is the user's responsibility to ensure that the elements within `[vec.begin(), vec.end())` are properly initialized, e.g with `cuda::std::uninitialized_copy`. At the destruction of the `buffer` all elements in the range `[vec.begin(), vec.end())` will be destroyed. + **Parameters** @@ -176,15 +173,13 @@ This constructor does **NOT** initialize any elements. It is the user's responsi The size of the buffer. - + The environment used to query the memory resource. -inline - Constructs a buffer using a memory resource and copy-constructs all elements from the forward range `[__first, __last)`. @@ -200,9 +195,9 @@ cuda::buffer<_Tp, _Properties>::buffer( ``` - -If `__first == __last` then no memory is allocated. - + +If `__first == __last` then no memory is allocated + **Parameters** @@ -214,15 +209,13 @@ The start of the input sequence. The end of the input sequence. - + The environment used to query the memory resource. -inline - Constructs a buffer using a memory resource and copy-constructs all elements from `__ilist`. @@ -237,9 +230,9 @@ cuda::buffer<_Tp, _Properties>::buffer( ``` - -If `__ilist.size() == 0` then no memory is allocated. - + +If `__ilist.size() == 0` then no memory is allocated + **Parameters** @@ -247,15 +240,13 @@ If `__ilist.size() == 0` then no memory is allocated. The initializer_list being copied into the buffer. - + The environment used to query the memory resource. -inline - Constructs a buffer using a memory resource and an input range. @@ -270,9 +261,9 @@ cuda::buffer<_Tp, _Properties>::buffer( ``` - + If `__range.size() == 0` then no memory is allocated. - + **Parameters** @@ -280,7 +271,7 @@ If `__range.size() == 0` then no memory is allocated. The input range to be moved into the buffer. - + The environment used to query the memory resource. @@ -306,350 +297,310 @@ void cuda::buffer<_Tp, _Properties>::operator=( **Parameters** -The buffer to move from. +The other buffer. After move assignment, the other buffer can only be assigned to or destroyed. --- -## Element access +## Methods -### get_unsynchronized inline noexcept +### begin inline noexcept nodiscard -nodiscard +Returns an iterator to the first element of the buffer. -Returns a reference to the `__n`'th element of the async_vector. +If the buffer is empty, the returned iterator will be equal to [end()](/libcudacxx/api/cuda::buffer::end()). ```cpp showLineNumbers={false} -reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( - const size_type __n -) noexcept +iterator cuda::buffer<_Tp, _Properties>::begin() noexcept ``` -**Returns:** `reference` - -**Parameters** - - -The index of the element. - - -const nodiscard +const + +Returns an immutable iterator to the first element of the buffer. -Returns a reference to the `__n`'th element of the async_vector. +If the buffer is empty, the returned iterator will be equal to [end()](/libcudacxx/api/cuda::buffer::end()). ```cpp showLineNumbers={false} -const_reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( - const size_type __n -) const noexcept +const_iterator cuda::buffer<_Tp, _Properties>::begin() const noexcept ``` -**Returns:** `const_reference` - -**Parameters** - - -The index of the element. - - -### data inline noexcept - - - - -nodiscard - -Returns a pointer to the first element of the buffer. - - -```cpp showLineNumbers={false} -pointer cuda::buffer<_Tp, _Properties>::data() noexcept -``` - - -**Returns:** `pointer` +### cbegin inline const noexcept nodiscard - - - -const nodiscard +Returns an immutable iterator to the first element of the buffer. -Returns a pointer to the first element of the buffer. +If the buffer is empty, the returned iterator will be equal to [end()](/libcudacxx/api/cuda::buffer::end()). ```cpp showLineNumbers={false} -const_pointer cuda::buffer<_Tp, _Properties>::data() const noexcept +const_iterator cuda::buffer<_Tp, _Properties>::cbegin() const noexcept ``` -**Returns:** `const_pointer` - - - - ---- - -## Iterators - -### begin inline noexcept +### end inline noexcept nodiscard -nodiscard +Returns an iterator to the element following the last element of the buffer. -Returns an iterator to the first element of the buffer. +This element acts as a placeholder; attempting to access it results in undefined behavior. ```cpp showLineNumbers={false} -iterator cuda::buffer<_Tp, _Properties>::begin() noexcept +iterator cuda::buffer<_Tp, _Properties>::end() noexcept ``` -**Returns:** `iterator` - -const nodiscard +const -Returns an immutable iterator to the first element of the buffer. +Returns an immutable iterator to the element following the last element of the buffer. + +This element acts as a placeholder; attempting to access it results in undefined behavior. ```cpp showLineNumbers={false} -const_iterator cuda::buffer<_Tp, _Properties>::begin() const noexcept +const_iterator cuda::buffer<_Tp, _Properties>::end() const noexcept ``` -**Returns:** `const_iterator` - -### cbegin inline const noexcept +### cend inline const noexcept nodiscard -nodiscard +Returns an immutable iterator to the element following the last element of the buffer. -Returns an immutable iterator to the first element of the buffer. +This element acts as a placeholder; attempting to access it results in undefined behavior. ```cpp showLineNumbers={false} -const_iterator cuda::buffer<_Tp, _Properties>::cbegin() const noexcept +const_iterator cuda::buffer<_Tp, _Properties>::cend() const noexcept ``` -**Returns:** `const_iterator` - -### end inline noexcept +### rbegin inline noexcept nodiscard -nodiscard +Returns a reverse iterator to the first element of the reversed buffer. -Returns an iterator to the element following the last element of the buffer. +It corresponds to the last element of the non-reversed buffer. If the buffer is empty, the returned iterator is equal to [rend()](/libcudacxx/api/cuda::buffer::rend()). ```cpp showLineNumbers={false} -iterator cuda::buffer<_Tp, _Properties>::end() noexcept +reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() noexcept ``` -**Returns:** `iterator` - -const nodiscard +const -Returns an immutable iterator to the element following the last element of the buffer. +Returns an immutable reverse iterator to the first element of the reversed buffer. + +It corresponds to the last element of the non-reversed buffer. If the buffer is empty, the returned iterator is equal to [rend()](/libcudacxx/api/cuda::buffer::rend()). ```cpp showLineNumbers={false} -const_iterator cuda::buffer<_Tp, _Properties>::end() const noexcept +const_reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() const noexcept ``` -**Returns:** `const_iterator` - -### cend inline const noexcept +### crbegin inline const noexcept nodiscard -nodiscard +Returns an immutable reverse iterator to the first element of the reversed buffer. -Returns an immutable iterator to the element following the last element of the buffer. +It corresponds to the last element of the non-reversed buffer. If the buffer is empty, the returned iterator is equal to [rend()](/libcudacxx/api/cuda::buffer::rend()). ```cpp showLineNumbers={false} -const_iterator cuda::buffer<_Tp, _Properties>::cend() const noexcept +const_reverse_iterator cuda::buffer<_Tp, _Properties>::crbegin() const noexcept ``` -**Returns:** `const_iterator` - -### rbegin inline noexcept +### rend inline noexcept nodiscard -nodiscard +Returns a reverse iterator to the element following the last element of the reversed buffer. -Returns a reverse iterator to the first element of the reversed buffer. +It corresponds to the element preceding the first element of the non-reversed buffer. This element acts as a placeholder, attempting to access it results in undefined behavior. ```cpp showLineNumbers={false} -reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() noexcept +reverse_iterator cuda::buffer<_Tp, _Properties>::rend() noexcept ``` -**Returns:** `reverse_iterator` - -const nodiscard +const -Returns an immutable reverse iterator to the first element of the reversed buffer. +Returns an immutable reverse iterator to the element following the last element of the reversed buffer. + +It corresponds to the element preceding the first element of the non-reversed buffer. This element acts as a placeholder, attempting to access it results in undefined behavior. ```cpp showLineNumbers={false} -const_reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() const noexcept +const_reverse_iterator cuda::buffer<_Tp, _Properties>::rend() const noexcept ``` -**Returns:** `const_reverse_iterator` - -### crbegin inline const noexcept +### crend inline const noexcept nodiscard -nodiscard +Returns an immutable reverse iterator to the element following the last element of the reversed buffer. -Returns an immutable reverse iterator to the first element of the reversed buffer. +It corresponds to the element preceding the first element of the non-reversed buffer. This element acts as a placeholder, attempting to access it results in undefined behavior. ```cpp showLineNumbers={false} -const_reverse_iterator cuda::buffer<_Tp, _Properties>::crbegin() const noexcept +const_reverse_iterator cuda::buffer<_Tp, _Properties>::crend() const noexcept ``` -**Returns:** `const_reverse_iterator` - -### rend inline noexcept +### data inline noexcept nodiscard -nodiscard +Returns a pointer to the first element of the buffer. -Returns a reverse iterator to the element following the last element of the reversed buffer. +If the buffer has not allocated memory the pointer will be null. ```cpp showLineNumbers={false} -reverse_iterator cuda::buffer<_Tp, _Properties>::rend() noexcept +pointer cuda::buffer<_Tp, _Properties>::data() noexcept ``` -**Returns:** `reverse_iterator` - -const nodiscard +const -Returns an immutable reverse iterator to the element following the last element of the reversed buffer. +Returns a pointer to the first element of the buffer. + +If the buffer has not allocated memory the pointer will be null. ```cpp showLineNumbers={false} -const_reverse_iterator cuda::buffer<_Tp, _Properties>::rend() const noexcept +const_pointer cuda::buffer<_Tp, _Properties>::data() const noexcept ``` -**Returns:** `const_reverse_iterator` - -### crend inline const noexcept +### get_unsynchronized inline noexcept nodiscard -nodiscard + + -Returns an immutable reverse iterator to the element following the last element of the reversed buffer. +Returns a reference to the `__n` 'th element of the async_vector. ```cpp showLineNumbers={false} -const_reverse_iterator cuda::buffer<_Tp, _Properties>::crend() const noexcept +reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( + const size_type __n +) noexcept ``` -**Returns:** `const_reverse_iterator` + +Does not synchronize with the stored stream + ---- +**Parameters** -## Capacity + +The index of the element we want to access + -### size inline const noexcept + + -nodiscard +const -Returns the current number of elements stored in the buffer. +Returns a reference to the `__n` 'th element of the async_vector. ```cpp showLineNumbers={false} -size_type cuda::buffer<_Tp, _Properties>::size() const noexcept +const_reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( + const size_type __n +) const noexcept ``` -**Returns:** `size_type` + +Does not synchronize with the stored stream + -### empty inline const noexcept +**Parameters** -nodiscard + +The index of the element we want to access + -Returns true if the buffer is empty. + + + +### size inline const noexcept nodiscard + +Returns the current number of elements stored in the buffer. ```cpp showLineNumbers={false} -bool cuda::buffer<_Tp, _Properties>::empty() const noexcept +size_type cuda::buffer<_Tp, _Properties>::size() const noexcept ``` -**Returns:** `bool` +### empty inline const noexcept nodiscard ---- +Returns true if the buffer is empty. -## Resource and stream management + +```cpp showLineNumbers={false} +bool cuda::buffer<_Tp, _Properties>::empty() const noexcept +``` + -### memory_resource inline const noexcept +### memory_resource inline const noexcept nodiscard -nodiscard +Returns a \c const reference to the any_resource that holds the memory resource used to allocate the buffer ```cpp showLineNumbers={false} -const __resource_t& cuda::buffer<_Tp, _Properties>::memory_resource() const noexcept +const __resource_t & cuda::buffer<_Tp, _Properties>::memory_resource() const noexcept ``` -**Returns:** `const __resource_t &` - -### stream inline const constexpr noexcept - -nodiscard +### stream inline constexpr const noexcept nodiscard Returns the stored stream. @@ -659,11 +610,9 @@ stream_ref cuda::buffer<_Tp, _Properties>::stream() const noexcept ``` -**Returns:** [`stream_ref`](/libcudacxx/api/cuda::stream_ref) - - -Stream used to allocate the buffer is initially stored in the buffer, but can be changed with [`set_stream`](/libcudacxx/api/cuda::buffer::set_stream). - + +Stream used to allocate the buffer is initially stored in the buffer, but can be changed with [`set_stream`](/libcudacxx/api/cuda::buffer::set_stream) + ### set_stream inline constexpr @@ -677,20 +626,16 @@ void cuda::buffer<_Tp, _Properties>::set_stream( ``` - -Always synchronizes with the old stream. - + +Always synchronizes with the old stream + **Parameters** -The new stream. +The new stream ---- - -## Modifiers - ### swap inline noexcept Swaps the contents of a buffer with those of `__other`. @@ -706,13 +651,13 @@ void cuda::buffer<_Tp, _Properties>::swap( **Parameters** -The buffer to swap with. +The other buffer. ### destroy inline - + Destroys the buffer, deallocates the buffer and destroys the memory resource. @@ -724,9 +669,9 @@ void cuda::buffer<_Tp, _Properties>::destroy( ``` - + After this explicit destroy call, the buffer can only be assigned to or destroyed. - + **Parameters** @@ -735,7 +680,7 @@ The stream to deallocate the buffer on. - + Destroys the buffer, deallocates the buffer and destroys the memory resource. @@ -745,87 +690,19 @@ void cuda::buffer<_Tp, _Properties>::destroy() ``` - -Uses the stored stream to deallocate the buffer. - + +Uses the stored stream to deallocate the buffer, equivalent to calling [buffer.destroy](/libcudacxx/api/cuda::buffer::buffer.destroy)([buffer.stream()](/libcudacxx/api/cuda::buffer::buffer.stream())) + - + After this explicit destroy call, the buffer can only be assigned to or destroyed. - + --- -## Friend functions - -### swap noexcept - - -```cpp showLineNumbers={false} -void swap( - buffer &__lhs, - buffer &__rhs -) noexcept -``` - - -**Parameters** - - -The first buffer. - - - -The second buffer. - - -### transform_launch_argument noexcept - - - - - -```cpp showLineNumbers={false} -template -::cuda::std::span<_Tp> transform_launch_argument( - ::cuda::stream_ref, - buffer &__self -) noexcept -``` - - - - - - -```cpp showLineNumbers={false} -template -::cuda::std::span transform_launch_argument( - ::cuda::stream_ref, - const buffer &__self -) noexcept -``` - - - - - -### get_property noexcept - - -```cpp showLineNumbers={false} -template -void get_property( - const buffer &, - _Property -) noexcept -``` - - ---- - ## Types ### Typedefs @@ -837,10 +714,10 @@ void get_property( | `const_reference` | `const _Tp &` | | `pointer` | `_Tp *` | | `const_pointer` | `const _Tp *` | -| `iterator` | `::cuda::heterogeneous_iterator<_Tp, _Properties...>` | -| `const_iterator` | `::cuda::heterogeneous_iterator` | -| `reverse_iterator` | `::cuda::std::reverse_iterator` | -| `const_reverse_iterator` | `::cuda::std::reverse_iterator` | +| `iterator` | `::cuda::heterogeneous_iterator< _Tp, _Properties... >` | +| `const_iterator` | `::cuda::heterogeneous_iterator< const _Tp, _Properties... >` | +| `reverse_iterator` | `::cuda::std::reverse_iterator< iterator >` | +| `const_reverse_iterator` | `::cuda::std::reverse_iterator< const_iterator >` | | `size_type` | `::cuda::std::size_t` | | `difference_type` | `::cuda::std::ptrdiff_t` | -| `properties_list` | `::cuda::mr::properties_list<_Properties...>` | +| `properties_list` | `::cuda::mr::properties_list< _Properties... >` | diff --git a/fern/cudapages/cuda/cuda/cuda/compute_capability.mdx b/fern/cudapages/cuda/cuda/cuda/compute_capability.mdx new file mode 100644 index 0000000..0d5a30f --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/compute_capability.mdx @@ -0,0 +1,207 @@ +--- +title: "cuda::compute_capability" +description: "Type representing the CUDA compute capability." +--- + +Type representing the CUDA compute capability. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Constructors + +### compute_capability constexpr noexcept + + + + + +```cpp showLineNumbers={false} +cuda::compute_capability::compute_capability() noexcept = default +``` + + + + + +inline explicit + +Constructs the object from compute capability `__cc`. + +The expected format is 10 * major + minor. + + +```cpp showLineNumbers={false} +cuda::compute_capability::compute_capability( + int __cc +) noexcept +``` + + +**Parameters** + + +Compute capability. + + + + + +inline + +Constructs the object by combining the `__major` and `__minor` compute capability. + + +```cpp showLineNumbers={false} +cuda::compute_capability::compute_capability( + int __major, + int __minor +) noexcept +``` + + +**Parameters** + + +The major compute capability. + + + +The minor compute capability. Must be less than 10. + + + + + +inline explicit + +Constructs the object from the architecture id. + + +```cpp showLineNumbers={false} +cuda::compute_capability::compute_capability( + arch_id __arch_id +) noexcept +``` + + +**Parameters** + + +The architecture id. + + + + + + +```cpp showLineNumbers={false} +cuda::compute_capability::compute_capability( + const compute_capability & +) noexcept = default +``` + + + + + +--- + +## Assignment operators + +### operator= constexpr noexcept + + +```cpp showLineNumbers={false} +compute_capability & cuda::compute_capability::operator=( + const compute_capability &__other +) noexcept = default +``` + + +--- + +## Methods + +### get inline constexpr const noexcept nodiscard + +Gets the stored compute capability. + + +```cpp showLineNumbers={false} +int cuda::compute_capability::get() const noexcept +``` + + +**Returns:** The stored compute capability in format 10 * major + minor. + +### major inline constexpr const noexcept nodiscard + +Gets the major compute capability. + + +```cpp showLineNumbers={false} +int cuda::compute_capability::major() const noexcept +``` + + + +This symbol is deprecated because it collides with major(...) macro defined in <sys/sysmacros.h> and will be removed in next major release. Use cc.major_cap() instead. + + +**Returns:** Major compute capability. + +### major_cap inline constexpr const noexcept nodiscard + +Gets the major compute capability. + + +```cpp showLineNumbers={false} +int cuda::compute_capability::major_cap() const noexcept +``` + + +**Returns:** Major compute capability. + +### minor inline constexpr const noexcept nodiscard + +Gets the minor compute capability. + + +```cpp showLineNumbers={false} +int cuda::compute_capability::minor() const noexcept +``` + + + +This symbol is deprecated because it collides with minor(...) macro defined in <sys/sysmacros.h> and will be removed in next major release. Use cc.minor_cap() instead. + + +**Returns:** Minor compute capability. The value is always less than 10. + +### minor_cap inline constexpr const noexcept nodiscard + +Gets the minor compute capability. + + +```cpp showLineNumbers={false} +int cuda::compute_capability::minor_cap() const noexcept +``` + + +**Returns:** Minor compute capability. The value is always less than 10. + +### operator int inline constexpr explicit const noexcept + +Conversion operator to `int`. + + +```cpp showLineNumbers={false} +cuda::compute_capability::operator int() const noexcept +``` + + +**Returns:** The stored compute capability in format 10 * major + minor. diff --git a/fern/cudapages/cuda/cuda/cuda/constant_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/constant_iterator.mdx new file mode 100644 index 0000000..608063b --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/constant_iterator.mdx @@ -0,0 +1,255 @@ +--- +title: "cuda::constant_iterator" +description: "The [`constant_iterator`](/libcudacxx/api/cuda::constant_iterator) class represents an iterator in an infinite sequence of repeated values." +--- + +The `constant_iterator` class represents an iterator in an infinite sequence of repeated values. + +This iterator is useful for creating a range filled with the same value without explicitly storing it in memory. Using `constant_iterator` saves both memory capacity and bandwidth. + +The following code snippet demonstrates how to create a `constant_iterator` whose [`value_type`](/libcudacxx/api/cuda::constant_iterator::value_type) is `int` and whose value is `10`. + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include + +cuda::constant_iterator iter(10); + +*iter; // returns 10 +iter[0]; // returns 10 +iter[1]; // returns 10 +iter[13]; // returns 10 + +// and so on... +``` + + + + + +The value type of the `constant_iterator`. + + + +The index type of the `constant_iterator`. It can optionally be specified, but must satisfy **integer-like** + + + + + +--- + +## Constructors + +### constant_iterator inline constexpr noexcept + + + + + +```cpp showLineNumbers={false} +template +cuda::constant_iterator<_Tp, _Index>::constant_iterator() noexcept(::cuda::std::is_nothrow_default_constructible_v< _Tp2 >) +``` + + + + + +Creates a `constant_iterator` from a value. + +The index is set to zero + + +```cpp showLineNumbers={false} +cuda::constant_iterator<_Tp, _Index>::constant_iterator( + _Tp __value +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Tp >) +``` + + +**Parameters** + + +The value to store in the `constant_iterator` + + + + + +explicit + +Creates `constant_iterator` from a value and an index. + + +```cpp showLineNumbers={false} +template +cuda::constant_iterator<_Tp, _Index>::constant_iterator( + _Tp __value, + _Index2 __index +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Tp >) +``` + + +**Parameters** + + +The value to store in the `constant_iterator` + + + +The index in the sequence represented by this `constant_iterator` + + + + + +--- + +## Methods + +### index inline constexpr const noexcept nodiscard + +Returns a the current index. + + +```cpp showLineNumbers={false} +difference_type cuda::constant_iterator<_Tp, _Index>::index() const noexcept +``` + + +### operator* inline constexpr const noexcept nodiscard + +Returns a const reference to the stored value. + + +```cpp showLineNumbers={false} +const _Tp & cuda::constant_iterator<_Tp, _Index>::operator*() const noexcept +``` + + +### operator[] inline constexpr const noexcept nodiscard + +Returns a const reference to the stored value. + + +```cpp showLineNumbers={false} +const _Tp & cuda::constant_iterator<_Tp, _Index>::operator[]( + difference_type +) const noexcept +``` + + +### operator++ inline constexpr noexcept + + + + +Increments the stored index. + + +```cpp showLineNumbers={false} +constant_iterator & cuda::constant_iterator<_Tp, _Index>::operator++() noexcept +``` + + + + + +Increments the stored index. + + +```cpp showLineNumbers={false} +constant_iterator cuda::constant_iterator<_Tp, _Index>::operator++( + int +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Tp >) +``` + + + + + +### operator-- inline constexpr noexcept + + + + +Decrements the stored index. + + +```cpp showLineNumbers={false} +constant_iterator & cuda::constant_iterator<_Tp, _Index>::operator--() noexcept +``` + + + + + +Decrements the stored index. + + +```cpp showLineNumbers={false} +constant_iterator cuda::constant_iterator<_Tp, _Index>::operator--( + int +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Tp >) +``` + + + + + +### operator+= inline constexpr noexcept + +Advances a `constant_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +constant_iterator & cuda::constant_iterator<_Tp, _Index>::operator+=( + difference_type __n +) noexcept +``` + + +**Parameters** + + +The amount of elements to advance + + +### operator-= inline constexpr noexcept + +Decrements a `constant_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +constant_iterator & cuda::constant_iterator<_Tp, _Index>::operator-=( + difference_type __n +) noexcept +``` + + +**Parameters** + + +The amount of elements to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `::cuda::std::random_access_iterator_tag` | +| `iterator_category` | `::cuda::std::random_access_iterator_tag` | +| `value_type` | `_Tp` | +| `difference_type` | `::cuda::std::ptrdiff_t` | +| `reference` | `_Tp` | +| `pointer` | `void` | diff --git a/fern/cudapages/cuda/cuda/cuda/copy_configuration.mdx b/fern/cudapages/cuda/cuda/cuda/copy_configuration.mdx new file mode 100644 index 0000000..8939919 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/copy_configuration.mdx @@ -0,0 +1,20 @@ +--- +title: "cuda::copy_configuration" +description: "Configuration for copy_bytes." +--- + +Configuration for copy_bytes. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `src_location_hint` | `memory_location` | Source memory location hint for copy_bytes, used only for managed memory. | +| `dst_location_hint` | `memory_location` | Destination memory location hint for copy_bytes, used only for managed memory. | +| `src_access_order` | `source_access_order` | Source access order for copy_bytes. | diff --git a/fern/cudapages/cuda/cuda/cuda/counting_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/counting_iterator.mdx new file mode 100644 index 0000000..ad3924d --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/counting_iterator.mdx @@ -0,0 +1,232 @@ +--- +title: "cuda::counting_iterator" +description: "A [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator) represents an iterator into a range of sequentially increasing values." +--- + +A `counting_iterator` represents an iterator into a range of sequentially increasing values. + +This iterator is useful for creating a range filled with a sequence without explicitly storing it in memory. Using `counting_iterator` saves memory capacity and bandwidth. + +The following code snippet demonstrates how to create a `counting_iterator` whose [`value_type`](/libcudacxx/api/cuda::counting_iterator::value_type) is `int` + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +... +// create iterators +cuda::counting_iterator first(10); +cuda::counting_iterator last = first + 3; + +first[0] // returns 10 +first[1] // returns 11 +first[100] // returns 110 + +// sum of [first, last) +std::reduce(first, last); // returns 33 (i.e. 10 + 11 + 12) + +// initialize vector to [0,1,2,..] +cuda::counting_iterator iter(0); +std::vector vec(500); +std::copy(iter, iter + vec.size(), vec.begin()); +``` + + + + + +The value type of the `counting_iterator`. + + + + + +**Inherits from:** `__counting_iterator_category< _Start >` (public) + +--- + +## Constructors + +### counting_iterator inline constexpr noexcept + + + + + +```cpp showLineNumbers={false} +template +cuda::counting_iterator<_Start,,>::counting_iterator() noexcept(::cuda::std::is_nothrow_default_constructible_v< _Start2 >) +``` + + + + + +explicit + +Creates a `counting_iterator` from an initial value. + + +```cpp showLineNumbers={false} +cuda::counting_iterator<_Start,,>::counting_iterator( + _Start __value +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Start >) +``` + + +**Parameters** + + +The value to store in the `counting_iterator` + + + + + +--- + +## Methods + +### operator* inline constexpr const noexcept nodiscard + +Returns the value currently stored in the `counting_iterator`. + + +```cpp showLineNumbers={false} +_Start cuda::counting_iterator<_Start,,>::operator*() const noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Start >) +``` + + +### operator[] inline constexpr const noexcept nodiscard + +Returns the value currently stored in the `counting_iterator` advanced by a number of steps. + + +```cpp showLineNumbers={false} +template +_Start2 cuda::counting_iterator<_Start,,>::operator[]( + difference_type __n +) const noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Start2 > &&noexcept(::cuda::std::declval< const _Start2 & >()+__n)) +``` + + +**Parameters** + + +The amount of elements to advance + + +### operator++ inline constexpr noexcept + + + + +Increments the stored value. + + +```cpp showLineNumbers={false} +counting_iterator & cuda::counting_iterator<_Start,,>::operator++() noexcept(++::cuda::std::declval< _Start & >()) +``` + + + + + +Increments the stored value. + + +```cpp showLineNumbers={false} +auto cuda::counting_iterator<_Start,,>::operator++( + int +) noexcept(noexcept(++::cuda::std::declval< _Start & >()) &&::cuda::std::is_nothrow_copy_constructible_v< _Start >) +``` + + + + + +### operator-- inline constexpr noexcept + + + + +Decrements the stored value. + + +```cpp showLineNumbers={false} +template +counting_iterator & cuda::counting_iterator<_Start,,>::operator--() noexcept(--::cuda::std::declval< _Start2 & >()) +``` + + + + + +Decrements the stored value. + + +```cpp showLineNumbers={false} +template +counting_iterator cuda::counting_iterator<_Start,,>::operator--( + int +) noexcept(noexcept(--::cuda::std::declval< _Start2 & >()) &&::cuda::std::is_nothrow_copy_constructible_v< _Start >) +``` + + + + + +### operator+= inline constexpr noexcept + +Increments the stored value by a given number of elements. + + +```cpp showLineNumbers={false} +counting_iterator & cuda::counting_iterator<_Start,,>::operator+=( + difference_type __n +) noexcept(::cuda::std::__integer_like< _Start >) +``` + + +**Parameters** + + +The number of elements to increment + + +### operator-= inline constexpr noexcept + +Decrements the stored value by a given number of elements. + + +```cpp showLineNumbers={false} +template +counting_iterator & cuda::counting_iterator<_Start,,>::operator-=( + difference_type __n +) noexcept(::cuda::std::__integer_like< _Start2 >) +``` + + +**Parameters** + + +The amount of elements to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `::cuda::std::conditional_t< __advanceable< _Start >, ::cuda::std::random_access_iterator_tag, ::cuda::std::conditional_t< __decrementable< _Start >, ::cuda::std::bidirectional_iterator_tag, ::cuda::std::conditional_t<::cuda::std::incrementable< _Start >, ::cuda::std::forward_iterator_tag, ::cuda::std::input_iterator_tag > > >` | +| `value_type` | `_Start` | +| `difference_type` | `_IotaDiffT< _Start >` | +| `reference` | `_Start` | +| `pointer` | `void` | diff --git a/fern/cudapages/cuda/cuda/cuda/device_attributes/compute_capability_t.mdx b/fern/cudapages/cuda/cuda/cuda/device_attributes/compute_capability_t.mdx new file mode 100644 index 0000000..00f387c --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/device_attributes/compute_capability_t.mdx @@ -0,0 +1,32 @@ +--- +title: "cuda::device_attributes::compute_capability_t" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Methods + +### operator() inline const nodiscard + + +```cpp showLineNumbers={false} +type cuda::device_attributes::compute_capability_t::operator()( + device_ref __dev_id +) const +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `::cuda::compute_capability` | diff --git a/fern/cudapages/cuda/cuda/cuda/device_memory_pool.mdx b/fern/cudapages/cuda/cuda/cuda/device_memory_pool.mdx new file mode 100644 index 0000000..bdbb564 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/device_memory_pool.mdx @@ -0,0 +1,140 @@ +--- +title: "cuda::device_memory_pool" +description: "" +--- + +`device_memory_pool` allocates device memory using `cudaMallocFromPoolAsync / cudaFreeAsync +<https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`__ +for allocation/deallocation. When constructed it creates an underlying \c cudaMemPool_t with the location type set to \c cudaMemLocationTypeDevice and owns it. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `cuda::device_memory_pool_ref` (public) + +--- + +## Constructors + +### device_memory_pool inline + + + + +Constructs a `device_memory_pool` with the optionally specified initial pool size and release threshold. + +If the pool size grows beyond the release threshold, unused memory held by the pool will be released at the next synchronization event. + + +```cpp showLineNumbers={false} +cuda::device_memory_pool::device_memory_pool( + ::cuda::device_ref __device_id, + memory_pool_properties __properties = {} +) +``` + + +**Throws:** `cuda_error` if the CUDA version does not support `cudaMallocAsync`. + +**Parameters** + + +The device id of the device the stream pool is constructed on. + + + +Optional, additional properties of the pool to be created. + + + + + +noexcept + + +```cpp showLineNumbers={false} +cuda::device_memory_pool::device_memory_pool( + ::cudaMemPool_t __pool +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +cuda::device_memory_pool::device_memory_pool( + const device_memory_pool & +) = delete +``` + + + + + +### Destructor + +### ~device_memory_pool inline noexcept + + +```cpp showLineNumbers={false} +cuda::device_memory_pool::~device_memory_pool() noexcept +``` + + +--- + +## Assignment operators + +### operator= + + +```cpp showLineNumbers={false} +device_memory_pool & cuda::device_memory_pool::operator=( + const device_memory_pool & +) = delete +``` + + +--- + +## Methods + +### as_ref inline noexcept nodiscard + +Returns a [`device_memory_pool_ref`](/libcudacxx/api/cuda::device_memory_pool_ref) for this `device_memory_pool`. + +We return by reference to ensure that we can subsequently convert to a resource_ref + + +```cpp showLineNumbers={false} +device_memory_pool_ref & cuda::device_memory_pool::as_ref() noexcept +``` + + +--- + +## Static methods + +### from_native_handle inline static noexcept + + +```cpp showLineNumbers={false} +static device_memory_pool cuda::device_memory_pool::from_native_handle( + ::cudaMemPool_t __pool +) noexcept +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `reference_type` | `device_memory_pool_ref` | +| `default_queries` | `::cuda::mr::properties_list<::cuda::mr::device_accessible >` | diff --git a/fern/cudapages/cuda/cuda/cuda/device_memory_pool_ref.mdx b/fern/cudapages/cuda/cuda/cuda/device_memory_pool_ref.mdx new file mode 100644 index 0000000..b64f489 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/device_memory_pool_ref.mdx @@ -0,0 +1,68 @@ +--- +title: "cuda::device_memory_pool_ref" +description: "" +--- + +`device_memory_pool_ref` allocates device memory using `cudaMallocFromPoolAsync / cudaFreeAsync +<https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`__ +for allocation/deallocation. A `device_memory_pool_ref` is a thin wrapper around a \c cudaMemPool_t with the location type set to \c cudaMemLocationTypeDevice. + +.. warning:: + + `device_memory_pool_ref` does not own the pool and it is the responsibility of the user to ensure that the lifetime of the pool exceeds the lifetime of the `device_memory_pool_ref`. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `cuda::__memory_pool_base` (public) + +--- + +## Constructors + +### device_memory_pool_ref inline explicit noexcept + + + + +Constructs the `device_memory_pool_ref` from a `cudaMemPool_t`. + + +```cpp showLineNumbers={false} +cuda::device_memory_pool_ref::device_memory_pool_ref( + ::cudaMemPool_t __pool +) noexcept +``` + + +**Parameters** + + +The `cudaMemPool_t` used to allocate memory. + + + + + +The following overloads are deleted to prevent misuse: + + +```cpp showLineNumbers={false} +cuda::device_memory_pool_ref::device_memory_pool_ref(int) = delete; +cuda::device_memory_pool_ref::device_memory_pool_ref(::cuda::std::nullptr_t) = delete; +``` + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `default_queries` | `::cuda::mr::properties_list<::cuda::mr::device_accessible >` | diff --git a/fern/cudapages/cuda/cuda/cuda/device_ref.mdx b/fern/cudapages/cuda/cuda/cuda/device_ref.mdx new file mode 100644 index 0000000..beefd1b --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/device_ref.mdx @@ -0,0 +1,155 @@ +--- +title: "cuda::device_ref" +description: "A non-owning representation of a CUDA device." +--- + +A non-owning representation of a CUDA device. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Constructors + +### device_ref inline constexpr noexcept + +Create a `device_ref` object from a native device ordinal. + + +```cpp showLineNumbers={false} +cuda::device_ref::device_ref( + int __id +) noexcept +``` + + +--- + +## Methods + +### get inline constexpr const noexcept nodiscard + +Retrieve the native ordinal of the `device_ref`. + + +```cpp showLineNumbers={false} +int cuda::device_ref::get() const noexcept +``` + + +**Returns:** int The native device ordinal held by the `device_ref` object + +### attribute inline const nodiscard + + + + +Retrieve the specified attribute for the device. + + +```cpp showLineNumbers={false} +template +auto cuda::device_ref::attribute( + _Attr __attr +) const +``` + + +**Throws:** `cuda_error` if the attribute query fails + +**Parameters** + + +The attribute to query. See `device::attrs` for the available attributes. + + +**See also:** +device::attrs + + + + + +This is an overloaded member function, provided for convenience. It differs from the above function only in what argument(s) it accepts. + + +```cpp showLineNumbers={false} +template <::cudaDeviceAttr _Attr> +auto cuda::device_ref::attribute() const +``` + + + + + +### operator memory_location inline const noexcept nodiscard + +Retrieve the memory location of this device. + + +```cpp showLineNumbers={false} +cuda::device_ref::operator memory_location() const noexcept +``` + + +**Returns:** The memory location of this device + +### init inline const + +Initializes the primary context of the device. + + +```cpp showLineNumbers={false} +void cuda::device_ref::init() const +``` + + +### name inline const nodiscard + +Retrieve the name of this device. + + +```cpp showLineNumbers={false} +cuda::std::string_view cuda::device_ref::name() const +``` + + +**Returns:** String view containing the name of this device. + +### has_peer_access_to inline const nodiscard + +Queries if its possible for this device to directly access specified device's memory. + +If this function returns true, device supplied to this call can be passed into enable_peer_access on memory resource or pool that manages memory on this device. It will make allocations from that pool accessible by this device. + + +```cpp showLineNumbers={false} +bool cuda::device_ref::has_peer_access_to( + device_ref __other_dev +) const +``` + + +**Returns:** true if its possible for this device to access the specified device's memory + +**Parameters** + + +Device to query the peer access + + +### peers inline const nodiscard + +Retrieve `device_ref`s that are peers of this device. + +The device on which this API is called is not included in the vector. + + +```cpp showLineNumbers={false} +cuda::std::span cuda::device_ref::peers() const +``` + + +**Throws:** `cuda_error` if any peer access query fails diff --git a/fern/cudapages/cuda/cuda/cuda/discard_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/discard_iterator.mdx new file mode 100644 index 0000000..01d78da --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/discard_iterator.mdx @@ -0,0 +1,264 @@ +--- +title: "cuda::discard_iterator" +description: "[`discard_iterator`](/libcudacxx/api/cuda::discard_iterator) is an iterator which represents a special kind of pointer that ignores values written to it upon dereference." +--- + +`discard_iterator` is an iterator which represents a special kind of pointer that ignores values written to it upon dereference. + +This iterator is useful for ignoring the output of certain algorithms without wasting memory capacity or bandwidth. `discard_iterator` may also be used to count the size of an algorithm's output which may not be known a priori. + +The following code snippet demonstrates how to use `discard_iterator` to ignore one of the output ranges of reduce_by_key + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include + +int main() +{ + thrust::device_vector keys{1, 3, 3, 3, 2, 2, 1}; + thrust::device_vector values{9, 8, 7, 6, 5, 4, 3}; + + thrust::device_vector result(4); + + // we are only interested in the reduced values + // use discard_iterator to ignore the output keys + thrust::reduce_by_key(keys.begin(), keys.end(), + values.begin(), + cuda::discard_iterator{}, + result.begin()); + + // result is now [9, 21, 9, 3] + + return 0; +} +``` + +--- + +## Constructors + +### discard_iterator constexpr + + + + +Default constructs a `discard_iterator` at index zero. + + +```cpp showLineNumbers={false} +cuda::discard_iterator::discard_iterator() = default +``` + + + + + +inline noexcept + +Constructs a `discard_iterator` with a given index. + + +```cpp showLineNumbers={false} +template +cuda::discard_iterator::discard_iterator( + _Integer __index +) noexcept +``` + + +**Parameters** + + +The index used for the discard iterator + + + + + +--- + +## Methods + +### index inline constexpr const noexcept nodiscard + +Returns the stored index. + + +```cpp showLineNumbers={false} +difference_type cuda::discard_iterator::index() const noexcept +``` + + +### operator* inline constexpr const noexcept nodiscard + +Dereferences the `discard_iterator` returning a proxy that discards all values that are assigned to it. + + +```cpp showLineNumbers={false} +__discard_proxy cuda::discard_iterator::operator*() const noexcept +``` + + +### operator[] inline constexpr const noexcept nodiscard + +Subscipts the `discard_iterator` returning a proxy that discards all values that are assigned to it. + + +```cpp showLineNumbers={false} +__discard_proxy cuda::discard_iterator::operator[]( + difference_type +) const noexcept +``` + + +### operator++ inline constexpr noexcept + + + + +Increments the stored index. + + +```cpp showLineNumbers={false} +discard_iterator & cuda::discard_iterator::operator++() noexcept +``` + + + + + +Increments the stored index. + + +```cpp showLineNumbers={false} +discard_iterator cuda::discard_iterator::operator++( + int +) noexcept +``` + + + + + +### operator-- inline constexpr noexcept + + + + +Decrements the stored index. + + +```cpp showLineNumbers={false} +discard_iterator & cuda::discard_iterator::operator--() noexcept +``` + + + + + +Decrements the stored index. + + +```cpp showLineNumbers={false} +discard_iterator cuda::discard_iterator::operator--( + int +) noexcept +``` + + + + + +### operator+ inline constexpr const noexcept nodiscard + +Returns a copy of this `discard_iterator` advanced by a number of elements. + + +```cpp showLineNumbers={false} +discard_iterator cuda::discard_iterator::operator+( + difference_type __n +) const noexcept +``` + + +**Parameters** + + +The number of elements to advance + + +### operator+= inline constexpr noexcept + +Advances the index of this `discard_iterator` by a number of elements. + + +```cpp showLineNumbers={false} +discard_iterator & cuda::discard_iterator::operator+=( + difference_type __n +) noexcept +``` + + +**Parameters** + + +The number of elements to advance + + +### operator- inline constexpr const noexcept nodiscard + +Returns a copy of this `discard_iterator` decremented by a number of elements. + + +```cpp showLineNumbers={false} +discard_iterator cuda::discard_iterator::operator-( + difference_type __n +) const noexcept +``` + + +**Parameters** + + +The number of elements to decrement + + +### operator-= inline constexpr noexcept + +Decrements the index of the `discard_iterator` by a number of elements. + + +```cpp showLineNumbers={false} +discard_iterator & cuda::discard_iterator::operator-=( + difference_type __n +) noexcept +``` + + +**Parameters** + + +The number of elements to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `::cuda::std::random_access_iterator_tag` | +| `iterator_category` | `::cuda::std::random_access_iterator_tag` | +| `difference_type` | `::cuda::std::ptrdiff_t` | +| `value_type` | `void` | +| `pointer` | `void` | +| `reference` | `void` | diff --git a/fern/cudapages/cuda/cuda/cuda/event.mdx b/fern/cudapages/cuda/cuda/cuda/event.mdx new file mode 100644 index 0000000..dd614f1 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/event.mdx @@ -0,0 +1,342 @@ +--- +title: "cuda::event" +description: "An owning wrapper for an untimed `cudaEvent_t`." +--- + +An owning wrapper for an untimed `cudaEvent_t`. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `cuda::event_ref` (public) + +--- + +## Constructors + +### event inline + + + + +explicit + +Construct a new `event` object with timing disabled, and record the event in the specified stream. + + +```cpp showLineNumbers={false} +cuda::event::event( + stream_ref __stream, + event_flags __flags = event_flags::none +) +``` + + +**Throws:** `cuda_error` if the event creation fails. + + + + +explicit + +Construct a new `event` object with timing disabled. + +The event can only be recorded on streams from the specified device. + + +```cpp showLineNumbers={false} +cuda::event::event( + device_ref __device, + event_flags __flags = event_flags::none +) +``` + + +**Throws:** `cuda_error` if the event creation fails. + + + + +constexpr explicit noexcept + +Construct a new `event` object into the moved-from state. + + +```cpp showLineNumbers={false} +cuda::event::event( + no_init_t +) noexcept +``` + + + +[`get()`](/libcudacxx/api/cuda::event_ref::get()) returns `cudaEvent_t()`. + + + + + +constexpr noexcept + +Move-construct a new `event` object. + + +```cpp showLineNumbers={false} +cuda::event::event( + event &&__other +) noexcept +``` + + + +`__other` is in a moved-from state. + + + + + +constexpr explicit noexcept + + +```cpp showLineNumbers={false} +cuda::event::event( + ::cudaEvent_t __evnt +) noexcept +``` + + + + + +explicit + + +```cpp showLineNumbers={false} +cuda::event::event( + stream_ref __stream, + unsigned __flags +) +``` + + + + + +explicit + + +```cpp showLineNumbers={false} +cuda::event::event( + device_ref __device, + unsigned __flags +) +``` + + + + + + +```cpp showLineNumbers={false} +cuda::event::event( + const event & +) = delete +``` + + + + + +### Destructor + +### ~event inline + +Destroy the `event` object. + + +```cpp showLineNumbers={false} +cuda::event::~event() +``` + + + +If the event fails to be destroyed, the error is silently ignored. + + +--- + +## Assignment operators + +### operator= inline noexcept + + + + +Move-assign an `event` object. + + +```cpp showLineNumbers={false} +event & cuda::event::operator=( + event &&__other +) noexcept +``` + + + +`__other` is in a moved-from state. + + + + + + +```cpp showLineNumbers={false} +event & cuda::event::operator=( + const event & +) = delete +``` + + + + + +--- + +## Methods + +### release inline noexcept nodiscard + +Retrieve the native `cudaEvent_t` handle and give up ownership. + + +```cpp showLineNumbers={false} +::cudaEvent_t cuda::event::release() noexcept +``` + + + +The event object is in a moved-from state. + + +**Returns:** cudaEvent_t The native handle being held by the `event` object. + +### record inline const + +Records an event on the specified stream. + + +```cpp showLineNumbers={false} +void cuda::event::record( + stream_ref __stream +) const +``` + + +**Throws:** `cuda_error` if the event record fails + +### sync inline const + +Synchronizes the event. + + +```cpp showLineNumbers={false} +void cuda::event::sync() const +``` + + +**Throws:** `cuda_error` if waiting for the event fails + +### is_done inline const nodiscard + +Checks if all the work in the stream prior to the record of the event has completed. + +If is_done returns true, calling [sync()](/libcudacxx/api/cuda::event_ref::sync()) on this event will return immediately + + +```cpp showLineNumbers={false} +bool cuda::event::is_done() const +``` + + +**Throws:** `cuda_error` if the event query fails + +### get inline const noexcept nodiscard + +Retrieve the native `cudaEvent_t` handle. + + +```cpp showLineNumbers={false} +::cudaEvent_t cuda::event::get() const noexcept +``` + + +**Returns:** cudaEvent_t The native handle being held by the [event_ref](/libcudacxx/api/cuda::event_ref) object. + +### operator bool inline constexpr explicit const noexcept nodiscard + +Checks if the [`event_ref`](/libcudacxx/api/cuda::event_ref) is valid. + + +```cpp showLineNumbers={false} +cuda::event::operator bool() const noexcept +``` + + +**Returns:** true if the [`event_ref`](/libcudacxx/api/cuda::event_ref) is valid, false otherwise. + +--- + +## Static methods + +### from_native_handle inline static noexcept nodiscard + + + + +Construct an `event` object from a native `cudaEvent_t` handle. + + +```cpp showLineNumbers={false} +static event cuda::event::from_native_handle( + ::cudaEvent_t __evnt +) noexcept +``` + + + +The constructed `event` object takes ownership of the native handle. + + +**Returns:** event The constructed `event` object + +**Parameters** + + +The native handle + + + + + +The following overloads are deleted to prevent misuse: + + +```cpp showLineNumbers={false} +static event cuda::event::from_native_handle(int) = delete; +static event cuda::event::from_native_handle(::cuda::std::nullptr_t) = delete; +``` + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `value_type` | `::cudaEvent_t` | diff --git a/fern/cudapages/cuda/cuda/cuda/event_ref.mdx b/fern/cudapages/cuda/cuda/cuda/event_ref.mdx new file mode 100644 index 0000000..66b544d --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/event_ref.mdx @@ -0,0 +1,132 @@ +--- +title: "cuda::event_ref" +description: "An non-owning wrapper for an untimed `cudaEvent_t`." +--- + +An non-owning wrapper for an untimed `cudaEvent_t`. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Constructors + +### event_ref inline constexpr noexcept + + + + +Construct a new `event_ref` object from a `cudaEvent_t`. + +This constructor provides an implicit conversion from `cudaEvent_t` + + +```cpp showLineNumbers={false} +cuda::event_ref::event_ref( + ::cudaEvent_t __evnt +) noexcept +``` + + + +: It is the callers responsibility to ensure the `event_ref` does not outlive the event denoted by the `cudaEvent_t` handle. + + + +[`get()`](/libcudacxx/api/cuda::event_ref::get())` == __evnt` + + + + + +The following overloads are deleted to prevent misuse: + + +```cpp showLineNumbers={false} +cuda::event_ref::event_ref(int) = delete; +cuda::event_ref::event_ref(::cuda::std::nullptr_t) = delete; +``` + + + + + +--- + +## Methods + +### record inline const + +Records an event on the specified stream. + + +```cpp showLineNumbers={false} +void cuda::event_ref::record( + stream_ref __stream +) const +``` + + +**Throws:** `cuda_error` if the event record fails + +### sync inline const + +Synchronizes the event. + + +```cpp showLineNumbers={false} +void cuda::event_ref::sync() const +``` + + +**Throws:** `cuda_error` if waiting for the event fails + +### is_done inline const nodiscard + +Checks if all the work in the stream prior to the record of the event has completed. + +If is_done returns true, calling [sync()](/libcudacxx/api/cuda::event_ref::sync()) on this event will return immediately + + +```cpp showLineNumbers={false} +bool cuda::event_ref::is_done() const +``` + + +**Throws:** `cuda_error` if the event query fails + +### get inline const noexcept nodiscard + +Retrieve the native `cudaEvent_t` handle. + + +```cpp showLineNumbers={false} +::cudaEvent_t cuda::event_ref::get() const noexcept +``` + + +**Returns:** cudaEvent_t The native handle being held by the `event_ref` object. + +### operator bool inline constexpr explicit const noexcept nodiscard + +Checks if the `event_ref` is valid. + + +```cpp showLineNumbers={false} +cuda::event_ref::operator bool() const noexcept +``` + + +**Returns:** true if the `event_ref` is valid, false otherwise. + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `value_type` | `::cudaEvent_t` | diff --git a/fern/cudapages/cuda/cuda/cuda/get_stream_t.mdx b/fern/cudapages/cuda/cuda/cuda/get_stream_t.mdx new file mode 100644 index 0000000..2a469b8 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/get_stream_t.mdx @@ -0,0 +1,92 @@ +--- +title: "cuda::get_stream_t" +description: "[`get_stream`](/libcudacxx/api/cuda::get_stream) is a customization point object that queries a type `T` for an associated stream" +--- + +[`get_stream`](/libcudacxx/api/cuda::get_stream) is a customization point object that queries a type `T` for an associated stream + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Methods + +### operator() inline const noexcept nodiscard + + + + + +```cpp showLineNumbers={false} +::cuda::stream_ref cuda::get_stream_t::operator()( + ::cudaStream_t __stream +) const noexcept +``` + + + + + + +```cpp showLineNumbers={false} +template +::cuda::stream_ref cuda::get_stream_t::operator()( + const _Tp &__t +) const noexcept(static_cast<::cuda::stream_ref >(__t)) +``` + + + + + + +```cpp showLineNumbers={false} +template +::cuda::stream_ref cuda::get_stream_t::operator()( + const _Tp &__t +) const noexcept(__t.stream()) +``` + + + + + + +```cpp showLineNumbers={false} +template +::cuda::stream_ref cuda::get_stream_t::operator()( + const _Tp &__t +) const noexcept(__t.get_stream()) +``` + + + + + + +```cpp showLineNumbers={false} +template +::cuda::stream_ref cuda::get_stream_t::operator()( + const _Env &__env +) const noexcept +``` + + + + + +--- + +## Static methods + +### query inline static constexpr noexcept nodiscard + + +```cpp showLineNumbers={false} +static constexpr bool cuda::get_stream_t::query( + ::cuda::std::execution::forwarding_query_t +) noexcept +``` + diff --git a/fern/cudapages/cuda/cuda/cuda/has_property.mdx b/fern/cudapages/cuda/cuda/cuda/has_property.mdx new file mode 100644 index 0000000..fdaa128 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/has_property.mdx @@ -0,0 +1,33 @@ +--- +title: "cuda::has_property" +description: "The `has_property` concept verifies that a Resource satisfies a given Property." +--- + +C++20 concept + +The `has_property` concept verifies that a Resource satisfies a given Property. + + +```cpp showLineNumbers={false} +template +concept has_property = /* see description */; +``` + + + + + + + + + + + + + + +--- + +## Description + +For \c has_property we require the following free function to be callable diff --git a/fern/cudapages/cuda/cuda/cuda/has_property_with.mdx b/fern/cudapages/cuda/cuda/cuda/has_property_with.mdx new file mode 100644 index 0000000..4ce85b8 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/has_property_with.mdx @@ -0,0 +1,36 @@ +--- +title: "cuda::has_property_with" +description: "The `has_property_with` concept verifies that a Resource satisfies a given stateful Property." +--- + +C++20 concept + +The `has_property_with` concept verifies that a Resource satisfies a given stateful Property. + + +```cpp showLineNumbers={false} +template +concept has_property_with = /* see description */; +``` + + + + + + + + + + + + + + + + + +--- + +## Description + +For \c has_property_with we require the following free function to be callable and its return type to exactly match the `value_type` of the Property diff --git a/fern/cudapages/cuda/cuda/cuda/heterogeneous_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/heterogeneous_iterator.mdx new file mode 100644 index 0000000..0eacc5b --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/heterogeneous_iterator.mdx @@ -0,0 +1,272 @@ +--- +title: "cuda::heterogeneous_iterator" +description: "" +--- + +`heterogeneous_iterator` provides a type safe access over heterogeneous memory. Depending on whether the memory is tagged as host-accessible and / or device-accessible the iterator restricts memory access. All operations that do not require memory access are always available on host and device. + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + +The properties that the `heterogeneous_iterator` is tagged with. + + + + + +**Inherits from:** `cuda::__heterogeneous_iterator_access< ::cuda::std::remove_const_t< _CvTp >, ::cuda::std::is_const_v< _CvTp > ? __is_heterogeneous_const_iter::__yes :__is_heterogeneous_const_iter::__no, ::cuda::mr::__memory_accessability_from_properties< _Properties... >::value >` (public) + +--- + +## Constructors + +### heterogeneous_iterator + + + + + +```cpp showLineNumbers={false} +cuda::heterogeneous_iterator<_CvTp, _Properties>::heterogeneous_iterator() = default +``` + + + + + +inline constexpr noexcept + +Construct a `heterogeneous_iterator` from a pointer to the underlying memory. + + +```cpp showLineNumbers={false} +cuda::heterogeneous_iterator<_CvTp, _Properties>::heterogeneous_iterator( + pointer __ptr +) noexcept +``` + + + + + +inline constexpr noexcept + +Constructs an immutable `heterogeneous_iterator` from a mutable one. + + +```cpp showLineNumbers={false} +template +cuda::heterogeneous_iterator<_CvTp, _Properties>::heterogeneous_iterator( + heterogeneous_iterator<_OtherTp, _Properties...> __other +) noexcept +``` + + +**Parameters** + + +The mutable `heterogeneous_iterator` + + + + + +--- + +## Methods + +### operator++ inline constexpr noexcept + + + + +Increment of a `heterogeneous_iterator`. + + +```cpp showLineNumbers={false} +heterogeneous_iterator & cuda::heterogeneous_iterator<_CvTp, _Properties>::operator++() noexcept +``` + + +**Returns:** The `heterogeneous_iterator` pointing to the next element + + + + +Post-increment of a `heterogeneous_iterator`. + + +```cpp showLineNumbers={false} +heterogeneous_iterator cuda::heterogeneous_iterator<_CvTp, _Properties>::operator++( + int +) noexcept +``` + + +**Returns:** A copy of the `heterogeneous_iterator` pointing to the next element + + + + +### operator-- inline constexpr noexcept + + + + +Decrement of a `heterogeneous_iterator`. + + +```cpp showLineNumbers={false} +heterogeneous_iterator & cuda::heterogeneous_iterator<_CvTp, _Properties>::operator--() noexcept +``` + + +**Returns:** The `heterogeneous_iterator` pointing to the previous element + + + + +Post-decrement of a `heterogeneous_iterator`. + + +```cpp showLineNumbers={false} +heterogeneous_iterator cuda::heterogeneous_iterator<_CvTp, _Properties>::operator--( + int +) noexcept +``` + + +**Returns:** A copy of the `heterogeneous_iterator` pointing to the previous element + + + + +### operator+= inline constexpr noexcept + +Advance a `heterogeneous_iterator`. + + +```cpp showLineNumbers={false} +heterogeneous_iterator & cuda::heterogeneous_iterator<_CvTp, _Properties>::operator+=( + const difference_type __count +) noexcept +``` + + +**Returns:** The `heterogeneous_iterator` advanced by `__count` + +**Parameters** + + +The number of elements to advance. + + +### operator+ inline constexpr const noexcept nodiscard + +Advance a `heterogeneous_iterator`. + + +```cpp showLineNumbers={false} +heterogeneous_iterator cuda::heterogeneous_iterator<_CvTp, _Properties>::operator+( + const difference_type __count +) const noexcept +``` + + +**Returns:** A copy of this `heterogeneous_iterator` advanced by `__count` + +**Parameters** + + +The number of elements to advance. + + +### operator-= inline constexpr noexcept + +Advance a `heterogeneous_iterator` by the negative value of `__count`. + + +```cpp showLineNumbers={false} +heterogeneous_iterator & cuda::heterogeneous_iterator<_CvTp, _Properties>::operator-=( + const difference_type __count +) noexcept +``` + + +**Returns:** The `heterogeneous_iterator` advanced by the negative value of `__count` + +**Parameters** + + +The number of elements to advance. + + +### operator- inline constexpr const noexcept nodiscard + + + + +Advance a `heterogeneous_iterator` by the negative value of `__count`. + + +```cpp showLineNumbers={false} +heterogeneous_iterator cuda::heterogeneous_iterator<_CvTp, _Properties>::operator-( + const difference_type __count +) const noexcept +``` + + +**Returns:** A copy of this `heterogeneous_iterator` advanced by the negative value of `__count` + +**Parameters** + + +The number of elements to advance. + + + + + +Distance between two `heterogeneous_iterator`. + + +```cpp showLineNumbers={false} +difference_type cuda::heterogeneous_iterator<_CvTp, _Properties>::operator-( + const heterogeneous_iterator &__other +) const noexcept +``` + + +**Returns:** The distance between the two elements the `heterogeneous_iterator` point to + +**Parameters** + + +The other `heterogeneous_iterator`. + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `::cuda::std::contiguous_iterator_tag` | +| `iterator_category` | `::cuda::std::random_access_iterator_tag` | +| `value_type` | `::cuda::std::remove_const_t< _CvTp >` | +| `difference_type` | `::cuda::std::ptrdiff_t` | +| `pointer` | `_CvTp *` | +| `reference` | `_CvTp &` | diff --git a/fern/cudapages/cuda/cuda/cuda/managed_memory_pool.mdx b/fern/cudapages/cuda/cuda/cuda/managed_memory_pool.mdx new file mode 100644 index 0000000..2a5f10b --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/managed_memory_pool.mdx @@ -0,0 +1,133 @@ +--- +title: "cuda::managed_memory_pool" +description: "" +--- + +`managed_memory_pool` allocates managed memory using `cudaMallocFromPoolAsync / cudaFreeAsync +<https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`__ +for allocation/deallocation. When constructed it creates an underlying \c cudaMemPool_t with the allocation type set to \c cudaMemAllocationTypeManaged and owns it. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `cuda::managed_memory_pool_ref` (public) + +--- + +## Constructors + +### managed_memory_pool inline + + + + +Constructs a `managed_memory_pool` with optional properties. + +Properties include the initial pool size and the release threshold. If the pool size grows beyond the release threshold, unused memory held by the pool will be released at the next synchronization event. + + +```cpp showLineNumbers={false} +cuda::managed_memory_pool::managed_memory_pool( + memory_pool_properties __properties = {} +) +``` + + +**Parameters** + + +Optional, additional properties of the pool to be created. + + + + + +noexcept + + +```cpp showLineNumbers={false} +cuda::managed_memory_pool::managed_memory_pool( + ::cudaMemPool_t __pool +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +cuda::managed_memory_pool::managed_memory_pool( + const managed_memory_pool & +) = delete +``` + + + + + +### Destructor + +### ~managed_memory_pool inline noexcept + + +```cpp showLineNumbers={false} +cuda::managed_memory_pool::~managed_memory_pool() noexcept +``` + + +--- + +## Assignment operators + +### operator= + + +```cpp showLineNumbers={false} +managed_memory_pool & cuda::managed_memory_pool::operator=( + const managed_memory_pool & +) = delete +``` + + +--- + +## Methods + +### as_ref inline noexcept nodiscard + +Returns a [`managed_memory_pool_ref`](/libcudacxx/api/cuda::managed_memory_pool_ref) for this `managed_memory_pool`. + +We return by reference to ensure that we can subsequently convert to a resource_ref + + +```cpp showLineNumbers={false} +managed_memory_pool_ref & cuda::managed_memory_pool::as_ref() noexcept +``` + + +--- + +## Static methods + +### from_native_handle inline static noexcept + + +```cpp showLineNumbers={false} +static managed_memory_pool cuda::managed_memory_pool::from_native_handle( + ::cudaMemPool_t __pool +) noexcept +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `reference_type` | `managed_memory_pool_ref` | +| `default_queries` | `::cuda::mr::properties_list<::cuda::mr::device_accessible, ::cuda::mr::host_accessible >` | diff --git a/fern/cudapages/cuda/cuda/cuda/managed_memory_pool_ref.mdx b/fern/cudapages/cuda/cuda/cuda/managed_memory_pool_ref.mdx new file mode 100644 index 0000000..a46969f --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/managed_memory_pool_ref.mdx @@ -0,0 +1,50 @@ +--- +title: "cuda::managed_memory_pool_ref" +description: "" +--- + +`managed_memory_pool_ref` allocates managed memory using `cudaMallocFromPoolAsync / cudaFreeAsync +<https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`__ +for allocation/deallocation. A `managed_memory_pool_ref` is a thin wrapper around a \c cudaMemPool_t with the allocation type set to \c cudaMemAllocationTypeManaged. + +.. warning:: + + `managed_memory_pool_ref` does not own the pool and it is the responsibility of the user to ensure that the lifetime of the pool exceeds the lifetime of the `managed_memory_pool_ref`. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `cuda::__memory_pool_base` (public) + +--- + +## Constructors + +### managed_memory_pool_ref inline explicit noexcept + +Constructs the `managed_memory_pool_ref` from a `cudaMemPool_t`. + + +```cpp showLineNumbers={false} +cuda::managed_memory_pool_ref::managed_memory_pool_ref( + ::cudaMemPool_t __pool +) noexcept +``` + + +**Parameters** + + +The `cudaMemPool_t` used to allocate memory. + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `default_queries` | `::cuda::mr::properties_list<::cuda::mr::device_accessible, ::cuda::mr::host_accessible >` | diff --git a/fern/cudapages/cuda/cuda/cuda/memory_pool_properties.mdx b/fern/cudapages/cuda/cuda/cuda/memory_pool_properties.mdx new file mode 100644 index 0000000..7811a78 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/memory_pool_properties.mdx @@ -0,0 +1,23 @@ +--- +title: "cuda::memory_pool_properties" +description: "[`memory_pool_properties`](/libcudacxx/api/cuda::memory_pool_properties) is a type that can controls memory pool to control the creation options." +--- + +`memory_pool_properties` is a type that can controls memory pool to control the creation options. + +Compared to attributes, properties can not be set after the pool is created. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `initial_pool_size` | `size_t` | | +| `release_threshold` | `size_t` | | +| `allocation_handle_type` | `cudaMemAllocationHandleType` | | +| `max_pool_size` | `size_t` | | diff --git a/fern/cudapages/cuda/cuda/cuda/mr/basic_any_resource.mdx b/fern/cudapages/cuda/cuda/cuda/mr/basic_any_resource.mdx new file mode 100644 index 0000000..fe8ce07 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/basic_any_resource.mdx @@ -0,0 +1,591 @@ +--- +title: "cuda::mr::basic_any_resource" +description: "" +--- + +`basic_any_resource` wraps any given resource that satisfies the required properties. It owns the contained resource, taking care of construction / destruction. This makes it especially suited for use in e.g. container types that need to ensure that the lifetime of the container exceeds the lifetime of the memory resource used to allocate the storage + +`basic_any_resource` models the `cuda::std::regular` concept. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[any_synchronous_resource](/library/api/cuda::mr::any_synchronous_resource), +[any_resource](/library/api/cuda::mr::any_resource), +[synchronous_resource_ref](/library/api/cuda::mr::synchronous_resource_ref), +[resource_ref](/library/api/cuda::mr::resource_ref) + + + + + +Either [`_ResourceKind::_Synchronous`](/library/api/cuda::mr::_Synchronous) for [`any_synchronous_resource`](/library/api/cuda::mr::any_synchronous_resource), or [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous) for [`any_resource`](/library/api/cuda::mr::any_resource). + + + +A pack of property types that a memory resource must provide in order to be storable in instances of this `basic_any_resource` type. + + + + + +--- + +## Constructors + +### basic_any_resource + + + + +noexcept + +Constructs a `basic_any_resource` with no value. + + +```cpp showLineNumbers={false} +cuda::mr::basic_any_resource<_Kind, _Properties>::basic_any_resource() noexcept +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is `false` + + + + + +noexcept + +Move constructs a `basic_any_resource`. + + +```cpp showLineNumbers={false} +cuda::mr::basic_any_resource<_Kind, _Properties>::basic_any_resource( + basic_any_resource &&__other +) noexcept +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is `true` if `__other` had a value prior to the move, and `false` otherwise. `__other.has_value()` is `false`. + + + + + +Copy constructs a `basic_any_resource`. + + +```cpp showLineNumbers={false} +cuda::mr::basic_any_resource<_Kind, _Properties>::basic_any_resource( + const basic_any_resource &__other +) +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is the same as `__other.has_value()`. + + + + + +Constructs a `basic_any_resource` from a type that satisfies the `resource` concept. + +and that supports all of the specified properties. + + +```cpp showLineNumbers={false} +template +cuda::mr::basic_any_resource<_Kind, _Properties>::basic_any_resource( + _Resource __res +) +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is `true` + + + +`_Resource` is not a specialization of `basic_any_resource` or [`basic_resource_ref`](/library/api/cuda::mr::basic_resource_ref), or a type derived from such. + + + +[`synchronous_resource_with`](/library/api/cuda::mr::synchronous_resource_with)`<_Resource, _Properties...>` is `true`. + + + +If `_Kind` is [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous), [`resource_with`](/library/api/cuda::mr::resource_with)`<_Resource, _Properties...>` is `true`. + + +**Parameters** + + +The resource to be wrapped by the `basic_any_resource`. + + + + + +Conversion from a type-erased resource with a superset of the required properties. + + +```cpp showLineNumbers={false} +template <_ResourceKind _OtherKind, class... _OtherProperties> +cuda::mr::basic_any_resource<_Kind, _Properties>::basic_any_resource( + basic_any_resource<_OtherKind, _OtherProperties...> __res +) +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is equal to `__res.has_value()` + + + +`_OtherKind` is equal to either `_Kind` or [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +The set `_OtherProperties...` is a superset of `_Properties...`. + + +**Parameters** + + +The object to copy from. + + + + + +Deep copy from a type-erased resource reference with a superset of the required properties. + +The object to which `__res` refers is copied into `*this`. + + +```cpp showLineNumbers={false} +template <_ResourceKind _OtherKind, class... _OtherProperties> +cuda::mr::basic_any_resource<_Kind, _Properties>::basic_any_resource( + basic_resource_ref<_OtherKind, _OtherProperties...> __res +) +``` + + + +`_OtherKind` is equal to either `_Kind` or [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +The set `_OtherProperties...` is a superset of `_Properties...`. + + +**Parameters** + + +The reference to copy from. + + + + + +--- + +## Assignment operators + +### operator= + + + + +noexcept + +Move assigns a `basic_any_resource`. + + +```cpp showLineNumbers={false} +basic_any_resource & cuda::mr::basic_any_resource<_Kind, _Properties>::operator=( + basic_any_resource &&__other +) noexcept +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is `true` if `__other` had a value prior to the move, and `false` otherwise. + + + +`__other.has_value()` is `false`. + + + + + +Copy assigns a `basic_any_resource`. + + +```cpp showLineNumbers={false} +basic_any_resource & cuda::mr::basic_any_resource<_Kind, _Properties>::operator=( + const basic_any_resource &__other +) +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is the same as `__other.has_value()`. + + + + + +Assigns from a type that satisfies the `resource` concept and that supports all of the specified properties. + + +```cpp showLineNumbers={false} +template +basic_any_resource & cuda::mr::basic_any_resource<_Kind, _Properties>::operator=( + _Resource __res +) +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is `true` + + + +`_Resource` is not a specialization of `basic_any_resource` or [`basic_resource_ref`](/library/api/cuda::mr::basic_resource_ref), or a type derived from such. + + + +[`synchronous_resource_with`](/library/api/cuda::mr::synchronous_resource_with)`<_Resource, _Properties...>` is `true`. + + + +If `_Kind` is [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous), [`resource_with`](/library/api/cuda::mr::resource_with)`<_Resource, _Properties...>` is `true`. + + +**Parameters** + + +The resource to be wrapped within the `basic_any_resource` + + + + + +Assignment from a type-erased resource with a superset of the required properties. + + +```cpp showLineNumbers={false} +template <_ResourceKind _OtherKind, class... _OtherProperties> +basic_any_resource & cuda::mr::basic_any_resource<_Kind, _Properties>::operator=( + basic_any_resource<_OtherKind, _OtherProperties...> __res +) +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is equal to `__res.has_value()`. + + + +`_OtherKind` is equal to either `_Kind` or [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +The set `_OtherProperties...` is a superset of `_Properties...`. + + +**Parameters** + + +The object to copy from. + + + + + +Deep copy from a type-erased resource reference with a superset of the required properties. + + +```cpp showLineNumbers={false} +template <_ResourceKind _OtherKind, class... _OtherProperties> +basic_any_resource & cuda::mr::basic_any_resource<_Kind, _Properties>::operator=( + basic_resource_ref<_OtherKind, _OtherProperties...> __res +) +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is `true`. + + + +`_OtherKind` is equal to either `_Kind` or [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +The set `_OtherProperties...` is a superset of `_Properties...`. + + +**Parameters** + + +The type-erased resource reference to copy from. + + + + + +--- + +## Methods + +### operator== const nodiscard + + + + +Equality comparison between two type-erased memory resource. + + +```cpp showLineNumbers={false} +template <_ResourceKind _OtherKind, class... _OtherProperties> +bool cuda::mr::basic_any_resource<_Kind, _Properties>::operator==( + const basic_any_resource<_OtherKind, _OtherProperties...> &__rhs +) const +``` + + + +`_OtherKind` is equal to either `_Kind` or [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +The set `_Properties...` is equal to the set `_OtherProperties...`. + + +**Returns:** `true` if both resources hold objects of the same type and those objects compare equal, and `false` otherwise. + +**Parameters** + + +The type-erased resource to compare with `*this`. + + + + + +Equality comparison between `*this` and a type-erased resource reference. + + +```cpp showLineNumbers={false} +template <_ResourceKind _OtherKind, class... _OtherProperties> +bool cuda::mr::basic_any_resource<_Kind, _Properties>::operator==( + const basic_resource_ref<_OtherKind, _OtherProperties...> &__rhs +) const +``` + + + +`_OtherKind` is equal to either `_Kind` or [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +The set `_Properties...` is equal to the set `_OtherProperties...`. + + +**Returns:** `true` if `__rhs` refers to an object of the same type as that wrapped by `*this` and those objects compare equal; `false` otherwise. + +**Parameters** + + +The type-erased resource reference to compare with `*this`. + + + + + +### allocate_sync nodiscard + +Calls [`allocate_sync`](/library/api/cuda::mr::basic_any_resource::allocate_sync) on the wrapped object with the specified arguments. + + +```cpp showLineNumbers={false} +void * cuda::mr::basic_any_resource<_Kind, _Properties>::allocate_sync( + size_t __size, + size_t __align = alignof(cuda::std::max_align_t) +) +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is `true`. + + +**Returns:** `obj.allocate_sync(__size, __align)`, where `obj` is the wrapped object. + +### deallocate_sync + +Calls [`deallocate_sync`](/library/api/cuda::mr::basic_any_resource::deallocate_sync) on the wrapped object with the specified arguments. + + +```cpp showLineNumbers={false} +void cuda::mr::basic_any_resource<_Kind, _Properties>::deallocate_sync( + void *__pv, + size_t __size, + size_t __align = alignof(cuda::std::max_align_t) +) +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is `true`. + + + +`__pv` must be a pointer that was previously returned by a call to `allocate` on the object wrapped by `*this`. + + +**Returns:** `obj.deallocate_sync(__pv, __size, __align)`, where `obj` is the wrapped object. + +### allocate nodiscard + + + + +Calls [`allocate`](/library/api/cuda::mr::basic_any_resource::allocate) on the wrapped object with the specified arguments. + + +```cpp showLineNumbers={false} +void * cuda::mr::basic_any_resource<_Kind, _Properties>::allocate( + cuda::stream_ref __stream, + size_t __size, + size_t __align +) +``` + + + +The returned pointer is not valid until `__stream` has been synchronized. + + + +`_Kind` is [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is `true`. + + +**Returns:** `obj.allocate(__stream, __size, __align)`, where `obj` is the wrapped object. + + + + +Equivalent to `allocate(__stream, __size, +alignof(::cuda::std::max_align_t))`. + + +```cpp showLineNumbers={false} +void * cuda::mr::basic_any_resource<_Kind, _Properties>::allocate( + cuda::stream_ref __stream, + size_t __size +) +``` + + + + + +### deallocate + + + + +Calls [`deallocate`](/library/api/cuda::mr::basic_any_resource::deallocate) on the wrapped object with the specified arguments. + + +```cpp showLineNumbers={false} +void cuda::mr::basic_any_resource<_Kind, _Properties>::deallocate( + cuda::stream_ref __stream, + void *__pv, + size_t __size, + size_t __align +) +``` + + + +`_Kind` is [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is `true`. + + + +`__pv` must be a pointer that was previously returned by a call to `allocate` on the object wrapped by `*this`. + + +**Returns:** `obj.deallocate(__stream, __pv, __size, __align)`, where `obj` is the wrapped object. + + + + +Equivalent to `deallocate(__stream, __pv, __size, +alignof(::cuda::std::max_align_t), __stream)`. + + +```cpp showLineNumbers={false} +void cuda::mr::basic_any_resource<_Kind, _Properties>::deallocate( + cuda::stream_ref __stream, + void *__pv, + size_t __size +) +``` + + + + + +### has_value const noexcept nodiscard + +Checks if `*this` holds a value. + + +```cpp showLineNumbers={false} +bool cuda::mr::basic_any_resource<_Kind, _Properties>::has_value() const noexcept +``` + + +**Returns:** `true` if `*this` holds a value; `false` otherwise. + +### reset noexcept + +Resets `*this` to the empty state. + + +```cpp showLineNumbers={false} +void cuda::mr::basic_any_resource<_Kind, _Properties>::reset() noexcept +``` + + + +[`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value())` == false` + + +### type const noexcept nodiscard + + +```cpp showLineNumbers={false} +const cuda::std::type_info & cuda::mr::basic_any_resource<_Kind, _Properties>::type() const noexcept +``` + + +**Returns:** A reference to the `type_info` object for the wrapped resource, or `typeid(void)` if [`has_value()`](/library/api/cuda::mr::basic_any_resource::has_value()) is `false`. diff --git a/fern/cudapages/cuda/cuda/cuda/mr/basic_resource_ref.mdx b/fern/cudapages/cuda/cuda/cuda/mr/basic_resource_ref.mdx new file mode 100644 index 0000000..757cb31 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/basic_resource_ref.mdx @@ -0,0 +1,375 @@ +--- +title: "cuda::mr::basic_resource_ref" +description: "Type erased wrapper around a reference to an object that satisfies the `resource` concept and that provides the requested `_Properties`." +--- + +Type erased wrapper around a reference to an object that satisfies the `resource` concept and that provides the requested `_Properties`. + +`basic_resource_ref` models the `cuda::std::copyable` and `cuda::std::equality_comparable` concepts. + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + +The properties that any resource wrapped within the `basic_resource_ref` needs to provide. + + + + + +--- + +## Constructors + +### basic_resource_ref + + + + +Copy constructs a `basic_resource_ref`. + + +```cpp showLineNumbers={false} +cuda::mr::basic_resource_ref<_Kind, _Properties>::basic_resource_ref( + const basic_resource_ref &__other +) +``` + + + +`*this` and `__other` both refer to the same resource object. + + + + + +Constructs a `basic_resource_ref` from a reference to a type that satisfies the `resource` concept and that supports all of the specified properties. + + +```cpp showLineNumbers={false} +template +cuda::mr::basic_resource_ref<_Kind, _Properties>::basic_resource_ref( + _Resource &__res +) +``` + + + +[`synchronous_resource_with`](/library/api/cuda::mr::synchronous_resource_with)`<_Resource, _Properties...>` is `true`. + + + +If `_Kind` is [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous), [`resource_with`](/library/api/cuda::mr::resource_with)`<_Resource, _Properties...>` is `true`. + + + +If `__res` refers to a specialization of [`basic_any_resource`](/library/api/cuda::mr::basic_any_resource) or a type derived from such, `__res.has_value()` is `true`. + + +**Parameters** + + +The resource reference to be wrapped. + + + + + +Conversion from type-erased resource reference with a superset of the required properties. + + +```cpp showLineNumbers={false} +template <_ResourceKind _OtherKind, class... _OtherProperties> +cuda::mr::basic_resource_ref<_Kind, _Properties>::basic_resource_ref( + basic_resource_ref<_OtherKind, _OtherProperties...> __res +) +``` + + + +`*this` and `__res` both refer to the same resource object. + + + +`_OtherKind` is equal to either `_Kind` or [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +The set `_OtherProperties...` is a superset of `_Properties...`. + + +**Parameters** + + +The other type-erased resource reference to copy from. + + + + + +--- + +## Assignment operators + +### operator= + + + + +Rebinds `*this` to refer to the object to which `__other` refers. + + +```cpp showLineNumbers={false} +basic_resource_ref & cuda::mr::basic_resource_ref<_Kind, _Properties>::operator=( + const basic_resource_ref &__other +) +``` + + + +`*this` and `__other` both refer to the same resource object. + + + + + +Rebinds the wrapped reference to an object whose type satisfies the `resource` concept and that supports all of the specified properties. + + +```cpp showLineNumbers={false} +template +basic_resource_ref & cuda::mr::basic_resource_ref<_Kind, _Properties>::operator=( + _Resource &__res +) +``` + + + +[`synchronous_resource_with`](/library/api/cuda::mr::synchronous_resource_with)`<_Resource, _Properties...>` is `true`. + + + +If `_Kind` is [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous), [`synchronous_resource_with`](/library/api/cuda::mr::synchronous_resource_with)`<_Resource, _Properties...>` is `true`. + + + +If `__res` refers to a specialization of [`basic_any_resource`](/library/api/cuda::mr::basic_any_resource) or a type derived from such, `__res.has_value()` is `true`. + + +**Parameters** + + +The reference to the resource to be wrapped by the `basic_resource_ref`. + + + + + +Rebinds `*this` to refer to the object to which `__other` refers. + + +```cpp showLineNumbers={false} +template <_ResourceKind _OtherKind, class... _OtherProperties> +basic_resource_ref & cuda::mr::basic_resource_ref<_Kind, _Properties>::operator=( + basic_resource_ref<_OtherKind, _OtherProperties...> __res +) +``` + + + +`*this` and `__res` both refer to the same resource object. + + + +`_OtherKind` is equal to either `_Kind` or [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +The set `_OtherProperties...` is a superset of `_Properties...`. + + +**Parameters** + + +The other type-erased resource reference to copy from. + + + + + +--- + +## Methods + +### operator== const nodiscard + +Equality comparison between two type-erased resource references. + + +```cpp showLineNumbers={false} +template <_ResourceKind _OtherKind, class... _OtherProperties> +bool cuda::mr::basic_resource_ref<_Kind, _Properties>::operator==( + const basic_resource_ref<_OtherKind, _OtherProperties...> &__rhs +) const +``` + + + +`_OtherKind` is equal to either `_Kind` or [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +The set `_Properties...` is equal to the set `_OtherProperties...`. + + +**Returns:** `true` if both resources refer to objects of the same type and those objects compare equal. Otherwise, returns `false`. + +**Parameters** + + +The other type-erased resource reference. + + +### allocate_sync nodiscard + +Calls [`allocate_sync`](/library/api/cuda::mr::basic_resource_ref::allocate_sync) on the wrapped reference with the specified arguments. + + +```cpp showLineNumbers={false} +void * cuda::mr::basic_resource_ref<_Kind, _Properties>::allocate_sync( + size_t __size, + size_t __align = alignof(cuda::std::max_align_t) +) +``` + + +**Returns:** `obj.allocate_sync(__size, __align)`, where `obj` is the wrapped reference. + +### deallocate_sync + +Calls [`deallocate_sync`](/library/api/cuda::mr::basic_resource_ref::deallocate_sync) on the wrapped reference with the specified arguments. + + +```cpp showLineNumbers={false} +void cuda::mr::basic_resource_ref<_Kind, _Properties>::deallocate_sync( + void *__pv, + size_t __size, + size_t __align = alignof(cuda::std::max_align_t) +) +``` + + + +`__pv` must be a pointer that was previously returned by a call to `allocate` on the object referenced by `*this`. + + +**Returns:** `obj.deallocate_sync(__pv, __size, __align)`, where `obj` is the wrapped reference. + +### allocate nodiscard + + + + +Calls [`allocate`](/library/api/cuda::mr::basic_resource_ref::allocate) on the wrapped reference with the specified arguments. + + +```cpp showLineNumbers={false} +void * cuda::mr::basic_resource_ref<_Kind, _Properties>::allocate( + cuda::stream_ref __stream, + size_t __size, + size_t __align +) +``` + + + +The returned pointer is not valid until `__stream` has been synchronized. + + + +`_Kind` is [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + +**Returns:** `obj.allocate(__stream, __size, __align)`, where `obj` is the wrapped reference. + + + + +Equivalent to `allocate(__stream, __size, +alignof(::cuda::std::max_align_t))`. + + +```cpp showLineNumbers={false} +void * cuda::mr::basic_resource_ref<_Kind, _Properties>::allocate( + cuda::stream_ref __stream, + size_t __size +) +``` + + + + + +### deallocate + + + + +Calls [`deallocate`](/library/api/cuda::mr::basic_resource_ref::deallocate) on the wrapped reference with the specified arguments. + + +```cpp showLineNumbers={false} +void cuda::mr::basic_resource_ref<_Kind, _Properties>::deallocate( + cuda::stream_ref __stream, + void *__pv, + size_t __size, + size_t __align +) +``` + + + +`_Kind` is [`_ResourceKind::_Asynchronous`](/library/api/cuda::mr::_Asynchronous). + + + +`__pv` must be a pointer that was previously returned by a call to `allocate` on the object referenced by `*this`. + + +**Returns:** `obj.deallocate(__stream, __pv, __size, __align)`, where `obj` is the wrapped reference. + + + + +Equivalent to `deallocate(__stream, __pv, __size, +alignof(::cuda::std::max_align_t), __stream)`. + + +```cpp showLineNumbers={false} +void cuda::mr::basic_resource_ref<_Kind, _Properties>::deallocate( + cuda::stream_ref __stream, + void *__pv, + size_t __size +) +``` + + + + + +### type const noexcept nodiscard + + +```cpp showLineNumbers={false} +const cuda::std::type_info & cuda::mr::basic_resource_ref<_Kind, _Properties>::type() const noexcept +``` + + +**Returns:** A reference to the `type_info` object for the type of the object to which `*this` refers. diff --git a/fern/cudapages/cuda/cuda/cuda/mr/device_accessible.mdx b/fern/cudapages/cuda/cuda/cuda/mr/device_accessible.mdx new file mode 100644 index 0000000..c7aade1 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/device_accessible.mdx @@ -0,0 +1,10 @@ +--- +title: "cuda::mr::device_accessible" +description: "The [device_accessible](/library/api/cuda::mr::device_accessible) property signals that the allocated memory is device accessible." +--- + +The `device_accessible` property signals that the allocated memory is device accessible. + +```cpp showLineNumbers={false} +#include +``` diff --git a/fern/cudapages/cuda/cuda/cuda/mr/host_accessible.mdx b/fern/cudapages/cuda/cuda/cuda/mr/host_accessible.mdx new file mode 100644 index 0000000..733e9f2 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/host_accessible.mdx @@ -0,0 +1,10 @@ +--- +title: "cuda::mr::host_accessible" +description: "The [device_accessible](/library/api/cuda::mr::device_accessible) property signals that the allocated memory is host accessible." +--- + +The [device_accessible](/library/api/cuda::mr::device_accessible) property signals that the allocated memory is host accessible. + +```cpp showLineNumbers={false} +#include +``` diff --git a/fern/cudapages/cuda/cuda/cuda/mr/legacy_managed_memory_resource.mdx b/fern/cudapages/cuda/cuda/cuda/mr/legacy_managed_memory_resource.mdx new file mode 100644 index 0000000..b5a4792 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/legacy_managed_memory_resource.mdx @@ -0,0 +1,120 @@ +--- +title: "cuda::mr::legacy_managed_memory_resource" +description: "`managed_memory_resource` uses `cudaMallocManaged` / `cudaFree` for allocation / deallocation." +--- + +`managed_memory_resource` uses `cudaMallocManaged` / `cudaFree` for allocation / deallocation. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Constructors + +### legacy_managed_memory_resource inline constexpr noexcept + +Construct a new `legacy_managed_memory_resource`. + + +```cpp showLineNumbers={false} +cuda::mr::legacy_managed_memory_resource::legacy_managed_memory_resource( + const unsigned int __flags = cudaMemAttachGlobal, + device_ref __device = {0} +) noexcept +``` + + + +Synchronous allocations in CUDA are tied to a device, even if not located in device memory. This constructor takes an optional device argument to specify the device that should be tied to allocations for the resource. This association has the effect of initializing that device and the memory being implicitly freed if the device is reset. + + +--- + +## Methods + +### allocate_sync inline nodiscard + +Allocate CUDA unified memory of size at least `__bytes`. + + +```cpp showLineNumbers={false} +void * cuda::mr::legacy_managed_memory_resource::allocate_sync( + const size_t __bytes, + const size_t __alignment = ::cuda::mr::default_cuda_malloc_alignment +) +``` + + +**Returns:** Pointer to the newly allocated memory + +**Throws:** `std::invalid_argument` in case of invalid alignment or `cuda::cuda_error` of the returned error code. + +**Parameters** + + +The size in bytes of the allocation. + + + +The requested alignment of the allocation. + + +### deallocate_sync inline noexcept + +Deallocate memory pointed to by `__ptr`. + + +```cpp showLineNumbers={false} +void cuda::mr::legacy_managed_memory_resource::deallocate_sync( + void *__ptr, + const size_t __bytes, + const size_t __alignment = ::cuda::mr::default_cuda_malloc_alignment +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated. Must have been allocated through a call to `allocate` or [`allocate_sync`](/library/api/cuda::mr::legacy_managed_memory_resource::allocate_sync) + + + +The number of bytes that was passed to the allocation call that returned `__ptr`. + + + +The alignment that was passed to the allocation call that returned `__ptr`. + + +### operator== inline constexpr const noexcept nodiscard + +Equality comparison with another `managed_memory_resource`. + + +```cpp showLineNumbers={false} +bool cuda::mr::legacy_managed_memory_resource::operator==( + legacy_managed_memory_resource const &__other +) const noexcept +``` + + +**Returns:** Whether both `managed_memory_resource` were constructed with the same flags. + +**Parameters** + + +The other `managed_memory_resource`. + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `default_queries` | `::cuda::mr::properties_list<::cuda::mr::device_accessible, ::cuda::mr::host_accessible >` | diff --git a/fern/cudapages/cuda/cuda/cuda/mr/legacy_pinned_memory_resource.mdx b/fern/cudapages/cuda/cuda/cuda/mr/legacy_pinned_memory_resource.mdx new file mode 100644 index 0000000..8c01ace --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/legacy_pinned_memory_resource.mdx @@ -0,0 +1,117 @@ +--- +title: "cuda::mr::legacy_pinned_memory_resource" +description: "[legacy_pinned_memory_resource](/library/api/cuda::mr::legacy_pinned_memory_resource) uses `cudaMallocHost` / `cudaFreeAsync` for allocation / deallocation." +--- + +`legacy_pinned_memory_resource` uses `cudaMallocHost` / `cudaFreeAsync` for allocation / deallocation. + +```cpp showLineNumbers={false} +#include +``` + + +This memory resource will be deprecated in the future. For CUDA 12.6 and above, use `cuda::pinned_memory_resource` instead, which is the long-term replacement. + + +--- + +## Constructors + +### legacy_pinned_memory_resource inline constexpr noexcept + +Construct a new `legacy_pinned_memory_resource`. + + +```cpp showLineNumbers={false} +cuda::mr::legacy_pinned_memory_resource::legacy_pinned_memory_resource( + ::cuda::device_ref __device = {0} +) noexcept +``` + + + +Synchronous allocations in CUDA are tied to a device, even if not located in device memory. This constructor takes an optional device argument to specify the device that should be tied to allocations for the resource. This association has the effect of initializing that device and the memory being implicitly freed if the device is reset. + + +--- + +## Methods + +### allocate_sync inline nodiscard + +Allocate host memory of size at least `__bytes`. + + +```cpp showLineNumbers={false} +void * cuda::mr::legacy_pinned_memory_resource::allocate_sync( + const size_t __bytes, + const size_t __alignment = ::cuda::mr::default_cuda_malloc_alignment +) +``` + + +**Returns:** Pointer to the newly allocated memory + +**Throws:** `std::invalid_argument` in case of invalid alignment or `cuda::cuda_error` of the returned error code. + +**Parameters** + + +The size in bytes of the allocation. + + + +The requested alignment of the allocation. + + +### deallocate_sync inline noexcept + +Deallocate memory pointed to by `__ptr`. + + +```cpp showLineNumbers={false} +void cuda::mr::legacy_pinned_memory_resource::deallocate_sync( + void *__ptr, + const size_t __bytes, + const size_t __alignment = ::cuda::mr::default_cuda_malloc_alignment +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated. Must have been allocated through a call to [`allocate_sync`](/library/api/cuda::mr::legacy_pinned_memory_resource::allocate_sync). + + + +The number of bytes that was passed to the allocation call that returned `__ptr`. + + + +The alignment that was passed to the allocation call that returned `__ptr`. + + +### operator== inline constexpr const noexcept nodiscard + +Equality comparison with another `legacy_pinned_memory_resource`. + + +```cpp showLineNumbers={false} +bool cuda::mr::legacy_pinned_memory_resource::operator==( + legacy_pinned_memory_resource const & +) const noexcept +``` + + +**Returns:** Whether both `legacy_pinned_memory_resource` were constructed with the same flags. + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `default_queries` | `::cuda::mr::properties_list<::cuda::mr::device_accessible, ::cuda::mr::host_accessible >` | diff --git a/fern/cudapages/cuda/cuda/cuda/mr/properties_list.mdx b/fern/cudapages/cuda/cuda/cuda/mr/properties_list.mdx new file mode 100644 index 0000000..5c4bf6e --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/properties_list.mdx @@ -0,0 +1,45 @@ +--- +title: "cuda::mr::properties_list" +description: "A type representing a list of memory resource properties." +--- + +A type representing a list of memory resource properties. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The properties to be included in the list It has a member template [`rebind`](/library/api/cuda::mr::properties_list::rebind) that allows constructing a type by combining a template and type arguments with the properties from this list. The properties are appended after the type arguments in the resulting type. + + + + + +--- + +## Static methods + +### has_property inline static constexpr + + +```cpp showLineNumbers={false} +template +static constexpr bool cuda::mr::properties_list<_Properties>::has_property( + _QueryProperty +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `rebind` | `_Fn< _ExtraArgs..., _Properties... >` | A type alias for a type template instantiated with the properties from this list appended to the type arguments. | diff --git a/fern/cudapages/cuda/cuda/cuda/mr/resource.mdx b/fern/cudapages/cuda/cuda/cuda/mr/resource.mdx new file mode 100644 index 0000000..322624d --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/resource.mdx @@ -0,0 +1,39 @@ +--- +title: "cuda::mr::resource" +description: "The `resource` concept verifies that a type Resource satisfies the basic requirements of a memory resource and additionally supports stream ordered allocations." +--- + +C++20 concept + +The `resource` concept verifies that a type Resource satisfies the basic requirements of a memory resource and additionally supports stream ordered allocations. + + +```cpp showLineNumbers={false} +template +concept resource = /* see description */; +``` + + + + + + +The type that should implement the resource concept + + + + + +--- + +## Description + +We require that an resource supports the following interface + + - `allocate(size_t bytes, size_t alignment)` + - `deallocate(void* ptr, size_t bytes, size_t alignment)` + - `T() == T()` + - `T() != T()` + + - `allocate(cuda::stream_ref stream, size_t bytes, size_t alignment)` + - `deallocate( cuda::stream_ref stream, void* ptr, size_t bytes, size_t alignment)` diff --git a/fern/cudapages/cuda/cuda/cuda/mr/resource_with.mdx b/fern/cudapages/cuda/cuda/cuda/mr/resource_with.mdx new file mode 100644 index 0000000..f6d8218 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/resource_with.mdx @@ -0,0 +1,27 @@ +--- +title: "cuda::mr::resource_with" +description: "The `resource_with` concept verifies that a type Resource satisfies the [`resource`](/library/api/cuda::mr::resource) concept and also satisfies all the provided Properties." +--- + +C++20 concept + +The `resource_with` concept verifies that a type Resource satisfies the [`resource`](/library/api/cuda::mr::resource) concept and also satisfies all the provided Properties. + + +```cpp showLineNumbers={false} +template +concept resource_with = /* see description */; +``` + + + + + + + + + + + + + diff --git a/fern/cudapages/cuda/cuda/cuda/mr/shared_resource.mdx b/fern/cudapages/cuda/cuda/cuda/mr/shared_resource.mdx new file mode 100644 index 0000000..cd7c091 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/shared_resource.mdx @@ -0,0 +1,435 @@ +--- +title: "cuda::mr::shared_resource" +description: "" +--- + +`shared_resource` holds a reference counted instance of a memory resource. This allows the user to pass a resource around with reference semantics while avoiding lifetime issues. + +@note `shared_resource` satisfies the `cuda::mr::resource` concept iff \tparam _Resource satisfies it. @tparam _Resource The resource type to hold. + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `cuda::mr::__copy_default_queries< _Resource >` (public) + +--- + +## Constructors + +### shared_resource inline + + + + +explicit + +Constructs a `shared_resource` referring to an object of type `_Resource` that has been constructed with arguments `__args`. + +The `_Resource` object is dynamically allocated with `new`. + + +```cpp showLineNumbers={false} +template +cuda::mr::shared_resource<_Resource>::shared_resource( + ::cuda::std::in_place_type_t<_Resource>, + _Args &&... __args +) +``` + + +**Parameters** + + +The arguments to be passed to the `_Resource` constructor. + + + + + +noexcept + +Copy-constructs a `shared_resource` object resulting in an copy that shares ownership of the wrapped resource with `__other`. + + +```cpp showLineNumbers={false} +cuda::mr::shared_resource<_Resource>::shared_resource( + const shared_resource &__other +) noexcept +``` + + +**Parameters** + + +The `shared_resource` object to copy from. + + + + + +noexcept + +Move-constructs a `shared_resource` assuming ownership of the resource stored in `__other`. + + +```cpp showLineNumbers={false} +cuda::mr::shared_resource<_Resource>::shared_resource( + shared_resource &&__other +) noexcept +``` + + + +`__other` is left in a valid but unspecified state. + + +**Parameters** + + +The `shared_resource` object to move from. + + + + + +### Destructor + +### ~shared_resource inline + +Releases the reference held by this `shared_resource` object. + +If this is the last reference to the wrapped resource, the resource is deleted. + + +```cpp showLineNumbers={false} +cuda::mr::shared_resource<_Resource>::~shared_resource() +``` + + +--- + +## Assignment operators + +### operator= inline noexcept + + + + +Copy-assigns from `__other`. + +Self-assignment is a no-op. Otherwise, the reference held by this `shared_resource` object is released and a new reference is acquired to the wrapped resource of `__other`, if any. + + +```cpp showLineNumbers={false} +shared_resource & cuda::mr::shared_resource<_Resource>::operator=( + const shared_resource &__other +) noexcept +``` + + +**Parameters** + + +The `shared_resource` object to copy from. + + + + + +Move-assigns from `__other`. + +Self-assignment is a no-op. Otherwise, the reference held by this `shared_resource` object is released, while the reference held by `__other` is transferred to this object. + + +```cpp showLineNumbers={false} +shared_resource & cuda::mr::shared_resource<_Resource>::operator=( + shared_resource &&__other +) noexcept +``` + + + +`__other` is left in a valid but unspecified state. + + +**Parameters** + + +The `shared_resource` object to move from. + + + + + +--- + +## Methods + +### swap inline noexcept + +Swaps a `shared_resource` with another one. + + +```cpp showLineNumbers={false} +void cuda::mr::shared_resource<_Resource>::swap( + shared_resource &__other +) noexcept +``` + + +**Parameters** + + +The other `shared_resource`. + + +### get inline noexcept nodiscard + + + + +Returns a reference to the stored resource. + + +```cpp showLineNumbers={false} +_Resource & cuda::mr::shared_resource<_Resource>::get() noexcept +``` + + +**Returns:** A reference to the stored resource. + + + + +const + +Returns a const reference to the stored resource. + + +```cpp showLineNumbers={false} +const _Resource & cuda::mr::shared_resource<_Resource>::get() const noexcept +``` + + +**Returns:** A const reference to the stored resource. + + + + +### operator-> inline noexcept nodiscard + + + + +Returns a pointer to the stored resource. + + +```cpp showLineNumbers={false} +_Resource * cuda::mr::shared_resource<_Resource>::operator->() noexcept +``` + + +**Returns:** A pointer to the stored resource. + + + + +const + +Returns a const pointer to the stored resource. + + +```cpp showLineNumbers={false} +const _Resource * cuda::mr::shared_resource<_Resource>::operator->() const noexcept +``` + + +**Returns:** A const pointer to the stored resource. + + + + +### operator* inline noexcept nodiscard + + + + +Returns a reference to the stored resource. + + +```cpp showLineNumbers={false} +_Resource & cuda::mr::shared_resource<_Resource>::operator*() noexcept +``` + + +**Returns:** A reference to the stored resource. + + + + +const + +Returns a const reference to the stored resource. + + +```cpp showLineNumbers={false} +const _Resource & cuda::mr::shared_resource<_Resource>::operator*() const noexcept +``` + + +**Returns:** A const reference to the stored resource. + + + + +### allocate_sync inline nodiscard + +Allocate memory of size at least `__bytes` using the stored resource. + + +```cpp showLineNumbers={false} +void * cuda::mr::shared_resource<_Resource>::allocate_sync( + size_t __bytes, + size_t __alignment = alignof(::cuda::std::max_align_t) +) +``` + + +**Returns:** Pointer to the newly allocated memory + +**Parameters** + + +The size in bytes of the allocation. + + + +The requested alignment of the allocation. + + +### deallocate_sync inline noexcept + +Deallocate memory pointed to by `__ptr` using the stored resource. + + +```cpp showLineNumbers={false} +void cuda::mr::shared_resource<_Resource>::deallocate_sync( + void *__ptr, + size_t __bytes, + size_t __alignment = alignof(::cuda::std::max_align_t) +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated. Must have been allocated through a call to [`allocate`](/library/api/cuda::mr::shared_resource::allocate) or [`allocate_sync`](/library/api/cuda::mr::shared_resource::allocate_sync) + + + +The number of bytes that was passed to the allocation call that returned `__ptr`. + + + +The alignment that was passed to the allocation call that returned `__ptr`. + + +### allocate inline nodiscard + +Enqueues an allocation of memory of size at least `__bytes` using the wrapped resource. + +The allocation is performed asynchronously on stream `__stream`. + + +```cpp showLineNumbers={false} +template +void * cuda::mr::shared_resource<_Resource>::allocate( + ::cuda::stream_ref __stream, + size_t __bytes, + size_t __alignment +) +``` + + + +The caller is responsible for ensuring that the memory is not accessed until the operation has completed. + + + +`_Resource` must satisfy `resource`. + + +**Returns:** Pointer to the newly allocated memory. + +**Parameters** + + +The size in bytes of the allocation. + + + +The requested alignment of the allocation. + + +### deallocate inline noexcept + +Enqueues the deallocation of memory pointed to by `__ptr`. + +The deallocation is performed asynchronously on stream `__stream`. + + +```cpp showLineNumbers={false} +template +void cuda::mr::shared_resource<_Resource>::deallocate( + ::cuda::stream_ref __stream, + void *__ptr, + size_t __bytes, + size_t __alignment +) noexcept +``` + + + +The caller is responsible for ensuring that the memory is not accessed after the operation has completed. + + + +`_Resource` must satisfy `resource`. + + +**Parameters** + + +Pointer to be deallocated. Must have been allocated through a call to [`allocate`](/library/api/cuda::mr::shared_resource::allocate) or [`allocate_sync`](/library/api/cuda::mr::shared_resource::allocate_sync) + + + +The number of bytes that was passed to the allocation call that returned `__ptr`. + + + +The alignment that was passed to the allocation call that returned `__ptr`. + + +--- + +## Inner classes + +### _Control_block + + +```cpp showLineNumbers={false} +struct cuda::mr::shared_resource::_Control_block +``` + diff --git a/fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource.mdx b/fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource.mdx new file mode 100644 index 0000000..3743179 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource.mdx @@ -0,0 +1,36 @@ +--- +title: "cuda::mr::synchronous_resource" +description: "The `synchronous_resource` concept verifies that a type Resource satisfies the basic requirements of a memory resource." +--- + +C++20 concept + +The `synchronous_resource` concept verifies that a type Resource satisfies the basic requirements of a memory resource. + + +```cpp showLineNumbers={false} +template +concept synchronous_resource = /* see description */; +``` + + + + + + +The type that should implement the synchronous resource concept + + + + + +--- + +## Description + +We require that a resource supports the following interface + + - `allocate(size_t bytes, size_t alignment)` + - `deallocate(void* ptr, size_t bytes, size_t alignment)` + - `T() == T()` + - `T() != T()` diff --git a/fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource_adapter.mdx b/fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource_adapter.mdx new file mode 100644 index 0000000..e34ea13 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource_adapter.mdx @@ -0,0 +1,146 @@ +--- +title: "cuda::mr::synchronous_resource_adapter" +description: "Adapter that allows a synchronous resource to be used as a resource It examines the resource for the presence of the allocate and deallocate members." +--- + +Adapter that allows a synchronous resource to be used as a resource It examines the resource for the presence of the allocate and deallocate members. + +If they are present, it passes through the allocate and deallocate calls to the contained resource. Otherwise, it uses the allocate_sync and deallocate_sync members (with proper synchronization in case of deallocate). + +```cpp showLineNumbers={false} +#include +``` + + +This adapter takes ownership of the contained resource. + + + + + + +The type of the resource to be adapted + + + + + +**Inherits from:** `cuda::mr::__copy_default_queries< _Resource >` (public), `cuda::forward_property< synchronous_resource_adapter< _Resource >, _Resource >` (public) + +--- + +## Constructors + +### synchronous_resource_adapter inline noexcept + + + + + +```cpp showLineNumbers={false} +cuda::mr::synchronous_resource_adapter<_Resource>::synchronous_resource_adapter( + const _Resource &__resource +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +cuda::mr::synchronous_resource_adapter<_Resource>::synchronous_resource_adapter( + _Resource &&__resource +) noexcept +``` + + + + + +--- + +## Methods + +### allocate inline nodiscard + + +```cpp showLineNumbers={false} +void * cuda::mr::synchronous_resource_adapter<_Resource>::allocate( + const ::cuda::stream_ref __stream, + const size_t __bytes, + const size_t __alignment +) +``` + + +### allocate_sync inline nodiscard + + +```cpp showLineNumbers={false} +void * cuda::mr::synchronous_resource_adapter<_Resource>::allocate_sync( + const size_t __bytes, + const size_t __alignment +) +``` + + +### deallocate inline noexcept + + +```cpp showLineNumbers={false} +void cuda::mr::synchronous_resource_adapter<_Resource>::deallocate( + const ::cuda::stream_ref __stream, + void *__ptr, + const size_t __bytes, + const size_t __alignment +) noexcept +``` + + +### deallocate_sync inline noexcept + + +```cpp showLineNumbers={false} +void cuda::mr::synchronous_resource_adapter<_Resource>::deallocate_sync( + void *__ptr, + const size_t __bytes, + const size_t __alignment +) noexcept +``` + + +### operator== inline const noexcept nodiscard + + +```cpp showLineNumbers={false} +bool cuda::mr::synchronous_resource_adapter<_Resource>::operator==( + const synchronous_resource_adapter &__rhs +) const noexcept +``` + + +### upstream_resource inline noexcept + + + + + +```cpp showLineNumbers={false} +_Resource & cuda::mr::synchronous_resource_adapter<_Resource>::upstream_resource() noexcept +``` + + + + + +const + + +```cpp showLineNumbers={false} +const _Resource & cuda::mr::synchronous_resource_adapter<_Resource>::upstream_resource() const noexcept +``` + + + + diff --git a/fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource_with.mdx b/fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource_with.mdx new file mode 100644 index 0000000..bbc20c6 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/mr/synchronous_resource_with.mdx @@ -0,0 +1,27 @@ +--- +title: "cuda::mr::synchronous_resource_with" +description: "The `resource_with` concept verifies that a type Resource satisfies the [`synchronous_resource`](/library/api/cuda::mr::synchronous_resource) concept and also satisfies all the provided Properties." +--- + +C++20 concept + +The `resource_with` concept verifies that a type Resource satisfies the [`synchronous_resource`](/library/api/cuda::mr::synchronous_resource) concept and also satisfies all the provided Properties. + + +```cpp showLineNumbers={false} +template +concept synchronous_resource_with = /* see description */; +``` + + + + + + + + + + + + + diff --git a/fern/cudapages/cuda/cuda/cuda/permutation_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/permutation_iterator.mdx new file mode 100644 index 0000000..1a74268 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/permutation_iterator.mdx @@ -0,0 +1,367 @@ +--- +title: "cuda::permutation_iterator" +description: "[`permutation_iterator`](/libcudacxx/api/cuda::permutation_iterator) is an iterator which represents a pointer into a reordered view of a given range." +--- + +`permutation_iterator` is an iterator which represents a pointer into a reordered view of a given range. + +`permutation_iterator` is an imprecise name; the reordered view need not be a strict permutation. This iterator is useful for fusing a scatter or gather operation with other algorithms. + +This iterator takes two arguments: + +- an iterator to the range `V` on which the "permutation" will be applied, referred to as `iter` below +- an iterator to a range of indices defining the reindexing scheme that determines how the elements of `V` will be permuted, referred to as `index` below + +Note that `permutation_iterator` is not limited to strict permutations of the given range `V`. The distance between begin and end of the reindexing iterators is allowed to be smaller compared to the size of the range `V`, in which case the `permutation_iterator` only provides a "permutation" of a subset of `V`. The indices do not need to be unique. In this same context, it must be noted that the past-the-end `permutation_iterator` is completely defined by means of the past-the-end iterator to the indices. + +The following code snippet demonstrates how to create a `permutation_iterator` which represents a reordering of the contents of a `device_vector`. + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector values{10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f, 70.0f, 80.0f}; +thrust::device_vector indices{2, 6, 1, 3}; + +using ElementIterator = thrust::device_vector::iterator; +using IndexIterator = thrust::device_vector::iterator; + +cuda::permutation_iterator iter(values.begin(), indices.begin()); + +*iter; // returns 30.0f; +iter[0]; // returns 30.0f; +iter[1]; // returns 70.0f; +iter[2]; // returns 20.0f; +iter[3]; // returns 40.0f; + +// iter[4] is an out-of-bounds error + +*iter = -1.0f; // sets values[2] to -1.0f; +iter[0] = -1.0f; // sets values[2] to -1.0f; +iter[1] = -1.0f; // sets values[6] to -1.0f; +iter[2] = -1.0f; // sets values[1] to -1.0f; +iter[3] = -1.0f; // sets values[3] to -1.0f; + +// values is now {10, -1, -1, -1, 50, 60, -1, 80} +``` + + + + + + + + + + + + + +--- + +## Constructors + +### permutation_iterator constexpr + + + + +Ensure that the user passes an iterator to something interger_like. + +Ensure that the index [value_type](/libcudacxx/api/cuda::permutation_iterator::value_type) is convertible to [difference_type](/libcudacxx/api/cuda::permutation_iterator::difference_type) To actually use operator+ we need the index iterator to be random access To actually use operator+ we need the base iterator to be random access + +Default constructs an `permutation_iterator` with a value initialized iterator and index + + +```cpp showLineNumbers={false} +cuda::permutation_iterator<_Iter, _Index>::permutation_iterator() = default +``` + + + + + +inline noexcept + +Constructs an `permutation_iterator` from an iterator and an optional index. + + +```cpp showLineNumbers={false} +cuda::permutation_iterator<_Iter, _Index>::permutation_iterator( + _Iter __iter, + _Index __index +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter > &&::cuda::std::is_nothrow_copy_constructible_v< _Index >) +``` + + +**Parameters** + + +The iterator to to index from + + + +The iterator with the permutations + + + + + +--- + +## Methods + +### base inline constexpr noexcept nodiscard + + + + +Extracts the stored base iterator `iter`. + + +```cpp showLineNumbers={false} +_Iter cuda::permutation_iterator<_Iter, _Index>::base() && noexcept(::cuda::std::is_nothrow_move_constructible_v< _Iter >) +``` + + + + + +const + +Returns a const reference to the stored base iterator `iter`. + + +```cpp showLineNumbers={false} +const _Iter & cuda::permutation_iterator<_Iter, _Index>::base() const & noexcept +``` + + + + + +### index inline constexpr const noexcept nodiscard + +Returns the current index. + + +```cpp showLineNumbers={false} +difference_type cuda::permutation_iterator<_Iter, _Index>::index() const noexcept +``` + + +**Returns:** Equivalent to `*index` + +### operator* inline constexpr noexcept nodiscard + + + + +Dereferences the `permutation_iterator`. + + +```cpp showLineNumbers={false} +decltype( + auto +) noexcept(__iter_[static_cast< __iter_difference_t >(*__index_)]) +``` + + +**Returns:** Equivalent to `iter[*index]` + + + + +const + +Dereferences the `permutation_iterator`. + + +```cpp showLineNumbers={false} +template +decltype( + auto +) const noexcept(__iter_[static_cast< __iter_difference_t >(*__index_)]) +``` + + +**Returns:** Equivalent to `iter[*index]` + + + + +### operator[] inline constexpr noexcept nodiscard + + + + +Subscripts the `permutation_iterator` by an offset. + + +```cpp showLineNumbers={false} +decltype( + auto +) noexcept(__iter_[static_cast< __iter_difference_t >(__index_[__n])]) +``` + + +**Returns:** Equivalent to `iter[`[`index`](/libcudacxx/api/cuda::permutation_iterator::index)`[__n]]` + +**Parameters** + + +The additional offset + + + + + +const + +Subscripts the `permutation_iterator` by an offset. + + +```cpp showLineNumbers={false} +template +decltype( + auto +) const noexcept(__iter_[static_cast< __iter_difference_t >(__index_[__n])]) +``` + + +**Returns:** Equivalent to `iter[`[`index`](/libcudacxx/api/cuda::permutation_iterator::index)`[__n]]` + +**Parameters** + + +The additional offset + + + + + +### operator++ inline constexpr noexcept + + + + +Increments the `permutation_iterator`. + + +```cpp showLineNumbers={false} +permutation_iterator & cuda::permutation_iterator<_Iter, _Index>::operator++() noexcept(++__index_) +``` + + +**Returns:** Equivalent to `++`[`index`](/libcudacxx/api/cuda::permutation_iterator::index) + + + + +Increments the `permutation_iterator`. + + +```cpp showLineNumbers={false} +permutation_iterator cuda::permutation_iterator<_Iter, _Index>::operator++( + int +) noexcept(noexcept(++__index_) &&::cuda::std::is_nothrow_copy_constructible_v< _Iter > &&::cuda::std::is_nothrow_copy_constructible_v< _Index >) +``` + + +**Returns:** Equivalent to `index++` + + + + +### operator-- inline constexpr noexcept + + + + +Increments the `permutation_iterator`. + + +```cpp showLineNumbers={false} +permutation_iterator & cuda::permutation_iterator<_Iter, _Index>::operator--() noexcept(--__index_) +``` + + +**Returns:** Equivalent to `--`[`index`](/libcudacxx/api/cuda::permutation_iterator::index) + + + + +Increments the `permutation_iterator`. + + +```cpp showLineNumbers={false} +permutation_iterator cuda::permutation_iterator<_Iter, _Index>::operator--( + int +) noexcept(noexcept(--__index_) &&::cuda::std::is_nothrow_copy_constructible_v< _Iter > &&::cuda::std::is_nothrow_copy_constructible_v< _Index >) +``` + + +**Returns:** Equivalent to `index++` + + + + +### operator+= inline constexpr noexcept + +Advances the `permutation_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +permutation_iterator & cuda::permutation_iterator<_Iter, _Index>::operator+=( + difference_type __n +) noexcept(__index_+=__n) +``` + + +**Returns:** Equivalent to [`index`](/libcudacxx/api/cuda::permutation_iterator::index)` + __n` + +**Parameters** + + +The number of elements to advance + + +### operator-= inline constexpr noexcept + +Decrements the `permutation_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +permutation_iterator & cuda::permutation_iterator<_Iter, _Index>::operator-=( + difference_type __n +) noexcept(__index_ -=__n) +``` + + +**Returns:** Equivalent to [`index`](/libcudacxx/api/cuda::permutation_iterator::index)` - __n` + +**Parameters** + + +The number of elements to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_type` | `_Iter` | +| `iterator_concept` | `::cuda::std::random_access_iterator_tag` | +| `iterator_category` | `::cuda::std::random_access_iterator_tag` | +| `value_type` | `::cuda::std::iter_value_t< _Iter >` | +| `difference_type` | `::cuda::std::iter_difference_t< _Index >` | diff --git a/fern/cudapages/cuda/cuda/cuda/pinned_memory_pool.mdx b/fern/cudapages/cuda/cuda/cuda/pinned_memory_pool.mdx new file mode 100644 index 0000000..8e940c6 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/pinned_memory_pool.mdx @@ -0,0 +1,167 @@ +--- +title: "cuda::pinned_memory_pool" +description: "" +--- + +`pinned_memory_pool` allocates pinned memory using `cudaMallocFromPoolAsync / cudaFreeAsync +<https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`__ +for allocation/deallocation. When constructed it creates an underlying \c cudaMemPool_t with the location type set to \c cudaMemLocationTypeHost or \c cudaMemLocationTypeHostNuma and owns it. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `cuda::pinned_memory_pool_ref` (public) + +--- + +## Constructors + +### pinned_memory_pool inline + + + + +Constructs a `pinned_memory_pool` with optional properties. + +Properties include the initial pool size and the release threshold. If the pool size grows beyond the release threshold, unused memory held by the pool will be released at the next synchronization event. + + +```cpp showLineNumbers={false} +cuda::pinned_memory_pool::pinned_memory_pool( + memory_pool_properties __properties = {} +) +``` + + + +Memory from this pool is accessible from all devices right away, which differs from the default behavior of pinned memory pools where memory is not accessible from devices until `cudaMemPoolSetAccess` is called. + + +**Parameters** + + +Optional, additional properties of the pool to be created. + + + + + +Constructs a `pinned_memory_pool` with the specified NUMA node id and optional properties. + +Properties include the initial pool size and the release threshold. If the pool size grows beyond the release threshold, unused memory held by the pool will be released at the next synchronization event. + + +```cpp showLineNumbers={false} +cuda::pinned_memory_pool::pinned_memory_pool( + int __numa_id, + memory_pool_properties __properties = {} +) +``` + + + +Memory from this pool is accessible from all devices right away, which differs from the default behavior of pinned memory pools where memory is not accessible from devices until `cudaMemPoolSetAccess` is called. + + +**Parameters** + + +The NUMA node id of the NUMA node the pool is constructed on. + + + +Optional, additional properties of the pool to be created. + + + + + +noexcept + + +```cpp showLineNumbers={false} +cuda::pinned_memory_pool::pinned_memory_pool( + ::cudaMemPool_t __pool +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +cuda::pinned_memory_pool::pinned_memory_pool( + const pinned_memory_pool & +) = delete +``` + + + + + +### Destructor + +### ~pinned_memory_pool inline noexcept + + +```cpp showLineNumbers={false} +cuda::pinned_memory_pool::~pinned_memory_pool() noexcept +``` + + +--- + +## Assignment operators + +### operator= + + +```cpp showLineNumbers={false} +pinned_memory_pool & cuda::pinned_memory_pool::operator=( + const pinned_memory_pool & +) = delete +``` + + +--- + +## Methods + +### as_ref inline noexcept + +Returns a [`pinned_memory_pool_ref`](/libcudacxx/api/cuda::pinned_memory_pool_ref) for this `pinned_memory_pool`. + +We return by reference to ensure that we can subsequently convert to a resource_ref + + +```cpp showLineNumbers={false} +pinned_memory_pool_ref & cuda::pinned_memory_pool::as_ref() noexcept +``` + + +--- + +## Static methods + +### from_native_handle inline static noexcept + + +```cpp showLineNumbers={false} +static pinned_memory_pool cuda::pinned_memory_pool::from_native_handle( + ::cudaMemPool_t __pool +) noexcept +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `reference_type` | `pinned_memory_pool_ref` | +| `default_queries` | `::cuda::mr::properties_list<::cuda::mr::device_accessible, ::cuda::mr::host_accessible >` | diff --git a/fern/cudapages/cuda/cuda/cuda/pinned_memory_pool_ref.mdx b/fern/cudapages/cuda/cuda/cuda/pinned_memory_pool_ref.mdx new file mode 100644 index 0000000..9355da4 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/pinned_memory_pool_ref.mdx @@ -0,0 +1,50 @@ +--- +title: "cuda::pinned_memory_pool_ref" +description: "" +--- + +`pinned_memory_pool_ref` allocates pinned memory using `cudaMallocFromPoolAsync / cudaFreeAsync +<https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`__ +for allocation/deallocation. A `pinned_memory_pool_ref` is a thin wrapper around a \c cudaMemPool_t with the location type set to \c cudaMemLocationTypeHost or \c cudaMemLocationTypeHostNuma. + +.. warning:: + + `pinned_memory_pool_ref` does not own the pool and it is the responsibility of the user to ensure that the lifetime of the pool exceeds the lifetime of the `pinned_memory_pool_ref`. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `cuda::__memory_pool_base` (public) + +--- + +## Constructors + +### pinned_memory_pool_ref inline explicit noexcept + +Constructs the `pinned_memory_pool_ref` from a `cudaMemPool_t`. + + +```cpp showLineNumbers={false} +cuda::pinned_memory_pool_ref::pinned_memory_pool_ref( + ::cudaMemPool_t __pool +) noexcept +``` + + +**Parameters** + + +The `cudaMemPool_t` used to allocate memory. + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `default_queries` | `::cuda::mr::properties_list<::cuda::mr::device_accessible, ::cuda::mr::host_accessible >` | diff --git a/fern/cudapages/cuda/cuda/cuda/property_with_value.mdx b/fern/cudapages/cuda/cuda/cuda/property_with_value.mdx new file mode 100644 index 0000000..9409884 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/property_with_value.mdx @@ -0,0 +1,24 @@ +--- +title: "cuda::property_with_value" +description: "The `property_with_value` concept verifies that a Property is stateful and signals this through the `value_type` alias." +--- + +C++20 concept + +The `property_with_value` concept verifies that a Property is stateful and signals this through the `value_type` alias. + + +```cpp showLineNumbers={false} +template +concept property_with_value = /* see description */; +``` + + + + + + + + + + diff --git a/fern/cudapages/cuda/cuda/cuda/shuffle_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/shuffle_iterator.mdx new file mode 100644 index 0000000..8fe36eb --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/shuffle_iterator.mdx @@ -0,0 +1,259 @@ +--- +title: "cuda::shuffle_iterator" +description: "[`shuffle_iterator`](/libcudacxx/api/cuda::shuffle_iterator) is an iterator which generates a sequence of integral values representing a random permutation." +--- + +`shuffle_iterator` is an iterator which generates a sequence of integral values representing a random permutation. + +`shuffle_iterator` is an iterator which generates a sequence of values representing a random permutation. This iterator is useful for working with random permutations of a range without explicitly storing them in memory. The shuffle iterator is also useful for sampling from a range by selecting only a subset of the elements in the permutation. + +The following code snippet demonstrates how to create a `shuffle_iterator` which generates a random permutation of the range[0, 4) + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +... +// create a shuffle_iterator +cuda::shuffle_iterator iterator{cuda::random_bijection{4, cuda::std::minstd_rand(0xDEADBEEF)}}; +// iterator[0] returns 1 +// iterator[1] returns 3 +// iterator[2] returns 2 +// iterator[3] returns 0 +``` + + + + + +The type of the index to shuffle. Defaults to uint64_t + + + + + + + + +--- + +## Constructors + +### shuffle_iterator constexpr noexcept + + + + + +```cpp showLineNumbers={false} +cuda::shuffle_iterator<_IndexType, _Bijection>::shuffle_iterator() noexcept = default +``` + + + + + +inline + +Constructs a `shuffle_iterator` from a given bijection and an optional start position. + + +```cpp showLineNumbers={false} +cuda::shuffle_iterator<_IndexType, _Bijection>::shuffle_iterator( + _Bijection __bijection, + value_type __start = 0 +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Bijection >) +``` + + +**Parameters** + + +The bijection representing the shuffled integer sequence + + + +The position of the iterator in the shuffled integer sequence + + + + + +inline explicit + +Constructs a `shuffle_iterator` by constructing the bijection function in place and an optional start position. + + +```cpp showLineNumbers={false} +template +cuda::shuffle_iterator<_IndexType, _Bijection>::shuffle_iterator( + value_type __num_elements, + _RNG &&__gen, + value_type __start = 0 +) noexcept(::cuda::std::is_nothrow_constructible_v< _Bijection, value_type, _RNG >) +``` + + +**Parameters** + + +The size of the bijection sequence + + + +The random number generator to initialize the bijection + + + +The optional stating index of the `shuffle_iterator` in the bijection sequence + + + + + +--- + +## Methods + +### operator* inline constexpr const noexcept nodiscard + +Dereferences the `shuffle_iterator` by invoking the bijection with the stored index. + + +```cpp showLineNumbers={false} +value_type cuda::shuffle_iterator<_IndexType, _Bijection>::operator*() const noexcept(__bijection_(0)) +``` + + +### operator[] inline constexpr const noexcept nodiscard + +Subscripts the `shuffle_iterator` by invoking the bijection with the stored index advanced by a given number of elements. + + +```cpp showLineNumbers={false} +value_type cuda::shuffle_iterator<_IndexType, _Bijection>::operator[]( + difference_type __n +) const noexcept(__bijection_(0)) +``` + + +**Parameters** + + +The additional number of elements + + +### operator++ inline constexpr noexcept + + + + +Increments the [`permutation_iterator`](/libcudacxx/api/cuda::permutation_iterator). + + +```cpp showLineNumbers={false} +shuffle_iterator & cuda::shuffle_iterator<_IndexType, _Bijection>::operator++() noexcept +``` + + + + + +Increments the [`permutation_iterator`](/libcudacxx/api/cuda::permutation_iterator). + + +```cpp showLineNumbers={false} +shuffle_iterator cuda::shuffle_iterator<_IndexType, _Bijection>::operator++( + int +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Bijection >) +``` + + + + + +### operator-- inline constexpr noexcept + + + + +Decrements the [`permutation_iterator`](/libcudacxx/api/cuda::permutation_iterator). + + +```cpp showLineNumbers={false} +shuffle_iterator & cuda::shuffle_iterator<_IndexType, _Bijection>::operator--() noexcept +``` + + + + + +nodiscard + +Decrements the [`permutation_iterator`](/libcudacxx/api/cuda::permutation_iterator). + + +```cpp showLineNumbers={false} +shuffle_iterator cuda::shuffle_iterator<_IndexType, _Bijection>::operator--( + int +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Bijection >) +``` + + + + + +### operator+= inline constexpr noexcept + +Advances the [`permutation_iterator`](/libcudacxx/api/cuda::permutation_iterator) by a given number of elements. + + +```cpp showLineNumbers={false} +shuffle_iterator & cuda::shuffle_iterator<_IndexType, _Bijection>::operator+=( + difference_type __n +) noexcept +``` + + +**Parameters** + + +The number of elements to advance + + +### operator-= inline constexpr noexcept + +Decrements the [`permutation_iterator`](/libcudacxx/api/cuda::permutation_iterator) by a given number of elements. + + +```cpp showLineNumbers={false} +shuffle_iterator & cuda::shuffle_iterator<_IndexType, _Bijection>::operator-=( + difference_type __n +) noexcept +``` + + +**Parameters** + + +The number of elements to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_category` | `::cuda::std::random_access_iterator_tag` | +| `iterator_concept` | `::cuda::std::random_access_iterator_tag` | +| `value_type` | `_IndexType` | +| `difference_type` | `::cuda::std::make_signed_t< value_type >` | +| `reference` | `_IndexType` | +| `pointer` | `void` | diff --git a/fern/cudapages/cuda/cuda/cuda/std/pointer_traits.mdx b/fern/cudapages/cuda/cuda/cuda/std/pointer_traits.mdx new file mode 100644 index 0000000..0ad238a --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/std/pointer_traits.mdx @@ -0,0 +1,56 @@ +--- +title: "cuda::std::pointer_traits<::cuda::heterogeneous_iterator< _Tp, _Properties... > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +--- + +## Static methods + +### to_address inline static constexpr noexcept nodiscard + +Retrieve the address of the element pointed at by an [heterogeneous_iterator](/libcudacxx/api/cuda::heterogeneous_iterator). + + >": "/libcudacxx/api/cuda::std::pointer_traits%3C::cuda::heterogeneous_iterator%3C _Tp, _Properties... %3E %3E"}}> +```cpp showLineNumbers={false} +static constexpr element_type * cuda::std::pointer_traits<::cuda::heterogeneous_iterator<_Tp, _Properties...>>::to_address( + const pointer __iter +) noexcept +``` + + +**Returns:** A pointer to the element pointed to by the [heterogeneous_iterator](/libcudacxx/api/cuda::heterogeneous_iterator) + +**Parameters** + + +A [heterogeneous_iterator](/libcudacxx/api/cuda::heterogeneous_iterator). + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `pointer` | `::cuda::heterogeneous_iterator< _Tp, _Properties... >` | +| `element_type` | `_Tp` | +| `difference_type` | `::cuda::std::ptrdiff_t` | diff --git a/fern/cudapages/cuda/cuda/cuda/stream.mdx b/fern/cudapages/cuda/cuda/cuda/stream.mdx new file mode 100644 index 0000000..1865562 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/stream.mdx @@ -0,0 +1,445 @@ +--- +title: "cuda::stream" +description: "An owning wrapper for cudaStream_t." +--- + +An owning wrapper for cudaStream_t. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `cuda::stream_ref` (public) + +--- + +## Constructors + +### stream inline + + + + +explicit + +Constructs a stream on a specified device and with specified priority. + +Priority is defaulted to [stream::default_priority](/libcudacxx/api/cuda::stream::default_priority) + + +```cpp showLineNumbers={false} +cuda::stream::stream( + device_ref __dev, + int __priority = default_priority +) +``` + + +**Throws:** `cuda_error` if stream creation fails + + + + +explicit noexcept + +Construct a new `stream` object into the moved-from state. + + +```cpp showLineNumbers={false} +cuda::stream::stream( + no_init_t +) noexcept +``` + + + +[`stream()`](/libcudacxx/api/cuda::stream::stream()) returns an invalid stream handle + + + + + +noexcept + +Move-construct a new `stream` object. + + +```cpp showLineNumbers={false} +cuda::stream::stream( + stream &&__other +) noexcept +``` + + + +`__other` is in moved-from state. + + + + + +explicit + + +```cpp showLineNumbers={false} +cuda::stream::stream( + ::cudaStream_t __handle +) +``` + + + + + + +```cpp showLineNumbers={false} +cuda::stream::stream( + const stream & +) = delete +``` + + + + + +### Destructor + +### ~stream inline + +Destroy the `stream` object. + + +```cpp showLineNumbers={false} +cuda::stream::~stream() +``` + + + +If the stream fails to be destroyed, the error is silently ignored. + + +--- + +## Assignment operators + +### operator= inline noexcept + + + + +Move-assign a `stream` object. + + +```cpp showLineNumbers={false} +stream & cuda::stream::operator=( + stream &&__other +) noexcept +``` + + + +`__other` is in a moved-from state. + + + + + + +```cpp showLineNumbers={false} +stream & cuda::stream::operator=( + const stream & +) = delete +``` + + + + + +--- + +## Methods + +### release inline nodiscard + +Retrieve the native `cudaStream_t` handle and give up ownership. + + +```cpp showLineNumbers={false} +::cudaStream_t cuda::stream::release() +``` + + + +The stream object is in a moved-from state. + + +**Returns:** cudaStream_t The native handle being held by the `stream` object. + +### get inline constexpr const noexcept nodiscard + +Returns the wrapped `cudaStream_t` handle. + + +```cpp showLineNumbers={false} +value_type cuda::stream::get() const noexcept +``` + + +### sync inline const + +Synchronizes the wrapped stream. + + +```cpp showLineNumbers={false} +void cuda::stream::sync() const +``` + + +**Throws:** `cuda::cuda_error` if synchronization fails. + +### wait inline const + + + + +Deprecated. + +Use [sync()](/libcudacxx/api/cuda::stream_ref::sync()) instead. + + +```cpp showLineNumbers={false} +void cuda::stream::wait() const +``` + + + +Use [sync()](/libcudacxx/api/cuda::stream_ref::sync()) instead. + + + + + +Make all future work submitted into this stream depend on completion of the specified event. + + +```cpp showLineNumbers={false} +void cuda::stream::wait( + event_ref __ev +) const +``` + + +**Throws:** `cuda_error` if inserting the dependency fails + +**Parameters** + + +Event that this stream should wait for + + + + + +Make all future work submitted into this stream depend on completion of all work from the specified stream. + + +```cpp showLineNumbers={false} +void cuda::stream::wait( + stream_ref __other +) const +``` + + +**Throws:** `cuda_error` if inserting the dependency fails + +**Parameters** + + +Stream that this stream should wait for + + + + + +### is_done inline const nodiscard + +Queries if all operations on the stream have completed. + + +```cpp showLineNumbers={false} +bool cuda::stream::is_done() const +``` + + +**Returns:** `true` if all operations have completed, or `false` if not. + +**Throws:** `cuda::cuda_error` if the query fails. + +### ready inline const nodiscard + +Queries if all operations on the wrapped stream have completed. + + +```cpp showLineNumbers={false} +bool cuda::stream::ready() const +``` + + +**Returns:** `true` if all operations have completed, or `false` if not. + +**Throws:** `cuda::cuda_error` if the query fails. + +### priority inline const nodiscard + +Queries the priority of the wrapped stream. + + +```cpp showLineNumbers={false} +int cuda::stream::priority() const +``` + + +**Returns:** value representing the priority of the wrapped stream. + +**Throws:** `cuda::cuda_error` if the query fails. + +### id inline const nodiscard + +Get the unique ID of the stream. + +Stream handles are sometimes reused, but ID is guaranteed to be unique. + + +```cpp showLineNumbers={false} +stream_id cuda::stream::id() const +``` + + +**Returns:** The unique ID of the stream + +**Throws:** `cuda_error` if the ID query fails + +### record_event inline const nodiscard + +Create a new event and record it into this stream. + + +```cpp showLineNumbers={false} +event cuda::stream::record_event( + event_flags __flags = event_flags::none +) const +``` + + +**Returns:** A new event that was recorded into this stream + +**Throws:** `cuda_error` if event creation or record failed + +### record_timed_event inline const nodiscard + +Create a new timed event and record it into this stream. + + +```cpp showLineNumbers={false} +timed_event cuda::stream::record_timed_event( + event_flags __flags = event_flags::none +) const +``` + + +**Returns:** A new timed event that was recorded into this stream + +**Throws:** `cuda_error` if event creation or record failed + +### device inline const nodiscard + +Get device under which this stream was created. + +Note: In case of a stream created under a `green_context` the device on which that `green_context` was created is returned + + +```cpp showLineNumbers={false} +device_ref cuda::stream::device() const +``` + + +**Throws:** `cuda_error` if device check fails + +### query inline constexpr const noexcept nodiscard + +Queries the [`stream_ref`](/libcudacxx/api/cuda::stream_ref) for itself. + +This makes [`stream_ref`](/libcudacxx/api/cuda::stream_ref) usable in places where we expect an environment with a [`get_stream_t`](/libcudacxx/api/cuda::get_stream_t) query + + +```cpp showLineNumbers={false} +stream_ref cuda::stream::query( + const ::cuda::get_stream_t & +) const noexcept +``` + + +--- + +## Static methods + +### from_native_handle inline static nodiscard + + + + +Construct an `stream` object from a native `cudaStream_t` handle. + + +```cpp showLineNumbers={false} +static stream cuda::stream::from_native_handle( + ::cudaStream_t __handle +) +``` + + + +The constructed `stream` object takes ownership of the native handle. + + +**Returns:** stream The constructed `stream` object + +**Parameters** + + +The native handle + + + + + +The following overloads are deleted to prevent misuse: + + +```cpp showLineNumbers={false} +static stream cuda::stream::from_native_handle(int) = delete; +static stream cuda::stream::from_native_handle(::cuda::std::nullptr_t) = delete; +static stream cuda::stream::from_native_handle(invalid_stream_t) = delete; +``` + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `value_type` | `::cudaStream_t` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `default_priority` static constexpr | `int` | | diff --git a/fern/cudapages/cuda/cuda/cuda/stream_ref.mdx b/fern/cudapages/cuda/cuda/cuda/stream_ref.mdx new file mode 100644 index 0000000..3477f1b --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/stream_ref.mdx @@ -0,0 +1,306 @@ +--- +title: "cuda::stream_ref" +description: "A non-owning wrapper for a `cudaStream_t`." +--- + +A non-owning wrapper for a `cudaStream_t`. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Constructors + +### stream_ref + + + + +Constructs a `stream_ref` of the "default" CUDA stream. + +For behavior of the default stream, + + +```cpp showLineNumbers={false} +cuda::stream_ref::stream_ref() = default +``` + + +**See also:** +//! [https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html](https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html) + + + + + +inline constexpr noexcept + +Constructs a `stream_ref` from a `cudaStream_t` handle. + +This constructor provides implicit conversion from `cudaStream_t`. + + +```cpp showLineNumbers={false} +cuda::stream_ref::stream_ref( + value_type __stream_ +) noexcept +``` + + + +: It is the callers responsibility to ensure the `stream_ref` does not outlive the stream identified by the `cudaStream_t` handle. + + + + + +inline explicit noexcept + +Constructs a `stream_ref` from the [`cuda::invalid_stream_t`](/libcudacxx/api/cuda::invalid_stream_t). + + +```cpp showLineNumbers={false} +cuda::stream_ref::stream_ref( + invalid_stream_t +) noexcept +``` + + + +Any CUDA APIs called on the created object will result in an CUDA error. + + + + + +The following overloads are deleted to prevent misuse: + + +```cpp showLineNumbers={false} +cuda::stream_ref::stream_ref(int) = delete; +cuda::stream_ref::stream_ref(::cuda::std::nullptr_t) = delete; +``` + + + + + +--- + +## Methods + +### get inline constexpr const noexcept nodiscard + +Returns the wrapped `cudaStream_t` handle. + + +```cpp showLineNumbers={false} +value_type cuda::stream_ref::get() const noexcept +``` + + +### sync inline const + +Synchronizes the wrapped stream. + + +```cpp showLineNumbers={false} +void cuda::stream_ref::sync() const +``` + + +**Throws:** `cuda::cuda_error` if synchronization fails. + +### wait inline const + + + + +Deprecated. + +Use [sync()](/libcudacxx/api/cuda::stream_ref::sync()) instead. + + +```cpp showLineNumbers={false} +void cuda::stream_ref::wait() const +``` + + + +Use [sync()](/libcudacxx/api/cuda::stream_ref::sync()) instead. + + + + + +Make all future work submitted into this stream depend on completion of the specified event. + + +```cpp showLineNumbers={false} +void cuda::stream_ref::wait( + event_ref __ev +) const +``` + + +**Throws:** `cuda_error` if inserting the dependency fails + +**Parameters** + + +Event that this stream should wait for + + + + + +Make all future work submitted into this stream depend on completion of all work from the specified stream. + + +```cpp showLineNumbers={false} +void cuda::stream_ref::wait( + stream_ref __other +) const +``` + + +**Throws:** `cuda_error` if inserting the dependency fails + +**Parameters** + + +Stream that this stream should wait for + + + + + +### is_done inline const nodiscard + +Queries if all operations on the stream have completed. + + +```cpp showLineNumbers={false} +bool cuda::stream_ref::is_done() const +``` + + +**Returns:** `true` if all operations have completed, or `false` if not. + +**Throws:** `cuda::cuda_error` if the query fails. + +### ready inline const nodiscard + +Queries if all operations on the wrapped stream have completed. + + +```cpp showLineNumbers={false} +bool cuda::stream_ref::ready() const +``` + + +**Returns:** `true` if all operations have completed, or `false` if not. + +**Throws:** `cuda::cuda_error` if the query fails. + +### priority inline const nodiscard + +Queries the priority of the wrapped stream. + + +```cpp showLineNumbers={false} +int cuda::stream_ref::priority() const +``` + + +**Returns:** value representing the priority of the wrapped stream. + +**Throws:** `cuda::cuda_error` if the query fails. + +### id inline const nodiscard + +Get the unique ID of the stream. + +Stream handles are sometimes reused, but ID is guaranteed to be unique. + + +```cpp showLineNumbers={false} +stream_id cuda::stream_ref::id() const +``` + + +**Returns:** The unique ID of the stream + +**Throws:** `cuda_error` if the ID query fails + +### record_event inline const nodiscard + +Create a new event and record it into this stream. + + +```cpp showLineNumbers={false} +event cuda::stream_ref::record_event( + event_flags __flags = event_flags::none +) const +``` + + +**Returns:** A new event that was recorded into this stream + +**Throws:** `cuda_error` if event creation or record failed + +### record_timed_event inline const nodiscard + +Create a new timed event and record it into this stream. + + +```cpp showLineNumbers={false} +timed_event cuda::stream_ref::record_timed_event( + event_flags __flags = event_flags::none +) const +``` + + +**Returns:** A new timed event that was recorded into this stream + +**Throws:** `cuda_error` if event creation or record failed + +### device inline const nodiscard + +Get device under which this stream was created. + +Note: In case of a stream created under a `green_context` the device on which that `green_context` was created is returned + + +```cpp showLineNumbers={false} +device_ref cuda::stream_ref::device() const +``` + + +**Throws:** `cuda_error` if device check fails + +### query inline constexpr const noexcept nodiscard + +Queries the `stream_ref` for itself. + +This makes `stream_ref` usable in places where we expect an environment with a [`get_stream_t`](/libcudacxx/api/cuda::get_stream_t) query + + +```cpp showLineNumbers={false} +stream_ref cuda::stream_ref::query( + const ::cuda::get_stream_t & +) const noexcept +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `value_type` | `::cudaStream_t` | diff --git a/fern/cudapages/cuda/cuda/cuda/strided_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/strided_iterator.mdx new file mode 100644 index 0000000..e296c7e --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/strided_iterator.mdx @@ -0,0 +1,341 @@ +--- +title: "cuda::strided_iterator" +description: "A [`strided_iterator`](/libcudacxx/api/cuda::strided_iterator) wraps another iterator and advances it by a specified stride each time it is incremented or decremented." +--- + +A `strided_iterator` wraps another iterator and advances it by a specified stride each time it is incremented or decremented. + +```cpp showLineNumbers={false} +#include +``` + + + + + +A random access iterator + + + +Either an [integer-like](https://eel.is/c++draft/iterator.concept.winc#4) or an [integral-constant-like](https://eel.is/c++draft/views.contiguous#concept:integral-constant-like) specifying the stride + + + + + +--- + +## Constructors + +### strided_iterator inline constexpr noexcept + + + + +value-initializes both the base iterator and stride + + +```cpp showLineNumbers={false} +template +cuda::strided_iterator<_Iter, _Stride>::strided_iterator() noexcept(::cuda::std::is_nothrow_default_constructible_v< _Iter2 > &&::cuda::std::is_nothrow_default_constructible_v< _Stride2 >) +``` + + + +_Iter must be default initializable because it is a random_access_iterator and thereby semiregular _Stride must be integer-like or integral_constant_like which requires default constructability + + + + + +explicit + +Constructs a `strided_iterator` from a base iterator. + + +```cpp showLineNumbers={false} +template +cuda::strided_iterator<_Iter, _Stride>::strided_iterator( + _Iter __iter +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Iter > &&::cuda::std::is_nothrow_default_constructible_v< _Stride2 >) +``` + + + +We cannot construct a `strided_iterator` with an [integer-like](https://eel.is/c++draft/iterator.concept.winc#4) stride, because that would value construct to 0 and incrementing the iterator would do nothing. + + +**Parameters** + + +The base iterator + + + + + +explicit + +Constructs a `strided_iterator` from a base iterator and a stride. + + +```cpp showLineNumbers={false} +cuda::strided_iterator<_Iter, _Stride>::strided_iterator( + _Iter __iter, + _Stride __stride +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Iter > &&::cuda::std::is_nothrow_move_constructible_v< _Stride >) +``` + + +**Parameters** + + +The base iterator + + + +The new stride + + + + + +--- + +## Methods + +### base inline constexpr noexcept nodiscard + + + + +Extracts the stored iterator. + + +```cpp showLineNumbers={false} +_Iter cuda::strided_iterator<_Iter, _Stride>::base() && noexcept(::cuda::std::is_nothrow_move_constructible_v< _Iter >) +``` + + + + + +const + +Returns a const reference to the stored iterator. + + +```cpp showLineNumbers={false} +const _Iter & cuda::strided_iterator<_Iter, _Stride>::base() const & noexcept +``` + + + + + +### stride inline constexpr const noexcept nodiscard + +Returns the current stride as an integral value. + + +```cpp showLineNumbers={false} +difference_type cuda::strided_iterator<_Iter, _Stride>::stride() const noexcept(__noexcept_stride) +``` + + +### operator* inline constexpr noexcept nodiscard + + + + +Dereferences the stored base iterator. + + +```cpp showLineNumbers={false} +decltype( + auto +) noexcept(*::cuda::std::declval< _Iter & >()) +``` + + + + + +const + +Dereferences the stored base iterator. + + +```cpp showLineNumbers={false} +template +decltype( + auto +) const noexcept(*::cuda::std::declval< const _Iter2 & >()) +``` + + + + + +### operator[] inline constexpr noexcept nodiscard + + + + +Subscripts the stored base iterator with a given offset times the stride. + + +```cpp showLineNumbers={false} +decltype( + auto +) noexcept(__noexcept_stride &&noexcept(::cuda::std::declval< _Iter & >()[__n])) +``` + + +**Parameters** + + +The offset + + + + + +const + +Subscripts the stored base iterator with a given offset times the stride. + + +```cpp showLineNumbers={false} +template +decltype( + auto +) const noexcept(__noexcept_stride &&noexcept(::cuda::std::declval< const _Iter2 & >()[__n])) +``` + + +**Parameters** + + +The offset + + + + + +### operator++ inline constexpr noexcept + + + + +Increments the stored base iterator by the stride. + + +```cpp showLineNumbers={false} +strided_iterator & cuda::strided_iterator<_Iter, _Stride>::operator++() noexcept(__noexcept_stride &&noexcept(::cuda::std::declval< _Iter & >()+=1)) +``` + + + + + +Increments the stored base iterator by the stride. + + +```cpp showLineNumbers={false} +auto cuda::strided_iterator<_Iter, _Stride>::operator++( + int +) noexcept(noexcept(__noexcept_stride &&noexcept(::cuda::std::declval< _Iter & >()+=1)) &&::cuda::std::is_nothrow_copy_constructible_v< _Iter > &&::cuda::std::is_nothrow_copy_constructible_v< _Stride >) +``` + + + + + +### operator-- inline constexpr noexcept + + + + +Decrements the stored base iterator by the stride. + + +```cpp showLineNumbers={false} +strided_iterator & cuda::strided_iterator<_Iter, _Stride>::operator--() noexcept(__noexcept_stride &&noexcept(::cuda::std::declval< _Iter & >() -=1)) +``` + + + + + +Decrements the stored base iterator by the stride. + + +```cpp showLineNumbers={false} +strided_iterator cuda::strided_iterator<_Iter, _Stride>::operator--( + int +) noexcept(noexcept(__noexcept_stride &&noexcept(::cuda::std::declval< _Iter & >() -=1)) &&::cuda::std::is_nothrow_copy_constructible_v< _Iter > &&::cuda::std::is_nothrow_copy_constructible_v< _Stride >) +``` + + + + + +### operator+= inline constexpr noexcept + +Advances a `strided_iterator` by a given number of steps. + + +```cpp showLineNumbers={false} +strided_iterator & cuda::strided_iterator<_Iter, _Stride>::operator+=( + difference_type __n +) noexcept(__noexcept_stride &&noexcept(::cuda::std::declval< _Iter & >()+=1)) +``` + + + +Increments the base iterator by `__n` times the stride + + +**Parameters** + + +The number of steps to increment + + +### operator-= inline constexpr noexcept + +Decrements a `strided_iterator` by a given number of steps. + + +```cpp showLineNumbers={false} +strided_iterator & cuda::strided_iterator<_Iter, _Stride>::operator-=( + difference_type __n +) noexcept(__noexcept_stride &&noexcept(::cuda::std::declval< _Iter & >() -=1)) +``` + + + +Decrements the base iterator by `__n` times the stride + + +**Parameters** + + +The number of steps to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `::cuda::std::random_access_iterator_tag` | +| `iterator_category` | `::cuda::std::random_access_iterator_tag` | +| `value_type` | `::cuda::std::iter_value_t< _Iter >` | +| `difference_type` | `::cuda::std::iter_difference_t< _Iter >` | +| `reference` | `::cuda::std::iter_reference_t< _Iter >` | +| `pointer` | `void` | diff --git a/fern/cudapages/cuda/cuda/cuda/tabulate_output_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/tabulate_output_iterator.mdx new file mode 100644 index 0000000..a422fac --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/tabulate_output_iterator.mdx @@ -0,0 +1,295 @@ +--- +title: "cuda::tabulate_output_iterator" +description: "[`tabulate_output_iterator`](/libcudacxx/api/cuda::tabulate_output_iterator) is a special kind of output iterator which, whenever a value is assigned to a dereferenced iterator, calls the given callable with the index that corresponds to the offset of the dereferenced iterator and the assigned value." +--- + +`tabulate_output_iterator` is a special kind of output iterator which, whenever a value is assigned to a dereferenced iterator, calls the given callable with the index that corresponds to the offset of the dereferenced iterator and the assigned value. + +The following code snippet demonstrates how to create a `tabulate_output_iterator` which prints the index and the assigned value. + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include + +struct print_op +{ + __host__ __device__ void operator()(int index, float value) const + { + printf("%d: %f\n", index, value); + } +}; + +int main() +{ + auto tabulate_it = cuda::make_tabulate_output_iterator(print_op{}); + + tabulate_it[0] = 1.0f; // prints: 0: 1.0 + tabulate_it[1] = 3.0f; // prints: 1: 3.0 + tabulate_it[9] = 5.0f; // prints: 9: 5.0 +} +``` + + + + + + + + + + + + + +--- + +## Constructors + +### tabulate_output_iterator inline constexpr noexcept + + + + + +```cpp showLineNumbers={false} +template +cuda::tabulate_output_iterator<_Fn, _Index>::tabulate_output_iterator() noexcept(::cuda::std::is_nothrow_default_constructible_v< _Fn2 >) +``` + + + + + +Constructs a `tabulate_output_iterator` with a given functor and an optional index. + + +```cpp showLineNumbers={false} +cuda::tabulate_output_iterator<_Fn, _Index>::tabulate_output_iterator( + _Fn __func, + _Index __index = 0 +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Fn >) +``` + + +**Parameters** + + +The output function + + + +The position in the output sequence + + + + + +--- + +## Methods + +### index inline constexpr const noexcept nodiscard + +Returns the stored index. + + +```cpp showLineNumbers={false} +difference_type cuda::tabulate_output_iterator<_Fn, _Index>::index() const noexcept +``` + + +### operator* inline constexpr noexcept nodiscard + + + + +Dereferences the `tabulate_output_iterator`. + + +```cpp showLineNumbers={false} +auto cuda::tabulate_output_iterator<_Fn, _Index>::operator*() noexcept +``` + + +**Returns:** A proxy that applies the stored function and index on assignment + + + + +const + +Dereferences the `tabulate_output_iterator`. + + +```cpp showLineNumbers={false} +auto cuda::tabulate_output_iterator<_Fn, _Index>::operator*() const noexcept +``` + + +**Returns:** A proxy that applies the stored function and index on assignment + + + + +### operator[] inline constexpr noexcept nodiscard + + + + +Subscripts the `tabulate_output_iterator` with a given offset. + + +```cpp showLineNumbers={false} +auto cuda::tabulate_output_iterator<_Fn, _Index>::operator[]( + difference_type __n +) noexcept +``` + + +**Returns:** A proxy that applies the stored function and index on assignment + +**Parameters** + + +The additional offset to advance the stored index + + + + + +const + +Subscripts the `tabulate_output_iterator` with a given offset. + + +```cpp showLineNumbers={false} +auto cuda::tabulate_output_iterator<_Fn, _Index>::operator[]( + difference_type __n +) const noexcept +``` + + +**Returns:** A proxy that applies the stored function and index on assignment + +**Parameters** + + +The additional offset to advance the stored index + + + + + +### operator++ inline constexpr noexcept + + + + +Increments the `tabulate_output_iterator` by incrementing the stored index. + + +```cpp showLineNumbers={false} +tabulate_output_iterator & cuda::tabulate_output_iterator<_Fn, _Index>::operator++() noexcept +``` + + + + + +Increments the `tabulate_output_iterator` by incrementing the stored index. + + +```cpp showLineNumbers={false} +tabulate_output_iterator cuda::tabulate_output_iterator<_Fn, _Index>::operator++( + int +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Fn >) +``` + + + + + +### operator-- inline constexpr noexcept + + + + +Decrements the `tabulate_output_iterator` by decrementing the stored index. + + +```cpp showLineNumbers={false} +tabulate_output_iterator & cuda::tabulate_output_iterator<_Fn, _Index>::operator--() noexcept +``` + + + + + +Decrements the `tabulate_output_iterator` by decrementing the stored index. + + +```cpp showLineNumbers={false} +tabulate_output_iterator cuda::tabulate_output_iterator<_Fn, _Index>::operator--( + int +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Fn >) +``` + + + + + +### operator+= inline constexpr noexcept + +Advances the `tabulate_output_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +tabulate_output_iterator & cuda::tabulate_output_iterator<_Fn, _Index>::operator+=( + difference_type __n +) noexcept +``` + + +**Parameters** + + +The number of elements to advance + + +### operator-= inline constexpr noexcept + +Decrements the `tabulate_output_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +tabulate_output_iterator & cuda::tabulate_output_iterator<_Fn, _Index>::operator-=( + difference_type __n +) noexcept +``` + + +**Parameters** + + +The number of elements to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `::cuda::std::random_access_iterator_tag` | +| `iterator_category` | `::cuda::std::random_access_iterator_tag` | +| `difference_type` | `_Index` | +| `value_type` | `void` | +| `pointer` | `void` | +| `reference` | `void` | diff --git a/fern/cudapages/cuda/cuda/cuda/timed_event.mdx b/fern/cudapages/cuda/cuda/cuda/timed_event.mdx new file mode 100644 index 0000000..0ad5cf0 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/timed_event.mdx @@ -0,0 +1,286 @@ +--- +title: "cuda::timed_event" +description: "An owning wrapper for a `cudaEvent_t` with timing enabled." +--- + +An owning wrapper for a `cudaEvent_t` with timing enabled. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `cuda::event` (public) + +--- + +## Constructors + +### timed_event + + + + +inline explicit + +Construct a new `timed_event` object with the specified flags and record the event on the specified stream. + + +```cpp showLineNumbers={false} +cuda::timed_event::timed_event( + stream_ref __stream, + event_flags __flags = event_flags::none +) +``` + + +**Throws:** `cuda_error` if the event creation fails. + + + + +inline explicit + +Construct a new `timed_event` object with the specified flags. + +The event can only be recorded on streams from the specified device. + + +```cpp showLineNumbers={false} +cuda::timed_event::timed_event( + device_ref __device, + event_flags __flags = event_flags::none +) +``` + + +**Throws:** `cuda_error` if the event creation fails. + + + + +inline constexpr explicit noexcept + +Construct a new `timed_event` object into the moved-from state. + + +```cpp showLineNumbers={false} +cuda::timed_event::timed_event( + no_init_t +) noexcept +``` + + + +[`get()`](/libcudacxx/api/cuda::event_ref::get()) returns `cudaEvent_t()`. + + + + + +noexcept + + +```cpp showLineNumbers={false} +cuda::timed_event::timed_event( + timed_event && +) noexcept = default +``` + + + + + +inline constexpr explicit noexcept + + +```cpp showLineNumbers={false} +cuda::timed_event::timed_event( + ::cudaEvent_t __evnt +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +cuda::timed_event::timed_event( + const timed_event & +) = delete +``` + + + + + +--- + +## Assignment operators + +### operator= noexcept + + + + + +```cpp showLineNumbers={false} +timed_event & cuda::timed_event::operator=( + timed_event && +) noexcept = default +``` + + + + + + +```cpp showLineNumbers={false} +timed_event & cuda::timed_event::operator=( + const timed_event & +) = delete +``` + + + + + +--- + +## Methods + +### release inline noexcept nodiscard + +Retrieve the native `cudaEvent_t` handle and give up ownership. + + +```cpp showLineNumbers={false} +::cudaEvent_t cuda::timed_event::release() noexcept +``` + + + +The event object is in a moved-from state. + + +**Returns:** cudaEvent_t The native handle being held by the [`event`](/libcudacxx/api/cuda::event) object. + +### record inline const + +Records an event on the specified stream. + + +```cpp showLineNumbers={false} +void cuda::event_ref::record( + stream_ref __stream +) const +``` + + +**Throws:** `cuda_error` if the event record fails + +### sync inline const + +Synchronizes the event. + + +```cpp showLineNumbers={false} +void cuda::event_ref::sync() const +``` + + +**Throws:** `cuda_error` if waiting for the event fails + +### is_done inline const nodiscard + +Checks if all the work in the stream prior to the record of the event has completed. + +If is_done returns true, calling [sync()](/libcudacxx/api/cuda::event_ref::sync()) on this event will return immediately + + +```cpp showLineNumbers={false} +bool cuda::event_ref::is_done() const +``` + + +**Throws:** `cuda_error` if the event query fails + +### get inline const noexcept nodiscard + +Retrieve the native `cudaEvent_t` handle. + + +```cpp showLineNumbers={false} +::cudaEvent_t cuda::event_ref::get() const noexcept +``` + + +**Returns:** cudaEvent_t The native handle being held by the [event_ref](/libcudacxx/api/cuda::event_ref) object. + +### operator bool inline constexpr explicit const noexcept nodiscard + +Checks if the [`event_ref`](/libcudacxx/api/cuda::event_ref) is valid. + + +```cpp showLineNumbers={false} +cuda::event_ref::operator bool() const noexcept +``` + + +**Returns:** true if the [`event_ref`](/libcudacxx/api/cuda::event_ref) is valid, false otherwise. + +--- + +## Static methods + +### from_native_handle inline static noexcept nodiscard + + + + +Construct a `timed_event` object from a native `cudaEvent_t` handle. + + +```cpp showLineNumbers={false} +static timed_event cuda::timed_event::from_native_handle( + ::cudaEvent_t __evnt +) noexcept +``` + + + +The constructed `timed_event` object takes ownership of the native handle. + + +**Returns:** `timed_event` The constructed `timed_event` object + +**Parameters** + + +The native handle + + + + + +The following overloads are deleted to prevent misuse: + + +```cpp showLineNumbers={false} +static timed_event cuda::timed_event::from_native_handle(int) = delete; +static timed_event cuda::timed_event::from_native_handle(::cuda::std::nullptr_t) = delete; +``` + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `value_type` | `::cudaEvent_t` | diff --git a/fern/cudapages/cuda/cuda/cuda/transform_input_output_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/transform_input_output_iterator.mdx new file mode 100644 index 0000000..5cfdf48 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/transform_input_output_iterator.mdx @@ -0,0 +1,345 @@ +--- +title: "cuda::transform_input_output_iterator" +description: "[`transform_input_output_iterator`](/libcudacxx/api/cuda::transform_input_output_iterator) is a special kind of iterator which applies transform functions when reading from or writing to dereferenced values." +--- + +`transform_input_output_iterator` is a special kind of iterator which applies transform functions when reading from or writing to dereferenced values. + +This iterator is useful for algorithms that operate on a type that needs to be serialized/deserialized from values in another iterator, avoiding the need to materialize intermediate results in memory. This also enables the transform functions to be fused with the operations that read and write to the `transform_input_output_iterator`. + +The following code snippet demonstrates how to create a `transform_input_output_iterator` which performs different transformations when reading from and writing to the iterator. + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include + + int main() + { + const size_t size = 4; + thrust::device_vector v(size); + + // Write 1.0f, 2.0f, 3.0f, 4.0f to vector + thrust::sequence(v.begin(), v.end(), 1); + + // Iterator that negates read values and writes squared values + auto iter = cuda::make_transform_input_output_iterator(v.begin(), + ::cuda::std::negate{}, thrust::square{}); + + // Iterator negates values when reading + std::cout << iter[0] << " "; // -1.0f; + std::cout << iter[1] << " "; // -2.0f; + std::cout << iter[2] << " "; // -3.0f; + std::cout << iter[3] << "\n"; // -4.0f; + + // Write 1.0f, 2.0f, 3.0f, 4.0f to iterator + thrust::sequence(iter, iter + size, 1); + + // Values were squared before writing to vector + std::cout << v[0] << " "; // 1.0f; + std::cout << v[1] << " "; // 4.0f; + std::cout << v[2] << " "; // 9.0f; + std::cout << v[3] << "\n"; // 16.0f; + + } +``` + + + + + + + + + + + + + + + + +--- + +## Constructors + +### transform_input_output_iterator inline constexpr noexcept + + + + +Default constructs a `transform_input_output_iterator` with a value initialized iterator and functors. + + +```cpp showLineNumbers={false} +template +cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::transform_input_output_iterator() noexcept(::cuda::std::is_nothrow_default_constructible_v< _Iter2 > &&::cuda::std::is_nothrow_default_constructible_v< _InputFn2 > &&::cuda::std::is_nothrow_default_constructible_v< _OutputFn2 >) +``` + + + + + +Constructs a `transform_input_output_iterator` with base iterator, input functor and output functor. + + +```cpp showLineNumbers={false} +cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::transform_input_output_iterator( + _Iter __iter, + _InputFn __input_func, + _OutputFn __output_func +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Iter > &&::cuda::std::is_nothrow_move_constructible_v< _InputFn > &&::cuda::std::is_nothrow_move_constructible_v< _OutputFn >) +``` + + +**Parameters** + + +The iterator to transform + + + +The input functor to apply to the iterator when reading + + + +The output functor to apply to the iterator when writing + + + + + +--- + +## Methods + +### base inline constexpr noexcept nodiscard + + + + +Extracts the stored base iterator. + + +```cpp showLineNumbers={false} +_Iter cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::base() && noexcept(::cuda::std::is_nothrow_move_constructible_v< _Iter >) +``` + + + + + +const + +Returns a const reference to the base iterator stored. + + +```cpp showLineNumbers={false} +const _Iter & cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::base() const & noexcept +``` + + + + + +### operator* inline constexpr noexcept nodiscard + + + + +Dereferences the `transform_input_output_iterator`. + +Returns a proxy that transforms values read from the stored iterator via the stored input functor and transforms assigned values via the output functor + + +```cpp showLineNumbers={false} +reference cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::operator*() noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter >) +``` + + + + + +const + +Dereferences the `transform_input_output_iterator`. + +Returns a proxy that transforms values read from the stored iterator via the stored input functor and transforms assigned values via the output functor + + +```cpp showLineNumbers={false} +reference cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::operator*() const noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter >) +``` + + + + + +### operator[] inline constexpr noexcept nodiscard + + + + +Subscripts the `transform_input_output_iterator`. + +Returns a proxy that transforms values read from the stored iterator adbanvd by a given number of elements via the stored input functor and transforms assigned values via the output functor + + +```cpp showLineNumbers={false} +template +reference cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::operator[]( + difference_type __n +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter2 > &&noexcept(::cuda::std::declval< const _Iter2 & >()+__n)) +``` + + +**Parameters** + + +The number of elements to advance + + + + + +const + +Subscripts the `transform_input_output_iterator`. + +Returns a proxy that transforms values read from the stored iterator adbanvd by a given number of elements via the stored input functor and transforms assigned values via the output functor + + +```cpp showLineNumbers={false} +template +reference cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::operator[]( + difference_type __n +) const noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter2 > &&noexcept(::cuda::std::declval< const _Iter2 & >()+__n)) +``` + + +**Parameters** + + +The number of elements to advance + + + + + +### operator++ inline constexpr noexcept + + + + +Increments the stored iterator. + + +```cpp showLineNumbers={false} +transform_input_output_iterator & cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::operator++() noexcept(++::cuda::std::declval< _Iter & >()) +``` + + + + + +Increments the stored iterator. + + +```cpp showLineNumbers={false} +transform_input_output_iterator cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::operator++( + int +) noexcept(noexcept(++::cuda::std::declval< _Iter & >()) &&::cuda::std::is_nothrow_copy_constructible_v< _Iter > &&::cuda::std::is_nothrow_copy_constructible_v< _InputFn > &&::cuda::std::is_nothrow_copy_constructible_v< _OutputFn >) +``` + + + + + +### operator-- inline constexpr noexcept + + + + +Decrements the stored iterator. + + +```cpp showLineNumbers={false} +template +transform_input_output_iterator & cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::operator--() noexcept(--::cuda::std::declval< _Iter2 & >()) +``` + + + + + +Decrements the stored iterator. + + +```cpp showLineNumbers={false} +template +transform_input_output_iterator cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::operator--( + int +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter > &&noexcept(--::cuda::std::declval< _Iter2 & >())) +``` + + + + + +### operator+= inline constexpr noexcept + +Advances the `transform_input_output_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +template +transform_input_output_iterator & cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::operator+=( + difference_type __n +) noexcept(::cuda::std::declval< _Iter2 & >()+=__n) +``` + + +**Parameters** + + +The number of elements to advance + + +### operator-= inline constexpr noexcept + +Decrements the `transform_input_output_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +template +transform_input_output_iterator & cuda::transform_input_output_iterator<_InputFn, _OutputFn, _Iter>::operator-=( + difference_type __n +) noexcept(::cuda::std::declval< _Iter2 & >() -=__n) +``` + + +**Parameters** + + +The number of elements to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `::cuda::std::conditional_t< ::cuda::std::__has_random_access_traversal< _Iter >, ::cuda::std::random_access_iterator_tag, ::cuda::std::conditional_t<::cuda::std::__has_bidirectional_traversal< _Iter >, ::cuda::std::bidirectional_iterator_tag, ::cuda::std::conditional_t<::cuda::std::__has_forward_traversal< _Iter >, ::cuda::std::forward_iterator_tag, ::cuda::std::output_iterator_tag > > >` | +| `iterator_category` | `::cuda::std::output_iterator_tag` | +| `difference_type` | `::cuda::std::iter_difference_t< _Iter >` | +| `value_type` | `::cuda::std::invoke_result_t< _InputFn &, ::cuda::std::iter_reference_t< _Iter > >` | +| `pointer` | `void` | +| `reference` | `__transform_input_output_proxy< _InputFn, _OutputFn, _Iter >` | diff --git a/fern/cudapages/cuda/cuda/cuda/transform_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/transform_iterator.mdx new file mode 100644 index 0000000..972a6bb --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/transform_iterator.mdx @@ -0,0 +1,358 @@ +--- +title: "cuda::transform_iterator" +description: "[`transform_iterator`](/libcudacxx/api/cuda::transform_iterator) is an iterator which represents a pointer into a range of values after transformation by a functor." +--- + +`transform_iterator` is an iterator which represents a pointer into a range of values after transformation by a functor. + +This iterator is useful for creating a range filled with the result of applying an operation to another range without either explicitly storing it in memory, or explicitly executing the transformation. Using `transform_iterator` facilitates kernel fusion by deferring the execution of a transformation until the value is needed while saving both memory capacity and bandwidth. + +The following code snippet demonstrates how to create a `transform_iterator` which represents the result of `sqrtf` applied to the contents of a `thrust::device_vector`. + +This next example demonstrates how to use a `transform_iterator` with the `thrust::reduce` functor to compute the sum of squares of a sequence. We will create temporary `transform_iterators` utilising class template argument deduction avoid explicitly specifying their type: + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include + +struct square_root +{ + __host__ __device__ + float operator()(float x) const + { + return sqrtf(x); + } +}; + +int main() +{ + thrust::device_vector v{1.0f, 4.0f, 9.0f, 16.0f}; + + using FloatIterator = thrust::device_vector::iterator; + + cuda::transform_iterator iter(v.begin(), square_root{}); + + *iter; // returns 1.0f + iter[0]; // returns 1.0f; + iter[1]; // returns 2.0f; + iter[2]; // returns 3.0f; + iter[3]; // returns 4.0f; + + // iter[4] is an out-of-bounds error +} +``` + +```cpp showLineNumbers={false} +#include +#include +#include +#include + +struct square +{ + __host__ __device__ + float operator()(float x) const + { + return x * x; + } +}; + +int main() +{ + // initialize a device array + thrust::device_vector v(4); + v[0] = 1.0f; + v[1] = 2.0f; + v[2] = 3.0f; + thrust::device_vector v{1.0f, 2.0f, 3.0f, 4.0f}; + thrust::reduce(cuda::transform_iterator{v.begin(), square{}}, + cuda::transform_iterator{v.end(), square{}}); + + std::cout << "sum of squares: " << sum_of_squares << std::endl; + return 0; +} +``` + + + + + + + + + + + + + +**Inherits from:** `cuda::__transform_iterator_category_base< _Fn, _Iter >` (public) + +--- + +## Constructors + +### transform_iterator inline constexpr noexcept + + + + +Default constructs a `transform_iterator` with a value initialized iterator and functor. + + +```cpp showLineNumbers={false} +template +cuda::transform_iterator<_Fn, _Iter>::transform_iterator() noexcept(::cuda::std::is_nothrow_default_constructible_v< _Iter2 > &&::cuda::std::is_nothrow_default_constructible_v< _Fn2 >) +``` + + + + + +Constructs a `transform_iterator` with a given iterator and functor. + + +```cpp showLineNumbers={false} +cuda::transform_iterator<_Fn, _Iter>::transform_iterator( + _Iter __iter, + _Fn __func +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Iter > &&::cuda::std::is_nothrow_move_constructible_v< _Fn >) +``` + + +**Parameters** + + +The iterator to transform + + + +The functor to apply to the iterator + + + + + +--- + +## Methods + +### base inline constexpr noexcept nodiscard + + + + +Extracts the stored iterator. + + +```cpp showLineNumbers={false} +_Iter cuda::transform_iterator<_Fn, _Iter>::base() && noexcept(::cuda::std::is_nothrow_move_constructible_v< _Iter >) +``` + + + + + +const + +Returns a const reference to the stored iterator. + + +```cpp showLineNumbers={false} +const _Iter & cuda::transform_iterator<_Fn, _Iter>::base() const & noexcept +``` + + + + + +### operator* inline constexpr noexcept nodiscard + + + + +Dereferences the stored iterator and applies the stored functor to the result. + + +```cpp showLineNumbers={false} +reference cuda::transform_iterator<_Fn, _Iter>::operator*() noexcept(::cuda::std::invoke(::cuda::std::declval< _Fn & >(), *::cuda::std::declval< _Iter & >())) +``` + + + + + +const + +Dereferences the stored iterator and applies the stored functor to the result. + + +```cpp showLineNumbers={false} +reference cuda::transform_iterator<_Fn, _Iter>::operator*() const noexcept(::cuda::std::invoke(::cuda::std::declval< const _Fn & >(), *::cuda::std::declval< const _Iter2 & >())) +``` + + + + + +### operator[] inline constexpr noexcept nodiscard + + + + +Subscripts the stored iterator by a number of elements and applies the stored functor to the result. + + +```cpp showLineNumbers={false} +reference cuda::transform_iterator<_Fn, _Iter>::operator[]( + difference_type __n +) noexcept(__transform_iterator_nothrow_subscript< _Fn, _Iter2 >) +``` + + +**Parameters** + + +The number of elements to advance by + + + + + +const + +Subscripts the stored iterator by a number of elements and applies the stored functor to the result. + + +```cpp showLineNumbers={false} +reference cuda::transform_iterator<_Fn, _Iter>::operator[]( + difference_type __n +) const noexcept(__transform_iterator_nothrow_subscript< const _Fn, _Iter2 >) +``` + + +**Parameters** + + +The number of elements to advance by + + + + + +### operator++ inline constexpr noexcept + + + + +Increments the stored iterator. + + +```cpp showLineNumbers={false} +transform_iterator & cuda::transform_iterator<_Fn, _Iter>::operator++() noexcept(++::cuda::std::declval< _Iter & >()) +``` + + + + + +Increments the stored iterator. + + +```cpp showLineNumbers={false} +auto cuda::transform_iterator<_Fn, _Iter>::operator++( + int +) noexcept(++::cuda::std::declval< _Iter & >()) +``` + + + + + +### operator-- inline constexpr noexcept + + + + +Decrements the stored iterator. + + +```cpp showLineNumbers={false} +template +transform_iterator & cuda::transform_iterator<_Fn, _Iter>::operator--() noexcept(--::cuda::std::declval< _Iter2 & >()) +``` + + + + + +Decrements the stored iterator. + + +```cpp showLineNumbers={false} +template +transform_iterator cuda::transform_iterator<_Fn, _Iter>::operator--( + int +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter > &&noexcept(--::cuda::std::declval< _Iter2 & >())) +``` + + + + + +### operator+= inline constexpr noexcept + +Increments the `transform_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +template +transform_iterator & cuda::transform_iterator<_Fn, _Iter>::operator+=( + difference_type __n +) noexcept(::cuda::std::declval< _Iter2 & >()+=__n) +``` + + +**Parameters** + + +The number of elements to increment + + +### operator-= inline constexpr noexcept + +Decrements the `transform_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +template +transform_iterator & cuda::transform_iterator<_Fn, _Iter>::operator-=( + difference_type __n +) noexcept(::cuda::std::declval< _Iter2 & >() -=__n) +``` + + +**Parameters** + + +The number of elements to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `::cuda::std::conditional_t< ::cuda::std::__has_random_access_traversal< _Iter >, ::cuda::std::random_access_iterator_tag, ::cuda::std::conditional_t<::cuda::std::__has_bidirectional_traversal< _Iter >, ::cuda::std::bidirectional_iterator_tag, ::cuda::std::conditional_t<::cuda::std::__has_forward_traversal< _Iter >, ::cuda::std::forward_iterator_tag, ::cuda::std::input_iterator_tag > > >` | +| `value_type` | `::cuda::std::remove_cvref_t<::cuda::std::invoke_result_t< _Fn &, ::cuda::std::iter_reference_t< _Iter > > >` | +| `difference_type` | `::cuda::std::iter_difference_t< _Iter >` | +| `reference` | `::cuda::std::invoke_result_t< _Fn &, ::cuda::std::iter_reference_t< _Iter > >` | +| `pointer` | `void` | diff --git a/fern/cudapages/cuda/cuda/cuda/transform_output_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/transform_output_iterator.mdx new file mode 100644 index 0000000..5c7b846 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/transform_output_iterator.mdx @@ -0,0 +1,331 @@ +--- +title: "cuda::transform_output_iterator" +description: "[`transform_output_iterator`](/libcudacxx/api/cuda::transform_output_iterator) is a special kind of output iterator which transforms a value written upon dereference." +--- + +`transform_output_iterator` is a special kind of output iterator which transforms a value written upon dereference. + +This iterator is useful for transforming an output from algorithms without explicitly storing the intermediate result in the memory and applying subsequent transformation, thereby avoiding wasting memory capacity and bandwidth. Using `transform_output_iterator` facilitates kernel fusion by deferring execution of transformation until the value is written while saving both memory capacity and bandwidth. + +The following code snippet demonstrated how to create a `transform_output_iterator` which applies `sqrtf` to the assigning value. + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include + +struct square_root +{ + __host__ __device__ + float operator()(float x) const + { + return cuda::std::sqrtf(x); + } +}; + +int main() +{ + thrust::device_vector v(4); + cuda::transform_output_iterator iter(v.begin(), square_root()); + + iter[0] = 1.0f; // stores sqrtf( 1.0f) + iter[1] = 4.0f; // stores sqrtf( 4.0f) + iter[2] = 9.0f; // stores sqrtf( 9.0f) + iter[3] = 16.0f; // stores sqrtf(16.0f) + // iter[4] is an out-of-bounds error + + v[0]; // returns 1.0f; + v[1]; // returns 2.0f; + v[2]; // returns 3.0f; + v[3]; // returns 4.0f; + +} +``` + + + + + + + + + + + + + +--- + +## Constructors + +### transform_output_iterator inline constexpr noexcept + + + + +Default constructs a `transform_output_iterator` with a value initialized iterator and functor. + + +```cpp showLineNumbers={false} +template +cuda::transform_output_iterator<_Fn, _Iter>::transform_output_iterator() noexcept(::cuda::std::is_nothrow_default_constructible_v< _Iter2 > &&::cuda::std::is_nothrow_default_constructible_v< _Fn2 >) +``` + + + + + +Constructs a `transform_output_iterator` with a given iterator and output functor. + + +```cpp showLineNumbers={false} +cuda::transform_output_iterator<_Fn, _Iter>::transform_output_iterator( + _Iter __iter, + _Fn __func +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Iter > &&::cuda::std::is_nothrow_move_constructible_v< _Fn >) +``` + + +**Parameters** + + +The iterator to transform + + + +The output function to apply to the iterator on assignment + + + + + +--- + +## Methods + +### base inline constexpr noexcept nodiscard + + + + +Extracts the stored iterator. + + +```cpp showLineNumbers={false} +_Iter cuda::transform_output_iterator<_Fn, _Iter>::base() && noexcept(::cuda::std::is_nothrow_move_constructible_v< _Iter >) +``` + + + + + +const + +Returns a const reference to the stored iterator. + + +```cpp showLineNumbers={false} +const _Iter & cuda::transform_output_iterator<_Fn, _Iter>::base() const & noexcept +``` + + + + + +### operator* inline constexpr noexcept nodiscard + + + + +Returns a proxy that transforms the input upon assignment. + + +```cpp showLineNumbers={false} +auto cuda::transform_output_iterator<_Fn, _Iter>::operator*() noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter >) +``` + + + + + +const + +Returns a proxy that transforms the input upon assignment. + + +```cpp showLineNumbers={false} +auto cuda::transform_output_iterator<_Fn, _Iter>::operator*() const noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter >) +``` + + + + + +### operator[] inline constexpr noexcept nodiscard + + + + +Subscripts the `transform_output_iterator`. + + +```cpp showLineNumbers={false} +template +auto cuda::transform_output_iterator<_Fn, _Iter>::operator[]( + difference_type __n +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter2 > &&noexcept(::cuda::std::declval< _Iter2 & >()+__n)) +``` + + +**Returns:** A proxy that transforms the input upon assignment storing the current iterator advanced by a given + +**Parameters** + + +The number of elements to advance by + + + + + +const + +Subscripts the `transform_output_iterator`. + + +```cpp showLineNumbers={false} +template +auto cuda::transform_output_iterator<_Fn, _Iter>::operator[]( + difference_type __n +) const noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter2 > &&noexcept(::cuda::std::declval< const _Iter2 & >()+__n)) +``` + + +**Returns:** A proxy that transforms the input upon assignment storing the current iterator advanced by a given + +**Parameters** + + +The number of elements to advance by + + + + + +### operator++ inline constexpr noexcept + + + + +Increments the stored iterator. + + +```cpp showLineNumbers={false} +transform_output_iterator & cuda::transform_output_iterator<_Fn, _Iter>::operator++() noexcept(++::cuda::std::declval< _Iter & >()) +``` + + + + + +Increments the stored iterator. + + +```cpp showLineNumbers={false} +auto cuda::transform_output_iterator<_Fn, _Iter>::operator++( + int +) noexcept(++::cuda::std::declval< _Iter & >()) +``` + + + + + +### operator-- inline constexpr noexcept + + + + +Decrements the stored iterator. + + +```cpp showLineNumbers={false} +template +transform_output_iterator & cuda::transform_output_iterator<_Fn, _Iter>::operator--() noexcept(--::cuda::std::declval< _Iter2 & >()) +``` + + + + + +Decrements the stored iterator. + + +```cpp showLineNumbers={false} +template +transform_output_iterator cuda::transform_output_iterator<_Fn, _Iter>::operator--( + int +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Iter > &&noexcept(--::cuda::std::declval< _Iter2 & >())) +``` + + + + + +### operator+= inline constexpr noexcept + +Increments the `transform_output_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +template +transform_output_iterator & cuda::transform_output_iterator<_Fn, _Iter>::operator+=( + difference_type __n +) noexcept(::cuda::std::declval< _Iter2 & >()+=__n) +``` + + +**Parameters** + + +The number of elements to increment + + +### operator-= inline constexpr noexcept + +Decrements the `transform_output_iterator` by a given number of elements. + + +```cpp showLineNumbers={false} +template +transform_output_iterator & cuda::transform_output_iterator<_Fn, _Iter>::operator-=( + difference_type __n +) noexcept(::cuda::std::declval< _Iter2 & >() -=__n) +``` + + +**Parameters** + + +The number of elements to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `::cuda::std::conditional_t< ::cuda::std::__has_random_access_traversal< _Iter >, ::cuda::std::random_access_iterator_tag, ::cuda::std::conditional_t<::cuda::std::__has_bidirectional_traversal< _Iter >, ::cuda::std::bidirectional_iterator_tag, ::cuda::std::conditional_t<::cuda::std::__has_forward_traversal< _Iter >, ::cuda::std::forward_iterator_tag, ::cuda::std::output_iterator_tag > > >` | +| `iterator_category` | `::cuda::std::output_iterator_tag` | +| `difference_type` | `::cuda::std::iter_difference_t< _Iter >` | +| `value_type` | `void` | +| `pointer` | `void` | +| `reference` | `void` | diff --git a/fern/cudapages/cuda/cuda/cuda/zip_function.mdx b/fern/cudapages/cuda/cuda/cuda/zip_function.mdx new file mode 100644 index 0000000..ab5139d --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/zip_function.mdx @@ -0,0 +1,67 @@ +--- +title: "cuda::zip_function" +description: "Adaptor that transforms a functor taking arguments of types `Ts`... into one accepting a `tuple`." +--- + +Adaptor that transforms a functor taking arguments of types `Ts`... into one accepting a `tuple`. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The functor to wrap + + + + + +--- + +## Constructors + +### zip_function inline constexpr noexcept + + + + +default construct a `zip_function` + + +```cpp showLineNumbers={false} +template +cuda::zip_function<_Fn>::zip_function() noexcept(::cuda::std::is_nothrow_default_constructible_v< _Fn2 >) +``` + + + + + +construct a `zip_function` from a functor + + +```cpp showLineNumbers={false} +cuda::zip_function<_Fn>::zip_function( + const _Fn &__fun +) noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Fn >) +``` + + + + + +construct a `zip_function` from a functor + + +```cpp showLineNumbers={false} +cuda::zip_function<_Fn>::zip_function( + _Fn &&__fun +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Fn >) +``` + + + + diff --git a/fern/cudapages/cuda/cuda/cuda/zip_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/zip_iterator.mdx new file mode 100644 index 0000000..8d0cbf0 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/zip_iterator.mdx @@ -0,0 +1,332 @@ +--- +title: "cuda::zip_iterator" +description: "[`zip_iterator`](/libcudacxx/api/cuda::zip_iterator) is an iterator which represents a `tuple` of iterators." +--- + +`zip_iterator` is an iterator which represents a `tuple` of iterators. + +This iterator is useful for creating a virtual array of structures while achieving the same performance and bandwidth as the structure of arrays idiom. `zip_iterator` also facilitates kernel fusion by providing a convenient means of amortizing the execution of the same operation over multiple ranges. + +The following code snippet demonstrates how to create a `zip_iterator` which represents the result of "zipping" multiple ranges together. + +This example shows how to use `zip_iterator` to copy multiple ranges with a single call to `thrust::copy`. + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include + +thrust::device_vector int_v{0, 1, 2}; +thrust::device_vector float_v{0.0f, 1.0f, 2.0f}; +thrust::device_vector char_v{'a', 'b', 'c'}; + +cuda::zip_iterator iter{int_v.begin(), float_v.begin(), char_v.begin()}; + +*iter; // returns (0, 0.0f, 'a') +iter[0]; // returns (0, 0.0f, 'a') +iter[1]; // returns (1, 1.0f, 'b') +iter[2]; // returns (2, 2.0f, 'c') + +cuda::std::get<0>(iter[2]); // returns 2 +cuda::std::get<1>(iter[0]); // returns 0.0f +cuda::std::get<2>(iter[1]); // returns 'b' + +// iter[3] is an out-of-bounds error +``` + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + thrust::device_vector int_in{0, 1, 2}, int_out(3); + thrust::device_vector float_in{0.0f, 10.0f, 20.0f}, float_out(3); + + thrust::copy(cuda::zip_iterator{int_in.begin(), float_in.begin()}, + cuda::zip_iterator{int_in.end(), float_in.end()}, + cuda::zip_iterator{int_out.begin(),float_out.begin()}); + + // int_out is now [0, 1, 2] + // float_out is now [0.0f, 10.0f, 20.0f] + + return 0; +} +``` + + + + + + + + + + +**Inherits from:** `__zv_iter_category_base< _Iterators... >` (public) + +--- + +## Constructors + +### zip_iterator + + + + +Default-constructs a `zip_iterator` by defaulting all stored iterators. + + +```cpp showLineNumbers={false} +cuda::zip_iterator<_Iterators>::zip_iterator() = default +``` + + + + + +inline constexpr explicit + +Constructs a `zip_iterator` from a tuple of iterators. + + +```cpp showLineNumbers={false} +cuda::zip_iterator<_Iterators>::zip_iterator( + ::cuda::std::tuple<_Iterators...> __iters +) +``` + + +**Parameters** + + +A tuple of iterators + + + + + +inline constexpr explicit + +Constructs a `zip_iterator` from a tuple of iterators. + + +```cpp showLineNumbers={false} +template +cuda::zip_iterator<_Iterators>::zip_iterator( + ::cuda::std::tuple<_Iterators...> __iters +) +``` + + +**Parameters** + + +A tuple of iterators + + + + + +inline constexpr explicit + +Constructs a `zip_iterator` from variadic set of iterators. + + +```cpp showLineNumbers={false} +cuda::zip_iterator<_Iterators>::zip_iterator( + _Iterators... __iters +) +``` + + +**Parameters** + + +The input iterators + + + + + +inline constexpr + +Converts a different `zip_iterator`. + + +```cpp showLineNumbers={false} +template +cuda::zip_iterator<_Iterators>::zip_iterator( + zip_iterator<_OtherIters...> __iter +) +``` + + +**Parameters** + + +The other `zip_iterator` + + + + + +--- + +## Methods + +### operator* inline constexpr const noexcept nodiscard + +Dereferences the `zip_iterator`. + + +```cpp showLineNumbers={false} +auto cuda::zip_iterator<_Iterators>::operator*() const noexcept(::cuda::std::apply(__zip_op_star{}, __current_)) +``` + + +**Returns:** A tuple of references obtained by referencing every stored iterator + +### operator[] inline constexpr const noexcept + +Subscripts the `zip_iterator` with an offset. + + +```cpp showLineNumbers={false} +template +auto cuda::zip_iterator<_Iterators>::operator[]( + difference_type __n +) const noexcept(::cuda::std::apply(__zip_op_index{__n}, __current_)) +``` + + +**Returns:** A tuple of references obtained by subscripting every stored iterator + +**Parameters** + + +The additional offset + + +### operator++ inline constexpr + + + + +noexcept + +Increments all stored iterators. + + +```cpp showLineNumbers={false} +zip_iterator & cuda::zip_iterator<_Iterators>::operator++() noexcept(::cuda::std::apply(__zip_op_increment{}, __current_)) +``` + + + + + +Increments all stored iterators. + + +```cpp showLineNumbers={false} +auto cuda::zip_iterator<_Iterators>::operator++( + int +) +``` + + +**Returns:** A copy of the original `zip_iterator` if possible + + + + +### operator-- inline constexpr + + + + +noexcept + +Decrements all stored iterators. + + +```cpp showLineNumbers={false} +template +zip_iterator & cuda::zip_iterator<_Iterators>::operator--() noexcept(::cuda::std::apply(__zip_op_decrement{}, __current_)) +``` + + + + + +Decrements all stored iterators. + + +```cpp showLineNumbers={false} +template +zip_iterator cuda::zip_iterator<_Iterators>::operator--( + int +) +``` + + + + + +### operator+= inline constexpr noexcept + +Increments all stored iterators by a given number of elements. + + +```cpp showLineNumbers={false} +template +zip_iterator & cuda::zip_iterator<_Iterators>::operator+=( + difference_type __n +) noexcept(::cuda::std::apply(__zip_op_pe{__n}, __current_)) +``` + + +**Parameters** + + +The number of elements to increment + + +### operator-= inline constexpr noexcept + +Decrements all stored iterators by a given number of elements. + + +```cpp showLineNumbers={false} +template +zip_iterator & cuda::zip_iterator<_Iterators>::operator-=( + difference_type __n +) noexcept(::cuda::std::apply(__zip_op_me{__n}, __current_)) +``` + + +**Parameters** + + +The number of elements to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `decltype(__get_zip_iterator_concept< _Iterators... >())` | +| `value_type` | `::cuda::std::tuple< __zip_maybe_proxy_value_type_t< _Iterators >... >` | +| `reference` | `::cuda::std::tuple< __zip_maybe_proxy_reference_t< _Iterators >... >` | +| `difference_type` | `::cuda::std::common_type_t<::cuda::std::iter_difference_t< _Iterators >... >` | +| `pointer` | `void` | diff --git a/fern/cudapages/cuda/cuda/cuda/zip_transform_iterator.mdx b/fern/cudapages/cuda/cuda/cuda/zip_transform_iterator.mdx new file mode 100644 index 0000000..0345da2 --- /dev/null +++ b/fern/cudapages/cuda/cuda/cuda/zip_transform_iterator.mdx @@ -0,0 +1,307 @@ +--- +title: "cuda::zip_transform_iterator" +description: "[`zip_transform_iterator`](/libcudacxx/api/cuda::zip_transform_iterator) is an iterator which represents the result of a transformation of a set of sequences with a given function." +--- + +`zip_transform_iterator` is an iterator which represents the result of a transformation of a set of sequences with a given function. + +This iterator is useful for creating a range filled with the result of applying an operation to another range without either explicitly storing it in memory, or explicitly executing the transformation. Using `zip_transform_iterator` facilitates kernel fusion by deferring the execution of a transformation until the value is needed while saving both memory capacity and bandwidth. + +`zip_transform_iterator` is morally equivalent to a combination of [transform_iterator](/libcudacxx/api/cuda::transform_iterator) and [zip_iterator](/libcudacxx/api/cuda::zip_iterator) + +`zip_transform_iterator` has the additional benefit that it does not require an artificial [`zip_function`](/libcudacxx/api/cuda::zip_function) to work and more importantly does not need to materialize the result of dereferencing the stored iterators when passing them to the stored function. + +The following code snippet demonstrates how to create a `zip_transform_iterator` which represents the result of "zipping" multiple ranges together. + +This example shows how to use `zip_transform_iterator` to copy multiple ranges with a single call to `thrust::copy`. + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +template +using zip_transform_iterator = cuda::transform_iterator, cuda::zip_function>; +``` + +```cpp showLineNumbers={false} +#include +#include + +struct SumArgs { + __host__ __device__ float operator()(float a, float b, float c) const noexcept { + return a + b + c; + } +}; + +thrust::device_vector A{0.f, 1.f, 2.f}; +thrust::device_vector B{1.f, 2.f, 3.f}; +thrust::device_vector C{2.f, 3.f, 4.f}; + +cuda::zip_transform_iterator iter{SumArgs{}, A.begin(), B.begin(), C.begin()}; + +*iter; // returns (3.f) +iter[0]; // returns (3.f) +iter[1]; // returns (6.f) +iter[2]; // returns (9.f) +// iter[3] is an out-of-bounds error +``` + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + struct SumArgs { + __host__ __device__ float operator()(float a, float b, float c) const noexcept { + return a + b + c; + } + }; + + thrust::device_vector A{0.f, 1.f, 2.f}; + thrust::device_vector B{1.f, 2.f, 3.f}; + thrust::device_vector C{2.f, 3.f, 4.f}; + thrust::device_vector out(3); + + cuda::zip_transform_iterator iter{SumArgs{}, A.begin(), B.begin(), C.begin()} + thrust::copy(iter, iter + 3, out.begin()); + + // out is now [3.0f, 6.0f, 9.0f] + + return 0; +} +``` + + + + + + + + + + + + + +--- + +## Constructors + +### zip_transform_iterator inline constexpr + + + + +noexcept + +Default-constructs a `zip_transform_iterator` by value-initializing the functor and all stored iterators. + + +```cpp showLineNumbers={false} +template +cuda::zip_transform_iterator<_Fn, _Iterators>::zip_transform_iterator() noexcept(::cuda::std::is_nothrow_default_constructible_v< _Fn2 > &&__zip_iter_constraints< _Iterators... >::__all_nothrow_default_constructible) +``` + + + + + +explicit + +Constructs a `zip_transform_iterator` from a tuple of iterators. + + +```cpp showLineNumbers={false} +cuda::zip_transform_iterator<_Fn, _Iterators>::zip_transform_iterator( + _Fn __fun, + ::cuda::std::tuple<_Iterators...> __iters +) +``` + + +**Parameters** + + +A tuple or pair of iterators + + + + + +explicit + +Constructs a `zip_transform_iterator` from variadic set of iterators. + + +```cpp showLineNumbers={false} +cuda::zip_transform_iterator<_Fn, _Iterators>::zip_transform_iterator( + _Fn __fun, + _Iterators... __iters +) +``` + + +**Parameters** + + +The input iterators + + + + + +--- + +## Methods + +### operator* inline constexpr const noexcept nodiscard + +Invokes the stored function with the result of dereferencing the stored iterators. + + +```cpp showLineNumbers={false} +reference cuda::zip_transform_iterator<_Fn, _Iterators>::operator*() const noexcept(::cuda::std::is_nothrow_invocable_v< _Fn &, ::cuda::std::iter_reference_t< const _Iterators >... >) +``` + + +### operator[] inline constexpr const noexcept + +Invokes the stored function with the result of dereferencing the stored iterators advanced by an offset. + + +```cpp showLineNumbers={false} +template +reference cuda::zip_transform_iterator<_Fn, _Iterators>::operator[]( + difference_type __n +) const noexcept(::cuda::std::apply(__zip_transform_op_subscript{__n, ::cuda::std::declval< _Fn & >()}, ::cuda::std::declval< const ::cuda::std::tuple< _Iterators... > & >())) +``` + + +**Parameters** + + +The additional offset + + +### operator++ inline constexpr + + + + +noexcept + +Increments all stored iterators. + + +```cpp showLineNumbers={false} +zip_transform_iterator & cuda::zip_transform_iterator<_Fn, _Iterators>::operator++() noexcept(::cuda::std::apply(__zip_op_increment{}, ::cuda::std::declval<::cuda::std::tuple< _Iterators... > & >())) +``` + + + + + +Increments all stored iterators. + + +```cpp showLineNumbers={false} +auto cuda::zip_transform_iterator<_Fn, _Iterators>::operator++( + int +) +``` + + +**Returns:** A copy of the original `zip_transform_iterator` if possible + + + + +### operator-- inline constexpr + + + + +noexcept + +Decrements all stored iterators. + + +```cpp showLineNumbers={false} +template +zip_transform_iterator & cuda::zip_transform_iterator<_Fn, _Iterators>::operator--() noexcept(::cuda::std::apply(__zip_op_decrement{}, ::cuda::std::declval<::cuda::std::tuple< _Iterators... > & >())) +``` + + + + + +Decrements all stored iterators. + + +```cpp showLineNumbers={false} +template +zip_transform_iterator cuda::zip_transform_iterator<_Fn, _Iterators>::operator--( + int +) +``` + + + + + +### operator+= inline constexpr noexcept + +Increments all stored iterators by a given number of elements. + + +```cpp showLineNumbers={false} +template +zip_transform_iterator & cuda::zip_transform_iterator<_Fn, _Iterators>::operator+=( + difference_type __n +) noexcept(::cuda::std::apply(__zip_op_pe{__n}, ::cuda::std::declval<::cuda::std::tuple< _Iterators... > & >())) +``` + + +**Parameters** + + +The number of elements to increment + + +### operator-= inline constexpr noexcept + +Decrements all stored iterators by a given number of elements. + + +```cpp showLineNumbers={false} +template +zip_transform_iterator & cuda::zip_transform_iterator<_Fn, _Iterators>::operator-=( + difference_type __n +) noexcept(::cuda::std::apply(__zip_op_me{__n}, ::cuda::std::declval<::cuda::std::tuple< _Iterators... > & >())) +``` + + +**Parameters** + + +The number of elements to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `iterator_concept` | `decltype(::cuda::__get_zip_iterator_concept< _Iterators... >())` | +| `iterator_category` | `decltype(::cuda::__get_zip_transform_iterator_category< _Fn, _Iterators... >())` | +| `difference_type` | `::cuda::std::common_type_t<::cuda::std::iter_difference_t< _Iterators >... >` | +| `value_type` | `::cuda::std::remove_cvref_t<::cuda::std::invoke_result_t< _Fn &, ::cuda::std::iter_reference_t< _Iterators >... > >` | +| `reference` | `::cuda::std::invoke_result_t< _Fn &, ::cuda::std::iter_reference_t< _Iterators >... >` | +| `pointer` | `void` | diff --git a/fern/cudapages/thrust/thrust/thrust/allocator_delete.mdx b/fern/cudapages/thrust/thrust/thrust/allocator_delete.mdx new file mode 100644 index 0000000..38ce376 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/allocator_delete.mdx @@ -0,0 +1,171 @@ +--- +title: thrust::allocator_delete +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + + + + +This class is marked final. + +--- + +## Constructors + +### allocator_delete inline noexcept + + + + + +```cpp showLineNumbers={false} +template +thrust::allocator_delete::allocator_delete( + UAllocator &&other +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +template +thrust::allocator_delete::allocator_delete( + allocator_delete const &other +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +template +thrust::allocator_delete::allocator_delete( + allocator_delete &&other +) noexcept +``` + + + + + +--- + +## Assignment operators + +### operator= inline noexcept + + + + + +```cpp showLineNumbers={false} +template +allocator_delete & thrust::allocator_delete::operator=( + allocator_delete const &other +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +template +allocator_delete & thrust::allocator_delete::operator=( + allocator_delete &&other +) noexcept +``` + + + + + +--- + +## Methods + +### operator() inline + + +```cpp showLineNumbers={false} +void thrust::allocator_delete::operator()( + pointer p +) +``` + + +### get_allocator inline noexcept + + + + + +```cpp showLineNumbers={false} +allocator_type & thrust::allocator_delete::get_allocator() noexcept +``` + + + + + +const + + +```cpp showLineNumbers={false} +allocator_type const & thrust::allocator_delete::get_allocator() const noexcept +``` + + + + + +### swap inline noexcept + + +```cpp showLineNumbers={false} +void thrust::allocator_delete::swap( + allocator_delete &other +) noexcept +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `allocator_type` | `typename std::remove_cv< typename std::remove_reference< Allocator >::type >::type::template rebind< T >::other` | +| `pointer` | `typename ::cuda::std::allocator_traits< allocator_type >::pointer` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `alloc_` | `allocator_type` | | diff --git a/fern/cudapages/thrust/thrust/thrust/array_allocator_delete.mdx b/fern/cudapages/thrust/thrust/thrust/array_allocator_delete.mdx new file mode 100644 index 0000000..32797b8 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/array_allocator_delete.mdx @@ -0,0 +1,173 @@ +--- +title: thrust::array_allocator_delete +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + + + + +This class is marked final. + +--- + +## Constructors + +### array_allocator_delete inline noexcept + + + + + +```cpp showLineNumbers={false} +template +thrust::array_allocator_delete::array_allocator_delete( + UAllocator &&other, + std::size_t n +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +template +thrust::array_allocator_delete::array_allocator_delete( + array_allocator_delete const &other +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +template +thrust::array_allocator_delete::array_allocator_delete( + array_allocator_delete &&other +) noexcept +``` + + + + + +--- + +## Assignment operators + +### operator= inline noexcept + + + + + +```cpp showLineNumbers={false} +template +array_allocator_delete & thrust::array_allocator_delete::operator=( + array_allocator_delete const &other +) noexcept +``` + + + + + + +```cpp showLineNumbers={false} +template +array_allocator_delete & thrust::array_allocator_delete::operator=( + array_allocator_delete &&other +) noexcept +``` + + + + + +--- + +## Methods + +### operator() inline + + +```cpp showLineNumbers={false} +void thrust::array_allocator_delete::operator()( + pointer p +) +``` + + +### get_allocator inline noexcept + + + + + +```cpp showLineNumbers={false} +allocator_type & thrust::array_allocator_delete::get_allocator() noexcept +``` + + + + + +const + + +```cpp showLineNumbers={false} +allocator_type const & thrust::array_allocator_delete::get_allocator() const noexcept +``` + + + + + +### swap inline noexcept + + +```cpp showLineNumbers={false} +void thrust::array_allocator_delete::swap( + array_allocator_delete &other +) noexcept +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `allocator_type` | `typename std::remove_cv< typename std::remove_reference< Allocator >::type >::type::template rebind< T >::other` | +| `pointer` | `typename ::cuda::std::allocator_traits< allocator_type >::pointer` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `alloc_` | `allocator_type` | | +| `count_` | `std::size_t` | | diff --git a/fern/cudapages/thrust/thrust/thrust/bidirectional_device_iterator_tag.mdx b/fern/cudapages/thrust/thrust/thrust/bidirectional_device_iterator_tag.mdx new file mode 100644 index 0000000..89ecab2 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/bidirectional_device_iterator_tag.mdx @@ -0,0 +1,17 @@ +--- +title: thrust::bidirectional_device_iterator_tag +description: "[`bidirectional_device_iterator_tag`](/library/api/thrust::bidirectional_device_iterator_tag) is an empty class: it has no member functions, member variables, or nested types." +--- + +`bidirectional_device_iterator_tag` is an empty class: it has no member functions, member variables, or nested types. + +It is used solely as a "tag": a representation of the Bidirectional Device Iterator concept within the C++ type system. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[https://en.cppreference.com/w/cpp/iterator/iterator_tags](https://en.cppreference.com/w/cpp/iterator/iterator_tags) iterator_traits, [input_device_iterator_tag](/library/api/thrust::input_device_iterator_tag), [output_device_iterator_tag](/library/api/thrust::output_device_iterator_tag), [forward_device_iterator_tag](/library/api/thrust::forward_device_iterator_tag), [random_access_device_iterator_tag](/library/api/thrust::random_access_device_iterator_tag), input_host_iterator_tag, output_host_iterator_tag, forward_host_iterator_tag, bidirectional_host_iterator_tag, random_access_host_iterator_tag + +**Inherits from:** `detail::iterator_category_with_system_and_traversal<::cuda::std::bidirectional_iterator_tag, device_system_tag, bidirectional_traversal_tag >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/bidirectional_traversal_tag.mdx b/fern/cudapages/thrust/thrust/thrust/bidirectional_traversal_tag.mdx new file mode 100644 index 0000000..2cd4b2b --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/bidirectional_traversal_tag.mdx @@ -0,0 +1,12 @@ +--- +title: thrust::bidirectional_traversal_tag +description: "Tag type for iterators allowing bidirectional traversal." +--- + +Tag type for iterators allowing bidirectional traversal. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::forward_traversal_tag` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/compile_time_value.mdx b/fern/cudapages/thrust/thrust/thrust/compile_time_value.mdx new file mode 100644 index 0000000..755c66f --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/compile_time_value.mdx @@ -0,0 +1,27 @@ +--- +title: thrust::compile_time_value +description: "Holds a compile-time value." +--- + +Holds a compile-time value. + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `value` static constexpr | `decltype(Value)` | | diff --git a/fern/cudapages/thrust/thrust/thrust/complex.mdx b/fern/cudapages/thrust/thrust/thrust/complex.mdx new file mode 100644 index 0000000..7ccc279 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/complex.mdx @@ -0,0 +1,710 @@ +--- +title: thrust::complex +description: "`complex` is the Thrust equivalent to `std::complex`." +--- + +`complex` is the Thrust equivalent to `std::complex`. + +It is functionally identical to it, but can also be used in device code which `std::complex` currently cannot. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type used to hold the real and imaginary parts. Should be `float` or `double`. Others types are not supported. embed:rst:leading-asterisk +* .. versionadded:: 2.2.0 +* + + + + + +--- + +## Constructors + +### complex + + + + +Construct a complex number with an imaginary part of 0. + + +```cpp showLineNumbers={false} +thrust::complex::complex( + const T &re +) +``` + + +**Parameters** + + +The real part of the number. + + + + + +Construct a complex number from its real and imaginary parts. + + +```cpp showLineNumbers={false} +thrust::complex::complex( + const T &re, + const T &im +) +``` + + +**Parameters** + + +The real part of the number. + + + +The imaginary part of the number. + + + + + +Default construct a complex number. + + +```cpp showLineNumbers={false} +thrust::complex::complex() = default +``` + + + + + +This copy constructor copies from a `complex` with a type that is convertible to this `complex's` [`value_type`](/library/api/thrust::complex::value_type). + + +```cpp showLineNumbers={false} +thrust::complex::complex( + const complex &z +) = default +``` + + +**Parameters** + + +The `complex` to copy from. + + + + + +This converting copy constructor copies from a `complex` with a type that is convertible to this `complex's` [`value_type`](/library/api/thrust::complex::value_type). + + +```cpp showLineNumbers={false} +template +thrust::complex::complex( + const complex &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The `complex` to copy from. + + + + + +This converting copy constructor copies from a `std::complex` with a type that is convertible to this `complex's` [`value_type`](/library/api/thrust::complex::value_type). + + +```cpp showLineNumbers={false} +thrust::complex::complex( + const ::std::complex &z +) +``` + + +**Parameters** + + +The `complex` to copy from. + + + + + +This converting copy constructor copies from a `std::complex` with a type that is convertible to this `complex's` [`value_type`](/library/api/thrust::complex::value_type). + + +```cpp showLineNumbers={false} +template +thrust::complex::complex( + const ::std::complex &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The `complex` to copy from. + + + + + +### operator::std::complex inline const + +Casts this `complex` to a `std::complex` of the same type. + + +```cpp showLineNumbers={false} +thrust::complex::operator::std::complex() const +``` + + +--- + +## Assignment operators + +### operator= + + + + +Assign `re` to the real part of this `complex` and set the imaginary part to 0. + + +```cpp showLineNumbers={false} +complex & thrust::complex::operator=( + const T &re +) +``` + + +**Parameters** + + +The real part of the number. + + + + + +Assign `z.real()` and `z.imag()` to the real and imaginary parts of this `complex` respectively. + + +```cpp showLineNumbers={false} +complex & thrust::complex::operator=( + const complex &z +) = default +``` + + +**Parameters** + + +The `complex` to copy from. + + + + + +Assign `z.real()` and `z.imag()` to the real and imaginary parts of this `complex` respectively. + + +```cpp showLineNumbers={false} +template +complex & thrust::complex::operator=( + const complex &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The `complex` to copy from. + + + + + +Assign `z.real()` and `z.imag()` to the real and imaginary parts of this `complex` respectively. + + +```cpp showLineNumbers={false} +complex & thrust::complex::operator=( + const ::std::complex &z +) +``` + + +**Parameters** + + +The `complex` to copy from. + + + + + +Assign `z.real()` and `z.imag()` to the real and imaginary parts of this `complex` respectively. + + +```cpp showLineNumbers={false} +template +complex & thrust::complex::operator=( + const ::std::complex &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The `complex` to copy from. + + + + + +--- + +## Methods + +### operator+= + + + + +Adds a `complex` to this `complex` and assigns the result to this `complex`. + + +```cpp showLineNumbers={false} +template +complex & thrust::complex::operator+=( + const complex &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The `complex` to be added. + + + + + +Adds a scalar to this `complex` and assigns the result to this `complex`. + + +```cpp showLineNumbers={false} +template +complex & thrust::complex::operator+=( + const U &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The `complex` to be added. + + + + + +### operator-= + + + + +Subtracts a `complex` from this `complex` and assigns the result to this `complex`. + + +```cpp showLineNumbers={false} +template +complex & thrust::complex::operator-=( + const complex &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The `complex` to be subtracted. + + + + + +Subtracts a scalar from this `complex` and assigns the result to this `complex`. + + +```cpp showLineNumbers={false} +template +complex & thrust::complex::operator-=( + const U &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The scalar to be subtracted. + + + + + +### operator*= + + + + +Multiplies this `complex` by another `complex` and assigns the result to this `complex`. + + +```cpp showLineNumbers={false} +template +complex & thrust::complex::operator*=( + const complex &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The `complex` to be multiplied. + + + + + +Multiplies this `complex` by a scalar and assigns the result to this `complex`. + + +```cpp showLineNumbers={false} +template +complex & thrust::complex::operator*=( + const U &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The scalar to be multiplied. + + + + + +### operator/= + + + + +Divides this `complex` by another `complex` and assigns the result to this `complex`. + + +```cpp showLineNumbers={false} +template +complex & thrust::complex::operator/=( + const complex &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The `complex` to be divided. + + + + + +Divides this `complex` by a scalar and assigns the result to this `complex`. + + +```cpp showLineNumbers={false} +template +complex & thrust::complex::operator/=( + const U &z +) +``` + + +**Template parameters** + + +Is convertible to [`value_type`](/library/api/thrust::complex::value_type). + + +**Parameters** + + +The scalar to be divided. + + + + + +### real inline + + + + +const + +Returns the real part of this `complex`. + + +```cpp showLineNumbers={false} +T thrust::complex::real() const volatile +``` + + + + + +const + +Returns the real part of this `complex`. + + +```cpp showLineNumbers={false} +T thrust::complex::real() const +``` + + + + + +Sets the real part of this `complex`. + + +```cpp showLineNumbers={false} +void thrust::complex::real( + T re +) volatile +``` + + +**Parameters** + + +The new real part of this `complex`. + + + + + +Sets the real part of this `complex`. + + +```cpp showLineNumbers={false} +void thrust::complex::real( + T re +) +``` + + +**Parameters** + + +The new real part of this `complex`. + + + + + +### imag inline + + + + +const + +Returns the imaginary part of this `complex`. + + +```cpp showLineNumbers={false} +T thrust::complex::imag() const volatile +``` + + + + + +const + +Returns the imaginary part of this `complex`. + + +```cpp showLineNumbers={false} +T thrust::complex::imag() const +``` + + + + + +Sets the imaginary part of this `complex`. + + +```cpp showLineNumbers={false} +void thrust::complex::imag( + T im +) volatile +``` + + +**Parameters** + + +The new imaginary part of this `complex.e` + + + + + +Sets the imaginary part of this `complex`. + + +```cpp showLineNumbers={false} +void thrust::complex::imag( + T im +) +``` + + +**Parameters** + + +The new imaginary part of this `complex`. + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `value_type` | `T` | [`value_type`](/library/api/thrust::complex::value_type) is the type of `complex's` real and imaginary parts. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `data` | `storage` | | + +--- + +## Inner classes + +### storage + + +```cpp showLineNumbers={false} +struct thrust::complex::storage +``` + + +| Name | Type | Description | +|---|---|---| +| `x` | `T` | | +| `y` | `T` | | diff --git a/fern/cudapages/thrust/thrust/thrust/constant_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/constant_iterator.mdx new file mode 100644 index 0000000..f4833ef --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/constant_iterator.mdx @@ -0,0 +1,217 @@ +--- +title: thrust::constant_iterator +description: "[`constant_iterator`](/library/api/thrust::constant_iterator) is an iterator which represents a pointer into a range of constant values." +--- + +`constant_iterator` is an iterator which represents a pointer into a range of constant values. + +This iterator is useful for creating a range filled with the same value without explicitly storing it in memory. Using `constant_iterator` saves both memory capacity and bandwidth. + +The following code snippet demonstrates how to create a `constant_iterator` whose `value_type` is `int` and whose value is `10`. + +This next example demonstrates how to use a `constant_iterator` with the `thrust::transform` function to increment all elements of a sequence by the same value. We will create a temporary `constant_iterator` with the function `make_constant_iterator` function in order to avoid explicitly specifying its type: + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +make_constant_iterator + +## Example + +```cpp showLineNumbers={false} +#include + +thrust::constant_iterator iter(10); + +*iter; // returns 10 +iter[0]; // returns 10 +iter[1]; // returns 10 +iter[13]; // returns 10 + +// and so on... +``` + +```cpp showLineNumbers={false} +#include +#include +#include +#include + +int main() +{ + thrust::device_vector data{3, 7, 2, 5}; + + // add 10 to all values in data + thrust::transform(data.begin(), data.end(), + thrust::make_constant_iterator(10), + data.begin(), + ::cuda::std::plus()); + + // data is now [13, 17, 12, 15] + + return 0; +} +``` + + + + + + + + + + + + + + + + +**Inherits from:** `thrust::iterator_adaptor< constant_iterator< Value, use_default, use_default >, base_iterator, Value, iterator_system_t< base_iterator >, iterator_traversal_t< base_iterator >, Value >` (public) + +--- + +## Constructors + +### constant_iterator + + + + +Default constructor initializes this `constant_iterator``'s` constant using its default constructor. + + +```cpp showLineNumbers={false} +thrust::constant_iterator::constant_iterator() = default +``` + + + + + +inline + +Copy constructor copies the value of another `constant_iterator` with related System type. + + +```cpp showLineNumbers={false} +template +thrust::constant_iterator::constant_iterator( + constant_iterator const &rhs +) +``` + + +**Parameters** + + +The `constant_iterator` to copy. + + + + + +inline + +This constructor receives a value to use as the constant value of this `constant_iterator` and an index specifying the location of this `constant_iterator` in a sequence. + +`v` The value of this `constant_iterator``'s` constant value. `i` The index of this `constant_iterator` in a sequence. Defaults to the value returned by `Incrementable's` null constructor. For example, when `Incrementable == int`, `0`. + + +```cpp showLineNumbers={false} +thrust::constant_iterator::constant_iterator( + value_type const &v, + incrementable const &i = incrementable() +) +``` + + + + + +inline + +This constructor is templated to allow construction from a value type and incrementable type related this this `constant_iterator``'s` respective types. + +`v` The value of this `constant_iterator``'s` constant value. `i` The index of this `constant_iterator` in a sequence. Defaults to the value returned by `Incrementable's` null constructor. For example, when `Incrementable == int`, `0`. + + +```cpp showLineNumbers={false} +template +thrust::constant_iterator::constant_iterator( + OtherValue const &v, + OtherIncrementable const &i = incrementable() +) +``` + + + + + +--- + +## Methods + +### value inline const + +This method returns the value of this `constant_iterator``'s` constant value. + + +```cpp showLineNumbers={false} +Value const & thrust::constant_iterator::value() const +``` + + +**Returns:** A `const` reference to this `constant_iterator``'s` constant value. + +### base inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +Base & thrust::iterator_adaptor::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/counting_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/counting_iterator.mdx new file mode 100644 index 0000000..e405289 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/counting_iterator.mdx @@ -0,0 +1,223 @@ +--- +title: thrust::counting_iterator +description: "[`counting_iterator`](/library/api/thrust::counting_iterator) is an iterator which represents a pointer into a range of sequentially changing values." +--- + +`counting_iterator` is an iterator which represents a pointer into a range of sequentially changing values. + +This iterator is useful for creating a range filled with a sequence without explicitly storing it in memory. Using `counting_iterator` saves memory capacity and bandwidth. + +The following code snippet demonstrates how to create a `counting_iterator` whose `value_type` is `int` and which sequentially increments by `1`. + +This next example demonstrates how to use a `counting_iterator` with the `thrust::copy_if` function to compute the indices of the non-zero elements of a [`device_vector`](/library/api/thrust::device_vector). In this example, we use the `make_counting_iterator` function to avoid specifying the type of the `counting_iterator`. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +make_counting_iterator + +## Example + +```cpp showLineNumbers={false} +#include +... +// create iterators +thrust::counting_iterator first(10); +thrust::counting_iterator last = first + 3; + +first[0] // returns 10 +first[1] // returns 11 +first[100] // returns 110 + +// sum of [first, last) +thrust::reduce(first, last); // returns 33 (i.e. 10 + 11 + 12) + +// initialize vector to [0,1,2,..] +thrust::counting_iterator iter(0); +thrust::device_vector vec(500); +thrust::copy(iter, iter + vec.size(), vec.begin()); +``` + +```cpp showLineNumbers={false} +#include +#include +#include +#include + +int main() +{ + // this example computes indices for all the nonzero values in a sequence + + // sequence of zero and nonzero values + thrust::device_vector stencil{0, 1, 1, 0, 0, 1, 0, 1}; + + // storage for the nonzero indices + thrust::device_vector indices(8); + + // compute indices of nonzero elements + using IndexIterator = thrust::device_vector::iterator; + + // use make_counting_iterator to define the sequence [0, 8) + IndexIterator indices_end = thrust::copy_if(thrust::make_counting_iterator(0), + thrust::make_counting_iterator(8), + stencil.begin(), + indices.begin(), + ::cuda::std::identity{}); + // indices now contains [1,2,5,7] + + return 0; +} +``` + + + + + + + + + + + + + + + + + + + + + + +**Inherits from:** `thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >` (public), `thrust::__compile_time_value< 1 >` (private) + +--- + +## Constructors + +### counting_iterator inline + + + + +Default constructor initializes this `counting_iterator``'s` counter to `Incrementable{}`. + + +```cpp showLineNumbers={false} +thrust::counting_iterator::counting_iterator() +``` + + + + + +Copy constructor copies the value of another `counting_iterator` with related System type. + + +```cpp showLineNumbers={false} +template +thrust::counting_iterator::counting_iterator( + counting_iterator const &rhs +) +``` + + +**Parameters** + + +The `counting_iterator` to copy. + + + + + +explicit + +This `explicit` constructor copies the value of an `Incrementable` into a new `counting_iterator``'s` `Incrementable` counter. + + +```cpp showLineNumbers={false} +thrust::counting_iterator::counting_iterator( + Incrementable x +) +``` + + +**Parameters** + + +The initial value of the new `counting_iterator``'s` `Incrementable` counter. + + + + + +explicit + + +```cpp showLineNumbers={false} +thrust::counting_iterator::counting_iterator( + Incrementable x, + StrideHolder stride +) +``` + + + + + +--- + +## Methods + +### base inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +Base & thrust::iterator_adaptor::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/device_allocator.mdx b/fern/cudapages/thrust/thrust/thrust/device_allocator.mdx new file mode 100644 index 0000000..e7ab185 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/device_allocator.mdx @@ -0,0 +1,211 @@ +--- +title: thrust::device_allocator +description: "An allocator which creates new elements in memory accessible by devices." +--- + +An allocator which creates new elements in memory accessible by devices. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[https://en.cppreference.com/w/cpp/named_req/Allocator](https://en.cppreference.com/w/cpp/named_req/Allocator) + + + + + + + + + + +**Inherits from:** `thrust::mr::stateless_resource_allocator< T, device_ptr_memory_resource< device_memory_resource > >` (public) + +--- + +## Constructors + +### device_allocator inline + + + + +Default constructor has no effect. + + +```cpp showLineNumbers={false} +thrust::device_allocator::device_allocator() +``` + + + + + +Copy constructor has no effect. + + +```cpp showLineNumbers={false} +thrust::device_allocator::device_allocator( + const device_allocator &other +) +``` + + + + + +Constructor from other `device_allocator` has no effect. + + +```cpp showLineNumbers={false} +template +thrust::device_allocator::device_allocator( + const device_allocator &other +) +``` + + + + + +### Destructor + +### ~device_allocator inline + +Destructor has no effect. + + +```cpp showLineNumbers={false} +thrust::device_allocator::~device_allocator() +``` + + +--- + +## Assignment operators + +### operator= + + +```cpp showLineNumbers={false} +device_allocator & thrust::device_allocator::operator=( + const device_allocator & +) = default +``` + + +--- + +## Methods + +### max_size inline const + +Calculates the maximum number of elements allocated by this allocator. + + +```cpp showLineNumbers={false} +size_type thrust::mr::allocator::max_size() const +``` + + +**Returns:** the maximum value of `std::size_t`, divided by the size of `T`. + +### allocate inline nodiscard + +Allocates objects of type `T`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::allocator::allocate( + size_type n +) +``` + + +**Returns:** a pointer to the newly allocated storage. + +**Parameters** + + +Number of elements to allocate + + +### deallocate inline noexcept + +Deallocates objects of type `T`. + + +```cpp showLineNumbers={false} +void thrust::mr::allocator::deallocate( + pointer p, + size_type n +) noexcept +``` + + +**Parameters** + + +Pointer returned by a previous call to `allocate` + + + +Number of elements, passed as an argument to the `allocate` call that produced `p` + + +### resource inline const + +Extracts the memory resource used by this allocator. + + +```cpp showLineNumbers={false} +MR * thrust::mr::allocator::resource() const +``` + + +**Returns:** the memory resource used by this allocator. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base` | `thrust::mr::stateless_resource_allocator< T, device_ptr_memory_resource< device_memory_resource > >` | | +| `void_pointer` | `typename MR::pointer` | The pointer to void type of this allocator. | +| `value_type` | `T` | The value type allocated by this allocator. | +| `pointer` | `typename thrust::detail::pointer_traits< void_pointer >::template rebind< T >::other` | The pointer type allocated by this allocator. | +| `const_pointer` | `typename thrust::detail::pointer_traits< void_pointer >::template rebind< const T >::other` | The pointer to const type. | +| `reference` | `typename thrust::detail::pointer_traits< pointer >::reference` | The reference to the type allocated by this allocator. | +| `const_reference` | `typename thrust::detail::pointer_traits< const_pointer >::reference` | The const reference to the type allocated by this allocator. | +| `size_type` | `std::size_t` | The size type of this allocator. | +| `difference_type` | `typename thrust::detail::pointer_traits< pointer >::difference_type` | The difference type between pointers allocated by this allocator. | +| `propagate_on_container_copy_assignment` | `detail::true_type` | Specifies that the allocator shall be propagated on container copy assignment. | +| `propagate_on_container_move_assignment` | `detail::true_type` | Specifies that the allocator shall be propagated on container move assignment. | +| `propagate_on_container_swap` | `detail::true_type` | Specifies that the allocator shall be propagated on container swap. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `mem_res` | `MR *` | | + +--- + +## Inner classes + +### rebind + + +```cpp showLineNumbers={false} +struct thrust::device_allocator::rebind +``` + + +The `rebind` metafunction provides the type of a `device_allocator` instantiated with another type. diff --git a/fern/cudapages/thrust/thrust/thrust/device_execution_policy.mdx b/fern/cudapages/thrust/thrust/thrust/device_execution_policy.mdx new file mode 100644 index 0000000..52863da --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/device_execution_policy.mdx @@ -0,0 +1,72 @@ +--- +title: thrust::device_execution_policy +description: "[`device_execution_policy`](/library/api/thrust::device_execution_policy) is the base class for all Thrust parallel execution policies which are derived from Thrust's default device backend system configured with the `THRUST_DEVICE_SYSTEM` macro." +--- + +`device_execution_policy` is the base class for all Thrust parallel execution policies which are derived from Thrust's default device backend system configured with the `THRUST_DEVICE_SYSTEM` macro. + +Custom user-defined backends which wish to inherit the functionality of Thrust's device backend system should derive a policy from this type in order to interoperate with Thrust algorithm dispatch. + +The following code snippet demonstrates how to derive a standalone custom execution policy from `thrust::device_execution_policy` to implement a backend which specializes `for_each` while inheriting the behavior of every other algorithm from the device system: + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +execution_policy, +[host_execution_policy](/library/api/thrust::host_execution_policy) + +## Example + +```cpp showLineNumbers={false} +#include +#include + +// define a type derived from thrust::device_execution_policy to distinguish our custom execution policy: +struct my_policy : thrust::device_execution_policy {}; + +// overload for_each on my_policy +template +Iterator for_each(my_policy, Iterator first, Iterator last, Function f) +{ + std::cout << "Hello, world from for_each(my_policy)!" << std::endl; + + for(; first < last; ++first) + { + f(*first); + } + + return first; +} + +struct ignore_argument +{ + void operator()(int) {} +}; + +int main() +{ + int data[4]; + + // dispatch thrust::for_each using our custom policy: + my_policy exec; + thrust::for_each(exec, data, data + 4, ignore_argument()); + + // dispatch thrust::transform whose behavior our policy inherits + thrust::transform(exec, data, data, + 4, data, ::cuda::std::identity{}); + + return 0; +} +``` + + + + + + + + + + +**Inherits from:** `thrust::system::__THRUST_DEVICE_SYSTEM_NAMESPACE::execution_policy< DerivedPolicy >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/device_malloc_allocator.mdx b/fern/cudapages/thrust/thrust/thrust/device_malloc_allocator.mdx new file mode 100644 index 0000000..699075b --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/device_malloc_allocator.mdx @@ -0,0 +1,253 @@ +--- +title: thrust::device_malloc_allocator +description: "[`device_malloc_allocator`](/library/api/thrust::device_malloc_allocator) is a device memory allocator that employs the `device_malloc` function for allocation." +--- + +`device_malloc_allocator` is a device memory allocator that employs the `device_malloc` function for allocation. + +`device_malloc_allocator` is deprecated in favor of [`thrust::mr`](/library/api/thrust::mr) memory resource-based allocators. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +device_malloc, +[device_ptr](/library/api/thrust::device_ptr), +[device_allocator](/library/api/thrust::device_allocator), +[https://en.cppreference.com/w/cpp/memory/allocator](https://en.cppreference.com/w/cpp/memory/allocator) + + + + + + + + + + +--- + +## Constructors + +### device_malloc_allocator inline + + + + +No-argument constructor has no effect. + + +```cpp showLineNumbers={false} +thrust::device_malloc_allocator::device_malloc_allocator() +``` + + + + + +Copy constructor has no effect. + + +```cpp showLineNumbers={false} +thrust::device_malloc_allocator::device_malloc_allocator( + device_malloc_allocator const & +) +``` + + + + + +Constructor from other `device_malloc_allocator` has no effect. + + +```cpp showLineNumbers={false} +template +thrust::device_malloc_allocator::device_malloc_allocator( + device_malloc_allocator const & +) +``` + + + + + +### Destructor + +### ~device_malloc_allocator inline + +No-argument destructor has no effect. + + +```cpp showLineNumbers={false} +thrust::device_malloc_allocator::~device_malloc_allocator() +``` + + +--- + +## Assignment operators + +### operator= + + +```cpp showLineNumbers={false} +device_malloc_allocator & thrust::device_malloc_allocator::operator=( + const device_malloc_allocator & +) = default +``` + + +--- + +## Methods + +### address inline + + + + +Returns the address of an allocated object. + + +```cpp showLineNumbers={false} +pointer thrust::device_malloc_allocator::address( + reference r +) +``` + + +**Returns:** `&r`. + + + + +Returns the address an allocated object. + + +```cpp showLineNumbers={false} +const_pointer thrust::device_malloc_allocator::address( + const_reference r +) +``` + + +**Returns:** `&r`. + + + + +### allocate inline + +Allocates storage for `cnt` objects. + + +```cpp showLineNumbers={false} +pointer thrust::device_malloc_allocator::allocate( + size_type cnt, + const_pointer = const_pointer(static_cast(0)) +) +``` + + +**Returns:** A `pointer` to uninitialized storage for `cnt` objects. + +**Parameters** + + +The number of objects to allocate. + + +### deallocate inline noexcept + +Deallocates storage for objects allocated with `allocate`. + + +```cpp showLineNumbers={false} +void thrust::device_malloc_allocator::deallocate( + pointer p, + size_type cnt +) noexcept +``` + + +**Parameters** + + +A `pointer` to the storage to deallocate. + + + +The size of the previous allocation. + + +### max_size inline const + +Returns the largest value `n` for which `allocate(n)` might succeed. + + +```cpp showLineNumbers={false} +size_type thrust::device_malloc_allocator::max_size() const +``` + + +**Returns:** The largest value `n` for which `allocate(n)` might succeed. + +### operator== inline const + +Compares against another `device_malloc_allocator` for equality. + + +```cpp showLineNumbers={false} +bool thrust::device_malloc_allocator::operator==( + device_malloc_allocator const & +) const +``` + + +**Returns:** `true` + +### operator!= inline const + +Compares against another `device_malloc_allocator` for inequality. + + +```cpp showLineNumbers={false} +bool thrust::device_malloc_allocator::operator!=( + device_malloc_allocator const &a +) const +``` + + +**Returns:** `false` + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `value_type` | `T` | Type of element allocated, `T`. | +| `pointer` | `device_ptr< T >` | Pointer to allocation, [`device_ptr`](/library/api/thrust::device_ptr). | +| `const_pointer` | `device_ptr< const T >` | `const` pointer to allocation, [`device_ptr`](/library/api/thrust::device_ptr). | +| `reference` | `device_reference< T >` | Reference to allocated element, [`device_reference`](/library/api/thrust::device_reference). | +| `const_reference` | `device_reference< const T >` | `const` reference to allocated element, [`device_reference`](/library/api/thrust::device_reference). | +| `size_type` | `std::size_t` | Type of allocation size, `std::size_t`. | +| `difference_type` | `typename pointer::difference_type` | Type of allocation difference, `pointer::difference_type`. | + +--- + +## Inner classes + +### rebind + + +```cpp showLineNumbers={false} +struct thrust::device_malloc_allocator::rebind +``` + + +The `rebind` metafunction provides the type of a `device_malloc_allocator` instantiated with another type. diff --git a/fern/cudapages/thrust/thrust/thrust/device_new_allocator.mdx b/fern/cudapages/thrust/thrust/thrust/device_new_allocator.mdx new file mode 100644 index 0000000..44ace9b --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/device_new_allocator.mdx @@ -0,0 +1,236 @@ +--- +title: thrust::device_new_allocator +description: "[`device_new_allocator`](/library/api/thrust::device_new_allocator) is a device memory allocator that employs the `device_new` function for allocation." +--- + +`device_new_allocator` is a device memory allocator that employs the `device_new` function for allocation. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +device_new, +[device_ptr](/library/api/thrust::device_ptr), +[https://en.cppreference.com/w/cpp/memory/allocator](https://en.cppreference.com/w/cpp/memory/allocator) + + + + + + + + + + +--- + +## Constructors + +### device_new_allocator inline + + + + +No-argument constructor has no effect. + + +```cpp showLineNumbers={false} +thrust::device_new_allocator::device_new_allocator() +``` + + + + + +Copy constructor has no effect. + + +```cpp showLineNumbers={false} +thrust::device_new_allocator::device_new_allocator( + device_new_allocator const & +) +``` + + + + + +Constructor from other [`device_malloc_allocator`](/library/api/thrust::device_malloc_allocator) has no effect. + + +```cpp showLineNumbers={false} +template +thrust::device_new_allocator::device_new_allocator( + device_new_allocator const & +) +``` + + + + + +### Destructor + +### ~device_new_allocator inline + +No-argument destructor has no effect. + + +```cpp showLineNumbers={false} +thrust::device_new_allocator::~device_new_allocator() +``` + + +--- + +## Methods + +### address inline + + + + +Returns the address of an allocated object. + + +```cpp showLineNumbers={false} +pointer thrust::device_new_allocator::address( + reference r +) +``` + + +**Returns:** `&r`. + + + + +Returns the address an allocated object. + + +```cpp showLineNumbers={false} +const_pointer thrust::device_new_allocator::address( + const_reference r +) +``` + + +**Returns:** `&r`. + + + + +### allocate inline + +Allocates storage for `cnt` objects. + + +```cpp showLineNumbers={false} +pointer thrust::device_new_allocator::allocate( + size_type cnt, + const_pointer = const_pointer(static_cast(0)) +) +``` + + +**Returns:** A `pointer` to uninitialized storage for `cnt` objects. + +**Parameters** + + +The number of objects to allocate. + + +### deallocate inline noexcept + +Deallocates storage for objects allocated with `allocate`. + + +```cpp showLineNumbers={false} +void thrust::device_new_allocator::deallocate( + pointer p, + size_type cnt +) noexcept +``` + + +**Parameters** + + +A `pointer` to the storage to deallocate. + + + +The size of the previous allocation. + + +### max_size inline const + +Returns the largest value `n` for which `allocate(n)` might succeed. + + +```cpp showLineNumbers={false} +size_type thrust::device_new_allocator::max_size() const +``` + + +**Returns:** The largest value `n` for which `allocate(n)` might succeed. + +### operator== inline + +Compares against another [`device_malloc_allocator`](/library/api/thrust::device_malloc_allocator) for equality. + + +```cpp showLineNumbers={false} +bool thrust::device_new_allocator::operator==( + device_new_allocator const & +) +``` + + +**Returns:** `true` + +### operator!= inline + +Compares against another [`device_malloc_allocator`](/library/api/thrust::device_malloc_allocator) for inequality. + + +```cpp showLineNumbers={false} +bool thrust::device_new_allocator::operator!=( + device_new_allocator const &a +) +``` + + +**Returns:** `false` + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `value_type` | `T` | Type of element allocated, `T`. | +| `pointer` | `device_ptr< T >` | Pointer to allocation, [`device_ptr`](/library/api/thrust::device_ptr). | +| `const_pointer` | `device_ptr< const T >` | `const` pointer to allocation, [`device_ptr`](/library/api/thrust::device_ptr). | +| `reference` | `device_reference< T >` | Reference to allocated element, [`device_reference`](/library/api/thrust::device_reference). | +| `const_reference` | `device_reference< const T >` | `const` reference to allocated element, [`device_reference`](/library/api/thrust::device_reference). | +| `size_type` | `::cuda::std::size_t` | Type of allocation size, `size_t`. | +| `difference_type` | `typename pointer::difference_type` | Type of allocation difference, `pointer::difference_type`. | + +--- + +## Inner classes + +### rebind + + +```cpp showLineNumbers={false} +struct thrust::device_new_allocator::rebind +``` + + +The `rebind` metafunction provides the type of a `device_new_allocator` instantiated with another type. diff --git a/fern/cudapages/thrust/thrust/thrust/device_ptr.mdx b/fern/cudapages/thrust/thrust/thrust/device_ptr.mdx new file mode 100644 index 0000000..f5b34eb --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/device_ptr.mdx @@ -0,0 +1,324 @@ +--- +title: thrust::device_ptr +description: "[`device_ptr`](/library/api/thrust::device_ptr) is a pointer-like object which points to an object that resides in memory associated with the device system." +--- + +`device_ptr` is a pointer-like object which points to an object that resides in memory associated with the device system. + +`device_ptr` has pointer semantics: it may be dereferenced safely from anywhere, including the host, and may be manipulated with pointer arithmetic. + +`device_ptr` can be created with device_new, device_malloc, [device_malloc_allocator](/library/api/thrust::device_malloc_allocator), [device_allocator](/library/api/thrust::device_allocator), or device_pointer_cast, or by explicitly calling its constructor with a raw pointer. + +The raw pointer contained in a `device_ptr` may be obtained via `get` member function or the raw_pointer_cast free function. + +[Algorithms](/library/api/Algorithms) operating on `device_ptr` types will automatically be dispatched to the device system. + +```cpp showLineNumbers={false} +#include +``` + + +`device_ptr` is not a smart pointer; it is the programmer's responsibility to deallocate memory pointed to by `device_ptr`. + + +**See also:** +device_new, +device_malloc, +[device_malloc_allocator](/library/api/thrust::device_malloc_allocator), +[device_allocator](/library/api/thrust::device_allocator), +device_pointer_cast, +raw_pointer_cast + + + + + + + + + + +**Inherits from:** `thrust::pointer< T, thrust::device_system_tag, thrust::device_reference< T >, thrust::device_ptr< T > >` (public) + +--- + +## Constructors + +### device_ptr + + + + + +```cpp showLineNumbers={false} +thrust::device_ptr::device_ptr() = default +``` + + + + + +inline + +Construct a null `device_ptr`. + + +```cpp showLineNumbers={false} +thrust::device_ptr::device_ptr( + std::nullptr_t +) +``` + + + +[`get()`](/library/api/thrust::pointer::get())` == nullptr`. + + + + + +inline explicit + +Construct a `device_ptr` from a raw pointer which is convertible to `T*`. + + +```cpp showLineNumbers={false} +template +thrust::device_ptr::device_ptr( + U *ptr +) +``` + + + +[`get()`](/library/api/thrust::pointer::get())` == nullptr`. + + + +`std::is_convertible_v == true`. + + + +`ptr` points to a location in device memory. + + +**Template parameters** + + +A type whose pointer is convertible to `T*`. + + +**Parameters** + + +A raw pointer to a `U` in device memory to construct from. + + + + + +inline + +Copy construct a `device_ptr` from another `device_ptr` whose pointer type is convertible to `T*`. + + +```cpp showLineNumbers={false} +template +thrust::device_ptr::device_ptr( + device_ptr const &other +) +``` + + + +[`get()`](/library/api/thrust::pointer::get())` == other.get()`. + + + +`std::is_convertible_v == true`. + + +**Template parameters** + + +A type whose pointer is convertible to `T*`. + + +**Parameters** + + +A `device_ptr` to a `U` to construct from. + + + + + +--- + +## Assignment operators + +### operator= inline + + + + +Set this `device_ptr` to point to the same object as another `device_ptr` whose pointer type is convertible to `T*`. + + +```cpp showLineNumbers={false} +template +device_ptr & thrust::device_ptr::operator=( + device_ptr const &other +) +``` + + + +[`get()`](/library/api/thrust::pointer::get())` == other.get()`. + + + +`std::is_convertible_v == true`. + + +**Returns:** `*this`. + +**Template parameters** + + +A type whose pointer is convertible to `T*`. + + +**Parameters** + + +A `device_ptr` to a `U` to assign from. + + + + + +Set this `device_ptr` to null. + + +```cpp showLineNumbers={false} +device_ptr & thrust::device_ptr::operator=( + std::nullptr_t +) +``` + + + +[`get()`](/library/api/thrust::pointer::get())` == nullptr`. + + +**Returns:** `*this`. + + + + +--- + +## Methods + +### get inline const + +`get` returns this `pointer's` encapsulated raw pointer. + + +```cpp showLineNumbers={false} +T * thrust::pointer, thrust::device_ptr>::get() const +``` + + +**Returns:** This `pointer's` raw pointer. + +### operator-> inline const + + +```cpp showLineNumbers={false} +T * thrust::pointer, thrust::device_ptr>::operator->() const +``` + + +### operator bool inline explicit const + + +```cpp showLineNumbers={false} +thrust::pointer, thrust::device_ptr>::operator bool() const +``` + + +### base inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### dereference inline const + + +```cpp showLineNumbers={false} +SuperRef thrust::pointer, thrust::device_ptr>::dereference() const +``` + + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +Base & thrust::iterator_adaptor::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Static methods + +### pointer_to inline static + + +```cpp showLineNumbers={false} +static derived_type thrust::pointer, thrust::device_ptr>::pointer_to( + typename detail::pointer_traits_detail::pointer_to_param::type r +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `super_t` | `thrust::pointer< T, thrust::device_system_tag, thrust::device_reference< T >, thrust::device_ptr< T > >` | | +| `derived_type` | `typename detail::pointer_base< T, thrust::device_system_tag, thrust::device_reference< T >, thrust::device_ptr< T > >::derived_type` | | +| `raw_pointer` | `typename super_t::base_type` | The type of the raw pointer. | +| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/device_ptr_memory_resource.mdx b/fern/cudapages/thrust/thrust/thrust/device_ptr_memory_resource.mdx new file mode 100644 index 0000000..ef21fb0 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/device_ptr_memory_resource.mdx @@ -0,0 +1,239 @@ +--- +title: thrust::device_ptr_memory_resource +description: "Memory resource adaptor that turns any memory resource that returns a fancy with the same tag as [`device_ptr`](/library/api/thrust::device_ptr), and adapts it to a resource that returns a [`device_ptr`](/library/api/thrust::device_ptr)." +--- + +Memory resource adaptor that turns any memory resource that returns a fancy with the same tag as [`device_ptr`](/library/api/thrust::device_ptr), and adapts it to a resource that returns a [`device_ptr`](/library/api/thrust::device_ptr). + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `thrust::mr::memory_resource< device_ptr< void > >` (public) + +This class is marked final. + +--- + +## Constructors + +### device_ptr_memory_resource inline + + + + +Initialize the adaptor with the global instance of the upstream resource. + + +```cpp showLineNumbers={false} +thrust::device_ptr_memory_resource::device_ptr_memory_resource() +``` + + + + + +Initialize the adaptor with an upstream resource. + + +```cpp showLineNumbers={false} +thrust::device_ptr_memory_resource::device_ptr_memory_resource( + Upstream *upstream +) +``` + + +**Parameters** + + +The upstream memory resource to adapt. + + + + + +--- + +## Methods + +### do_allocate inline nodiscard virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +virtual pointer thrust::device_ptr_memory_resource::do_allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### do_deallocate inline virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::device_ptr_memory_resource::do_deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment +) override +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource>::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource>::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource>::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource>::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `upstream_ptr` | `typename Upstream::pointer` | | +| `pointer` | `device_ptr< void >` | Alias for the template parameter. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `m_upstream` | `Upstream *` | | diff --git a/fern/cudapages/thrust/thrust/thrust/device_reference.mdx b/fern/cudapages/thrust/thrust/thrust/device_reference.mdx new file mode 100644 index 0000000..4317a96 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/device_reference.mdx @@ -0,0 +1,1189 @@ +--- +title: thrust::device_reference +description: "[`device_reference`](/library/api/thrust::device_reference) acts as a reference-like object to an object stored in device memory." +--- + +`device_reference` acts as a reference-like object to an object stored in device memory. + +`device_reference` is not intended to be used directly; rather, this type is the result of dereferencing a [`device_ptr`](/library/api/thrust::device_ptr). Similarly, taking the address of a `device_reference` yields a [`device_ptr`](/library/api/thrust::device_ptr). + +`device_reference` may often be used from host code in place of operations defined on its associated [`value_type`](/library/api/thrust::device_reference::value_type). For example, when `device_reference` refers to an arithmetic type, arithmetic operations on it are legal: + +Similarly, we can print the value of `ref_to_thirteen` in the above code by using an `iostream:` + +Of course, we needn't explicitly create a `device_reference` in the previous example, because one is returned by [`device_vector`](/library/api/thrust::device_vector)`'s` bracket operator. A more natural way to print the value of a [`device_vector`](/library/api/thrust::device_vector) element might be: + +These kinds of operations should be used sparingly in performance-critical code, because they imply a potentially expensive copy between host and device space. + +Some operations which are possible with regular objects are impossible with their corresponding `device_reference` objects due to the requirements of the C++ language. For example, because the member access operator cannot be overloaded, member variables and functions of a referent object cannot be directly accessed through its `device_reference`. + +The following code, which generates a compiler error, illustrates: + +Instead, a host space copy must be created to access `foo's` `x` member: + +Another common case where a `device_reference` cannot directly be used in place of its referent object occurs when passing them as parameters to functions like `printf` which have varargs parameters. Because varargs parameters must be Plain Old Data, a `device_reference` to a POD type requires a cast when passed to `printf:` + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[device_ptr](/library/api/thrust::device_ptr), +[device_vector](/library/api/thrust::device_vector) + +## Example + +```cpp showLineNumbers={false} +#include + +int main() +{ + thrust::device_vector vec(1, 13); + + thrust::device_reference ref_to_thirteen = vec[0]; + + int x = ref_to_thirteen + 1; + + // x is 14 + + return 0; +} +``` + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + thrust::device_vector vec(1, 13); + + thrust::device_reference ref_to_thirteen = vec[0]; + + std::cout << ref_to_thirteen << std::endl; + + // 13 is printed + + return 0; +} +``` + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + thrust::device_vector vec(1, 13); + + std::cout << vec[0] << std::endl; + + // 13 is printed + + return 0; +} +``` + +```cpp showLineNumbers={false} +#include + +struct foo +{ + int x; +}; + +int main() +{ + thrust::device_vector foo_vec(1); + + thrust::device_reference foo_ref = foo_vec[0]; + + foo_ref.x = 13; // ERROR: x cannot be accessed through foo_ref + + return 0; +} +``` + +```cpp showLineNumbers={false} +#include + +struct foo +{ + int x; +}; + +int main() +{ + thrust::device_vector foo_vec(1); + + // create a local host-side foo object + foo host_foo; + host_foo.x = 13; + + thrust::device_reference foo_ref = foo_vec[0]; + + foo_ref = host_foo; + + // foo_ref's x member is 13 + + return 0; +} +``` + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + thrust::device_vector vec(1,13); + + // vec[0] must be cast to int when passing to printf + printf("%d\n", (int) vec[0]); + + return 0; +} +``` + + + + + + + + + + +**Inherits from:** `thrust::reference< T, thrust::device_ptr< T >, thrust::device_reference< T > >` (public) + +--- + +## Constructors + +### device_reference + + + + + +```cpp showLineNumbers={false} +thrust::device_reference::device_reference( + const device_reference &other +) = default +``` + + + + + +inline + +This copy constructor accepts a const reference to another `device_reference`. + + +```cpp showLineNumbers={false} +template +thrust::device_reference::device_reference( + const device_reference &other, + thrust::detail::enable_if_convertible_t::pointer, pointer> * = 0 +) +``` + + +**Parameters** + + +A `device_reference` to copy from. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,0); +thrust::device_reference ref = v[0]; + +// ref equals the object at v[0] +assert(ref == v[0]); + +// the address of ref equals the address of v[0] +assert(&ref == &v[0]); + +// modifying v[0] modifies ref +v[0] = 13; +assert(ref == 13); +``` + + + + +inline explicit + +This copy constructor initializes this `device_reference` to refer to an object pointed to by the given [`device_ptr`](/library/api/thrust::device_ptr). + + +```cpp showLineNumbers={false} +thrust::device_reference::device_reference( + const pointer &ptr +) +``` + + +**Parameters** + + +A [`device_ptr`](/library/api/thrust::device_ptr) to copy from. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,0); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals the object pointed to by ptr +assert(ref == *ptr); + +// the address of ref equals ptr +assert(&ref == ptr); + +// modifying *ptr modifies ref +*ptr = 13; +assert(ref == 13); +``` + + + + +--- + +## Assignment operators + +### operator= inline const + + + + + +```cpp showLineNumbers={false} +const device_reference & thrust::device_reference::operator=( + const device_reference &other +) const +``` + + + + + +This assignment operator assigns the value of the object referenced by the given `device_reference` to the object referenced by this `device_reference`. + + +```cpp showLineNumbers={false} +template +const device_reference & thrust::device_reference::operator=( + const device_reference &other +) const +``` + + +**Returns:** `*this` + +**Parameters** + + +The `device_reference` to assign from. + + + + + +Assignment operator assigns the value of the given value to the value referenced by this `device_reference`. + + +```cpp showLineNumbers={false} +const device_reference & thrust::device_reference::operator=( + const value_type &x +) const +``` + + +**Returns:** `*this` + +**Parameters** + + +The value to assign from. + + + + + +--- + +## Methods + +### operator& const + +Address-of operator returns a [`device_ptr`](/library/api/thrust::device_ptr) pointing to the object referenced by this `device_reference`. + + +```cpp showLineNumbers={false} +pointer thrust::device_reference::operator&( + void +) const +``` + + +**Returns:** A [`device_ptr`](/library/api/thrust::device_ptr) pointing to the object this `device_reference` references. + +### operator value_type const + +Conversion operator converts this `device_reference` to T by returning a copy of the object referenced by this `device_reference`. + + +```cpp showLineNumbers={false} +thrust::device_reference::operator value_type( + void +) const +``` + + +**Returns:** A copy of the object referenced by this `device_reference`. + +### swap + +swaps the value this `device_reference` references with another. + + +```cpp showLineNumbers={false} +void thrust::device_reference::swap( + device_reference other +) +``` + + +### operator++ + + + + +Prefix increment operator increments the object referenced by this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` prefix increment operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator++( + void +) +``` + + + +The increment executes as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this` + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,0); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 0 +assert(ref == 0); + +// the object pointed to by ptr equals 1 +assert(*ptr == 1); + +// v[0] equals 1 +assert(v[0] == 1); + +// increment ref +++ref; + +// ref equals 1 +assert(ref == 1); + +// the object pointed to by ptr equals 1 +assert(*ptr == 1); + +// v[0] equals 1 +assert(v[0] == 1); +``` + + + + +Postfix increment operator copies the object referenced by this `device_reference`, increments the object referenced by this `device_reference`, and returns the copy. + +The following code snippet demonstrates the semantics of `device_reference``'s` postfix increment operator. + + +```cpp showLineNumbers={false} +value_type thrust::device_reference::operator++( + int +) +``` + + + +The increment executes as if it were executed on the host. This may change in a later version. + + +**Returns:** A copy of the object referenced by this `device_reference` before being incremented. + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,0); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 0 +assert(ref == 0); + +// the object pointed to by ptr equals 0 +assert(*ptr == 0); + +// v[0] equals 0 +assert(v[0] == 0); + +// increment ref +int x = ref++; + +// x equals 0 +assert(x == 0) + +// ref equals 1 +assert(ref == 1); + +// the object pointed to by ptr equals 1 +assert(*ptr == 1); + +// v[0] equals 1 +assert(v[0] == 1); +``` + + + + +### operator+= + +Addition assignment operator add-assigns the object referenced by this `device_reference` and returns this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` addition assignment operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator+=( + const T &rhs +) +``` + + + +The add-assignment executes as as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this`. + +**Parameters** + + +The right hand side of the add-assignment. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,0); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 0 +assert(ref == 0); + +// the object pointed to by ptr equals 0 +assert(*ptr == 0); + +// v[0] equals 0 +assert(v[0] == 0); + +// add-assign ref +ref += 5; + +// ref equals 5 +assert(ref == 5); + +// the object pointed to by ptr equals 5 +assert(*ptr == 5); + +// v[0] equals 5 +assert(v[0] == 5); +``` + +### operator-- + + + + +Prefix decrement operator decrements the object referenced by this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` prefix decrement operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator--( + void +) +``` + + + +The decrement executes as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this` + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,0); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 0 +assert(ref == 0); + +// the object pointed to by ptr equals 0 +assert(*ptr == 0); + +// v[0] equals 0 +assert(v[0] == 0); + +// decrement ref +--ref; + +// ref equals -1 +assert(ref == -1); + +// the object pointed to by ptr equals -1 +assert(*ptr == -1); + +// v[0] equals -1 +assert(v[0] == -1); +``` + + + + +Postfix decrement operator copies the object referenced by this `device_reference`, decrements the object referenced by this `device_reference`, and returns the copy. + +The following code snippet demonstrates the semantics of `device_reference``'s` postfix decrement operator. + + +```cpp showLineNumbers={false} +value_type thrust::device_reference::operator--( + int +) +``` + + + +The decrement executes as if it were executed on the host. This may change in a later version. + + +**Returns:** A copy of the object referenced by this `device_reference` before being decremented. + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,0); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 0 +assert(ref == 0); + +// the object pointed to by ptr equals 0 +assert(*ptr == 0); + +// v[0] equals 0 +assert(v[0] == 0); + +// decrement ref +int x = ref--; + +// x equals 0 +assert(x == 0) + +// ref equals -1 +assert(ref == -1); + +// the object pointed to by ptr equals -1 +assert(*ptr == -1); + +// v[0] equals -1 +assert(v[0] == -1); +``` + + + + +### operator-= + +Subtraction assignment operator subtract-assigns the object referenced by this `device_reference` and returns this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` addition assignment operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator-=( + const T &rhs +) +``` + + + +The subtract-assignment executes as as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this`. + +**Parameters** + + +The right hand side of the subtraction-assignment. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,0); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 0 +assert(ref == 0); + +// the object pointed to by ptr equals 0 +assert(*ptr == 0); + +// v[0] equals 0 +assert(v[0] == 0); + +// subtract-assign ref +ref -= 5; + +// ref equals -5 +assert(ref == -5); + +// the object pointed to by ptr equals -5 +assert(*ptr == -5); + +// v[0] equals -5 +assert(v[0] == -5); +``` + +### operator*= + +Multiplication assignment operator multiply-assigns the object referenced by this `device_reference` and returns this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` multiply assignment operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator*=( + const T &rhs +) +``` + + + +The multiply-assignment executes as as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this`. + +**Parameters** + + +The right hand side of the multiply-assignment. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,1); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 1 +assert(ref == 1); + +// the object pointed to by ptr equals 1 +assert(*ptr == 1); + +// v[0] equals 1 +assert(v[0] == 1); + +// multiply-assign ref +ref *= 5; + +// ref equals 5 +assert(ref == 5); + +// the object pointed to by ptr equals 5 +assert(*ptr == 5); + +// v[0] equals 5 +assert(v[0] == 5); +``` + +### operator/= + +Division assignment operator divide-assigns the object referenced by this `device_reference` and returns this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` divide assignment operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator/=( + const T &rhs +) +``` + + + +The divide-assignment executes as as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this`. + +**Parameters** + + +The right hand side of the divide-assignment. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,5); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 5 +assert(ref == 5); + +// the object pointed to by ptr equals 5 +assert(*ptr == 5); + +// v[0] equals 5 +assert(v[0] == 5); + +// divide-assign ref +ref /= 5; + +// ref equals 1 +assert(ref == 1); + +// the object pointed to by ptr equals 1 +assert(*ptr == 1); + +// v[0] equals 1 +assert(v[0] == 1); +``` + +### operator%= + +Modulation assignment operator modulus-assigns the object referenced by this `device_reference` and returns this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` divide assignment operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator%=( + const T &rhs +) +``` + + + +The modulus-assignment executes as as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this`. + +**Parameters** + + +The right hand side of the divide-assignment. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,5); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 5 +assert(ref == 5); + +// the object pointed to by ptr equals 5 +assert(*ptr == 5); + +// v[0] equals 5 +assert(v[0] == 5); + +// modulus-assign ref +ref %= 5; + +// ref equals 0 +assert(ref == 0); + +// the object pointed to by ptr equals 0 +assert(*ptr == 0); + +// v[0] equals 0 +assert(v[0] == 0); +``` + +### operator<<= + +Bitwise left shift assignment operator left shift-assigns the object referenced by this `device_reference` and returns this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` left shift assignment operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator<<=( + const T &rhs +) +``` + + + +The left shift-assignment executes as as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this`. + +**Parameters** + + +The right hand side of the left shift-assignment. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,1); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 1 +assert(ref == 1); + +// the object pointed to by ptr equals 1 +assert(*ptr == 1); + +// v[0] equals 1 +assert(v[0] == 1); + +// left shift-assign ref +ref <<= 1; + +// ref equals 2 +assert(ref == 2); + +// the object pointed to by ptr equals 2 +assert(*ptr == 2); + +// v[0] equals 2 +assert(v[0] == 2); +``` + +### operator>>= + +Bitwise right shift assignment operator right shift-assigns the object referenced by this `device_reference` and returns this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` right shift assignment operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator>>=( + const T &rhs +) +``` + + + +The right shift-assignment executes as as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this`. + +**Parameters** + + +The right hand side of the right shift-assignment. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,2); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 2 +assert(ref == 2); + +// the object pointed to by ptr equals 2 +assert(*ptr == 2); + +// v[0] equals 2 +assert(v[0] == 2); + +// right shift-assign ref +ref >>= 1; + +// ref equals 1 +assert(ref == 1); + +// the object pointed to by ptr equals 1 +assert(*ptr == 1); + +// v[0] equals 1 +assert(v[0] == 1); +``` + +### operator&= + +Bitwise AND assignment operator AND-assigns the object referenced by this `device_reference` and returns this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` AND assignment operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator&=( + const T &rhs +) +``` + + + +The AND-assignment executes as as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this`. + +**Parameters** + + +The right hand side of the AND-assignment. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,1); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 1 +assert(ref == 1); + +// the object pointed to by ptr equals 1 +assert(*ptr == 1); + +// v[0] equals 1 +assert(v[0] == 1); + +// right AND-assign ref +ref &= 0; + +// ref equals 0 +assert(ref == 0); + +// the object pointed to by ptr equals 0 +assert(*ptr == 0); + +// v[0] equals 0 +assert(v[0] == 0); +``` + +### operator|= + +Bitwise OR assignment operator OR-assigns the object referenced by this `device_reference` and returns this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` OR assignment operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator|=( + const T &rhs +) +``` + + + +The OR-assignment executes as as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this`. + +**Parameters** + + +The right hand side of the OR-assignment. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,0); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 0 +assert(ref == 0); + +// the object pointed to by ptr equals 0 +assert(*ptr == 0); + +// v[0] equals 0 +assert(v[0] == 0); + +// right OR-assign ref +ref |= 1; + +// ref equals 1 +assert(ref == 1); + +// the object pointed to by ptr equals 1 +assert(*ptr == 1); + +// v[0] equals 1 +assert(v[0] == 1); +``` + +### operator^= + +Bitwise XOR assignment operator XOR-assigns the object referenced by this `device_reference` and returns this `device_reference`. + +The following code snippet demonstrates the semantics of `device_reference``'s` XOR assignment operator. + + +```cpp showLineNumbers={false} +device_reference & thrust::device_reference::operator^=( + const T &rhs +) +``` + + + +The XOR-assignment executes as as if it were executed on the host. This may change in a later version. + + +**Returns:** `*this`. + +**Parameters** + + +The right hand side of the XOR-assignment. + + +**Example** + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector v(1,1); +thrust::device_ptr ptr = &v[0]; +thrust::device_reference ref(ptr); + +// ref equals 1 +assert(ref == 1); + +// the object pointed to by ptr equals 1 +assert(*ptr == 1); + +// v[0] equals 1 +assert(v[0] == 1); + +// right XOR-assign ref +ref ^= 1; + +// ref equals 0 +assert(ref == 0); + +// the object pointed to by ptr equals 0 +assert(*ptr == 0); + +// v[0] equals 0 +assert(v[0] == 0); +``` + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `super_t` | `thrust::reference< T, thrust::device_ptr< T >, thrust::device_reference< T > >` | | +| `value_type` | `typename super_t::value_type` | The type of the value referenced by this type of `device_reference`. | +| `pointer` | `typename super_t::pointer` | The type of the expression `&ref`, where `ref` is a `device_reference`. | diff --git a/fern/cudapages/thrust/thrust/thrust/device_vector.mdx b/fern/cudapages/thrust/thrust/thrust/device_vector.mdx new file mode 100644 index 0000000..de2a4db --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/device_vector.mdx @@ -0,0 +1,1368 @@ +--- +title: thrust::device_vector +description: "A [`device_vector`](/library/api/thrust::device_vector) is a container that supports random access to elements, constant time removal of elements at the end, and linear time insertion and removal of elements at the beginning or in the middle." +--- + +A `device_vector` is a container that supports random access to elements, constant time removal of elements at the end, and linear time insertion and removal of elements at the beginning or in the middle. + +The number of elements in a `device_vector` may vary dynamically; memory management is automatic. The memory associated with a `device_vector` resides in the memory accessible to devices. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[https://en.cppreference.com/w/cpp/container/vector](https://en.cppreference.com/w/cpp/container/vector), +[device_allocator](/library/api/thrust::device_allocator), +[host_vector](/library/api/thrust::host_vector), +universal_vector + + + + + + + + + + + + + +**Inherits from:** `detail::vector_base< T, thrust::device_allocator< T > >` (public) + +--- + +## Constructors + +### device_vector inline + + + + +This constructor creates an empty `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector() +``` + + + + + +This constructor creates an empty `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + const Alloc &alloc +) +``` + + +**Parameters** + + +The allocator to use by this `device_vector`. + + + + + +explicit + +This constructor creates a `device_vector` with the given size. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +This constructor creates a `device_vector` with the given size, performing only default-initialization instead of value-initialization. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + default_init_t +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +This constructor creates a `device_vector` with the given size, without initializing elements. + +It mandates that the element type is trivially default-constructible. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + no_init_t +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +explicit + +This constructor creates a `device_vector` with the given size. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + const Alloc &alloc +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +The allocator to use by this `device_vector`. + + + + + +explicit + +This constructor creates a `device_vector` with copies of an exemplar element. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + const value_type &value +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +An element to copy. + + + + + +explicit + +This constructor creates a `device_vector` with copies of an exemplar element. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + const value_type &value, + const Alloc &alloc +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +An element to copy. + + + +The allocator to use by this `device_vector`. + + + + + +Copy constructor copies from an exemplar `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + const device_vector &v +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + + + +Copy constructor copies from an exemplar `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + const device_vector &v, + const Alloc &alloc +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + +The allocator to use by this `device_vector`. + + + + + +explicit + +Copy constructor copies from an exemplar `device_vector` with different type. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + const device_vector &v +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + + + +Copy constructor copies from an exemplar `std::vector`. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + const std::vector &v +) +``` + + +**Parameters** + + +The `std::vector` to copy. + + + + + +Copy construct from a `vector_base` whose element type is convertible to `T`. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + const detail::vector_base &v +) +``` + + +**Parameters** + + +The `vector_base` to copy. + + + + + +Move constructor moves from another `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + device_vector &&v +) +``` + + +**Parameters** + + +The `device_vector` to move. + + + + + +Move constructor moves from another `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + device_vector &&v, + const Alloc &alloc +) +``` + + +**Parameters** + + +The `device_vector` to move. + + + +The allocator to use by this `device_vector`. + + + + + +This constructor builds a `device_vector` from an intializer_list. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + ::cuda::std::initializer_list il +) +``` + + +**Parameters** + + +The intializer_list. + + + + + +This constructor builds a `device_vector` from an intializer_list. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + ::cuda::std::initializer_list il, + const Alloc &alloc +) +``` + + +**Parameters** + + +The intializer_list. + + + +The allocator to use by this `device_vector`. + + + + + +This constructor builds a `device_vector` from a range. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + InputIterator first, + InputIterator last +) +``` + + +**Parameters** + + +The beginning of the range. + + + +The end of the range. + + + + + +This constructor builds a `device_vector` from a range. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + InputIterator first, + InputIterator last, + const Alloc &alloc +) +``` + + +**Parameters** + + +The beginning of the range. + + + +The end of the range. + + + +The allocator to use by this `device_vector`. + + + + + +### Destructor + +### ~device_vector inline + +The destructor erases the elements. + + +```cpp showLineNumbers={false} +thrust::device_vector::~device_vector() +``` + + +--- + +## Assignment operators + +### operator= inline + + + + +Copy assign operator copies another `device_vector` with the same type. + + +```cpp showLineNumbers={false} +device_vector & thrust::device_vector::operator=( + const device_vector &v +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + + + +Move assign operator moves from another `device_vector`. + + +```cpp showLineNumbers={false} +device_vector & thrust::device_vector::operator=( + device_vector &&v +) +``` + + +**Parameters** + + +The `device_vector` to move. + + + + + +Assign operator copies from an exemplar `device_vector` with different type. + + +```cpp showLineNumbers={false} +template +device_vector & thrust::device_vector::operator=( + const device_vector &v +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + + + +Assign operator copies from an exemplar `std::vector`. + + +```cpp showLineNumbers={false} +template +device_vector & thrust::device_vector::operator=( + const std::vector &v +) +``` + + +**Parameters** + + +The `std::vector` to copy. + + + + + +Assign a `vector_base` whose element type is convertible to `T`. + + +```cpp showLineNumbers={false} +template +device_vector & thrust::device_vector::operator=( + const detail::vector_base &v +) +``` + + +**Parameters** + + +The `vector_base` to copy. + + + + + +Assign an `intializer_list` with a matching element type. + + +```cpp showLineNumbers={false} +device_vector & thrust::device_vector::operator=( + ::cuda::std::initializer_list il +) +``` + + +**Parameters** + + +The intializer_list. + + + + + +--- + +## Methods + +### resize + + + + +Resizes this vector to the specified number of elements. + + +```cpp showLineNumbers={false} +void thrust::device_vector::resize( + size_type new_size, + const value_type &x = value_type() +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::device_vector::max_size()). + +**Parameters** + + +Number of elements this vector should contain. + + + +Data with which new elements should be populated. + + + + + +Resizes this vector to the specified number of elements, performing default-initialization instead of value-initialization. + + +```cpp showLineNumbers={false} +void thrust::device_vector::resize( + size_type new_size, + default_init_t +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::device_vector::max_size()). + +**Parameters** + + +Number of elements this vector should contain. + + + + + +Resizes this vector_base to the specified number of elements, without initializing elements. + +It mandates that the element type is trivially default-constructible. + + +```cpp showLineNumbers={false} +void thrust::device_vector::resize( + size_type new_size, + no_init_t +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::device_vector::max_size()). + +**Parameters** + + +Number of elements this vector_base should contain. + + + + + +### size const + +Returns the number of elements in this vector. + + +```cpp showLineNumbers={false} +size_type thrust::device_vector::size() const +``` + + +### max_size const + +Returns the [size()](/library/api/thrust::device_vector::size()) of the largest possible vector. + + +```cpp showLineNumbers={false} +size_type thrust::device_vector::max_size() const +``` + + +**Returns:** The largest possible return value of [size()](/library/api/thrust::device_vector::size()). + +### reserve + +If n is less than or equal to [capacity()](/library/api/thrust::device_vector::capacity()), this call has no effect. + + +```cpp showLineNumbers={false} +void thrust::device_vector::reserve( + size_type n +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::device_vector::max_size()). + +### capacity const + +Returns the number of elements which have been reserved in this vector. + + +```cpp showLineNumbers={false} +size_type thrust::device_vector::capacity() const +``` + + +### shrink_to_fit + +This method shrinks the capacity of this vector to exactly fit its elements. + + +```cpp showLineNumbers={false} +void thrust::device_vector::shrink_to_fit() +``` + + +### operator[] + + + + +Subscript access to the data contained in this vector_dev. + +This operator allows for easy, array-style, data access. Note that data access with this operator is unchecked and out_of_range lookups are not defined. + + +```cpp showLineNumbers={false} +reference thrust::device_vector::operator[]( + size_type n +) +``` + + +**Returns:** Read/write reference to data. + +**Parameters** + + +The index of the element for which data should be accessed. + + + + + +const + +Subscript read access to the data contained in this vector_dev. + +This operator allows for easy, array-style, data access. Note that data access with this operator is unchecked and out_of_range lookups are not defined. + + +```cpp showLineNumbers={false} +const_reference thrust::device_vector::operator[]( + size_type n +) const +``` + + +**Returns:** Read reference to data. + +**Parameters** + + +The index of the element for which data should be accessed. + + + + + +### begin + + + + +This method returns an iterator pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::begin() +``` + + +**Returns:** mStart + + + + +const + +This method returns a const_iterator pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::begin() const +``` + + +**Returns:** mStart + + + + +### cbegin const + +This method returns a const_iterator pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::cbegin() const +``` + + +**Returns:** mStart + +### rbegin + + + + +This method returns a reverse_iterator pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +reverse_iterator thrust::device_vector::rbegin() +``` + + +**Returns:** A reverse_iterator pointing to the beginning of this vector's reversed sequence. + + + + +const + +This method returns a const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::rbegin() const +``` + + +**Returns:** A const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + + + + +### crbegin const + +This method returns a const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::crbegin() const +``` + + +**Returns:** A const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + +### end + + + + +This method returns an iterator pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::end() +``` + + +**Returns:** [begin()](/library/api/thrust::device_vector::begin()) + [size()](/library/api/thrust::device_vector::size()). + + + + +const + +This method returns a const_iterator pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::end() const +``` + + +**Returns:** [begin()](/library/api/thrust::device_vector::begin()) + [size()](/library/api/thrust::device_vector::size()). + + + + +### cend const + +This method returns a const_iterator pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::cend() const +``` + + +**Returns:** [begin()](/library/api/thrust::device_vector::begin()) + [size()](/library/api/thrust::device_vector::size()). + +### rend + + + + +This method returns a reverse_iterator pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +reverse_iterator thrust::device_vector::rend() +``` + + +**Returns:** [rbegin()](/library/api/thrust::device_vector::rbegin()) + [size()](/library/api/thrust::device_vector::size()). + + + + +const + +This method returns a const_reverse_iterator pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::rend() const +``` + + +**Returns:** [rbegin()](/library/api/thrust::device_vector::rbegin()) + [size()](/library/api/thrust::device_vector::size()). + + + + +### crend const + +This method returns a const_reverse_iterator pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::crend() const +``` + + +**Returns:** [rbegin()](/library/api/thrust::device_vector::rbegin()) + [size()](/library/api/thrust::device_vector::size()). + +### front + + + + +This method returns a reference pointing to the first element of this vector. + + +```cpp showLineNumbers={false} +reference thrust::device_vector::front() +``` + + +**Returns:** The first element of this vector. + + + + +const + +This method returns a const_reference referring to the first element of this vector. + + +```cpp showLineNumbers={false} +const_reference thrust::device_vector::front() const +``` + + +**Returns:** The first element of this vector. + + + + +### back + + + + +This method returns a reference referring to the last element of this vector_dev. + + +```cpp showLineNumbers={false} +reference thrust::device_vector::back() +``` + + +**Returns:** The last element of this vector. + + + + +const + +This method returns a const reference pointing to the last element of this vector. + + +```cpp showLineNumbers={false} +const_reference thrust::device_vector::back() const +``` + + +**Returns:** The last element of this vector. + + + + +### data + + + + +This method returns a pointer to this vector's first element. + + +```cpp showLineNumbers={false} +pointer thrust::device_vector::data() +``` + + +**Returns:** A pointer to the first element of this vector. + + + + +const + +This method returns a const_pointer to this vector's first element. + + +```cpp showLineNumbers={false} +const_pointer thrust::device_vector::data() const +``` + + +**Returns:** a const_pointer to the first element of this vector. + + + + +### clear + +This method resizes this vector to 0. + + +```cpp showLineNumbers={false} +void thrust::device_vector::clear() +``` + + +### empty const + +This method returns true iff [size()](/library/api/thrust::device_vector::size()) == 0. + + +```cpp showLineNumbers={false} +bool thrust::device_vector::empty() const +``` + + +**Returns:** true if [size()](/library/api/thrust::device_vector::size()) == 0; false, otherwise. + +### push_back + +This method appends the given element to the end of this vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::push_back( + const value_type &x +) +``` + + +**Parameters** + + +The element to append. + + +### pop_back + +This method erases the last element of this vector, invalidating all iterators and references to it. + + +```cpp showLineNumbers={false} +void thrust::device_vector::pop_back() +``` + + +### swap + +This method swaps the contents of this `device_vector` with another vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::swap( + device_vector &v +) +``` + + +**Parameters** + + +The vector with which to swap. + + +### erase + + + + +This method removes the element at position pos. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::erase( + iterator pos +) +``` + + +**Returns:** An iterator pointing to the new location of the element that followed the element at position pos. + +**Parameters** + + +The position of the element of interest. + + + + + +This method removes the range of elements [first,last) from this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::erase( + iterator first, + iterator last +) +``` + + +**Returns:** An iterator pointing to the new location of the element that followed the last element in the sequence [first,last). + +**Parameters** + + +The beginning of the range of elements to remove. + + + +The end of the range of elements to remove. + + + + + +### insert + + + + +This method inserts a single copy of a given exemplar value at the specified position in this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::insert( + iterator position, + const T &x +) +``` + + +**Returns:** An iterator pointing to the newly inserted element. + +**Parameters** + + +The insertion position. + + + +The exemplar element to copy & insert. + + + + + +This method inserts a copy of an exemplar value to a range at the specified position in this vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::insert( + iterator position, + size_type n, + const T &x +) +``` + + +**Parameters** + + +The insertion position + + + +The number of insertions to perform. + + + +The value to replicate and insert. + + + + + +This method inserts a copy of an input range at the specified position in this vector. + + +```cpp showLineNumbers={false} +template +void thrust::device_vector::insert( + iterator position, + InputIterator first, + InputIterator last +) +``` + + +**Template parameters** + + +Is a model of [Input Iterator](https://en.cppreference.com/w/cpp/iterator/input_iterator), and `InputIterator's` `value_type` is a model of [Assignable.](https://en.cppreference.com/w/cpp/named_req/CopyAssignable) + + +**Parameters** + + +The insertion position. + + + +The beginning of the range to copy. + + + +The end of the range to copy. + + + + + +### assign + + + + +This version of `assign` replicates a given exemplar `n` times into this vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::assign( + size_type n, + const T &x +) +``` + + +**Parameters** + + +The number of times to copy `x`. + + + +The exemplar element to replicate. + + + + + +This version of `assign` makes this vector a copy of a given input range. + + +```cpp showLineNumbers={false} +template +void thrust::device_vector::assign( + InputIterator first, + InputIterator last +) +``` + + +**Template parameters** + + +Is a model of [Input Iterator](https://en.cppreference.com/w/cpp/named_req/InputIterator). + + +**Parameters** + + +The beginning of the range to copy. + + + +The end of the range to copy. + + + + + +### get_allocator const + +This method returns a copy of this vector's allocator. + + +```cpp showLineNumbers={false} +allocator_type thrust::device_vector::get_allocator() const +``` + + +**Returns:** A copy of the allocator used by this vector. + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `Parent` | `detail::vector_base< T, Alloc >` | diff --git a/fern/cudapages/thrust/thrust/thrust/discard_block_engine.mdx b/fern/cudapages/thrust/thrust/thrust/discard_block_engine.mdx new file mode 100644 index 0000000..8f4f427 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/discard_block_engine.mdx @@ -0,0 +1,223 @@ +--- +title: thrust::discard_block_engine +description: "A [`discard_block_engine`](/library/api/thrust::discard_block_engine) adapts an existing base random number engine and produces random values by discarding some of the values returned by its base engine." +--- + +A `discard_block_engine` adapts an existing base random number engine and produces random values by discarding some of the values returned by its base engine. + +Each cycle of the compound engine begins by returning `r` values successively produced by the base engine and ends by discarding `p-r` such values. The engine's state is the state of its base engine followed by the number of calls to `operator()` that have occurred since the beginning of the current cycle. + +The following code snippet shows an example of using a `discard_block_engine` instance: + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include + +int main() +{ + // create a discard_block_engine from minstd_rand, with a cycle length of 13 + // keep every first 10 values, and discard the next 3 + thrust::discard_block_engine rng; + + // print a random number to standard output + std::cout << rng() << std::endl; + + return 0; +} +``` + + + + + +The type of the base random number engine to adapt. + + + +The discard cycle length. + + + +The number of values to return of the base engine. Because `p-r` will be discarded, `r <= p`. + + + + + +--- + +## Constructors + +### discard_block_engine + + + + +This constructor constructs a new `discard_block_engine` and constructs its [`base_type`](/library/api/thrust::random::discard_block_engine::base_type) engine using its null constructor. + + +```cpp showLineNumbers={false} +thrust::random::discard_block_engine::discard_block_engine() +``` + + + + + +explicit + +This constructor constructs a new `discard_block_engine` using a given [`base_type`](/library/api/thrust::random::discard_block_engine::base_type) engine to initialize its adapted base engine. + + +```cpp showLineNumbers={false} +thrust::random::discard_block_engine::discard_block_engine( + const base_type &urng +) +``` + + +**Parameters** + + +A [`base_type`](/library/api/thrust::random::discard_block_engine::base_type) to use to initialize this `discard_block_engine``'s` adapted base engine. + + + + + +explicit + +This constructor initializes a new `discard_block_engine` with a given seed. + + +```cpp showLineNumbers={false} +thrust::random::discard_block_engine::discard_block_engine( + result_type s +) +``` + + +**Parameters** + + +The seed used to initialize this `discard_block_engine``'s` adapted base engine. + + + + + +--- + +## Methods + +### seed + + + + +This method initializes the state of this `discard_block_engine``'s` adapted base engine by using its `default_seed` value. + + +```cpp showLineNumbers={false} +void thrust::random::discard_block_engine::seed() +``` + + + + + +This method initializes the state of this `discard_block_engine``'s` adapted base engine by using the given seed. + + +```cpp showLineNumbers={false} +void thrust::random::discard_block_engine::seed( + result_type s +) +``` + + +**Parameters** + + +The seed with which to initialize this `discard_block_engine``'s` adapted base engine. + + + + + +### operator() + +This member function produces a new random value and updates this `discard_block_engine``'s` state. + + +```cpp showLineNumbers={false} +result_type thrust::random::discard_block_engine::operator()( + void +) +``` + + +**Returns:** A new random number. + +### discard + +This member function advances this `discard_block_engine``'s` state a given number of times and discards the results. + + +```cpp showLineNumbers={false} +void thrust::random::discard_block_engine::discard( + unsigned long long z +) +``` + + + +This function is provided because an implementation may be able to accelerate it. + + +**Parameters** + + +The number of random values to discard. + + +### base const + +This member function returns a const reference to this `discard_block_engine``'s` adapted base engine. + + +```cpp showLineNumbers={false} +const base_type & thrust::random::discard_block_engine::base() const +``` + + +**Returns:** A const reference to the base engine this `discard_block_engine` adapts. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `Engine` | The type of the adapted base random number engine. | +| `result_type` | `typename base_type::result_type` | The type of the unsigned integer produced by this [`linear_congruential_engine`](/library/api/thrust::linear_congruential_engine). | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `block_size` static | `const size_t` | The length of the production cycle. | +| `used_block` static | `const size_t` | The number of used numbers per production cycle. | +| `min` static | `const result_type` | The smallest value this `discard_block_engine` may potentially produce. | +| `max` static | `const result_type` | The largest value this `discard_block_engine` may potentially produce. | diff --git a/fern/cudapages/thrust/thrust/thrust/discard_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/discard_iterator.mdx new file mode 100644 index 0000000..4a65536 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/discard_iterator.mdx @@ -0,0 +1,126 @@ +--- +title: thrust::discard_iterator +description: "[`discard_iterator`](/library/api/thrust::discard_iterator) is an iterator which represents a special kind of pointer that ignores values written to it upon dereference." +--- + +`discard_iterator` is an iterator which represents a special kind of pointer that ignores values written to it upon dereference. + +This iterator is useful for ignoring the output of certain algorithms without wasting memory capacity or bandwidth. `discard_iterator` may also be used to count the size of an algorithm's output which may not be known a priori. + +The following code snippet demonstrates how to use `discard_iterator` to ignore one of the output ranges of reduce_by_key + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +make_discard_iterator + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include + +int main() +{ + thrust::device_vector keys{1, 3, 3, 3, 2, 2, 1}; + thrust::device_vector values{9, 8, 7, 6, 5, 4, 3}; + + thrust::device_vector result(4); + + // we are only interested in the reduced values + // use discard_iterator to ignore the output keys + thrust::reduce_by_key(keys.begin(), keys.end(), + values.begin(), + thrust::make_discard_iterator(), + result.begin()); + + // result is now [9, 21, 9, 3] + + return 0; +} +``` + + + + + + + + + + +**Inherits from:** `thrust::iterator_adaptor< discard_iterator< use_default >, base_iterator, value_type, iterator_system_t< base_iterator >, iterator_traversal_t< base_iterator >, reference >` (public) + +--- + +## Constructors + +### discard_iterator inline + +This constructor receives an optional index specifying the position of this `discard_iterator` in a range. + +`i` The index of this `discard_iterator` in a range. Defaults to the value returned by `Incrementable's` null constructor. For example, when `Incrementable == int`, `0`. + + +```cpp showLineNumbers={false} +thrust::discard_iterator::discard_iterator( + incrementable const &i = incrementable() +) +``` + + +--- + +## Methods + +### base inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +Base & thrust::iterator_adaptor::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/error_category.mdx b/fern/cudapages/thrust/thrust/thrust/error_category.mdx new file mode 100644 index 0000000..fea4568 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/error_category.mdx @@ -0,0 +1,139 @@ +--- +title: thrust::error_category +description: "The class [`error_category`](/library/api/thrust::error_category) serves as a base class for types used to identify the source and encoding of a particular category of error code." +--- + +The class `error_category` serves as a base class for types used to identify the source and encoding of a particular category of error code. + +Classes may be derived from `error_category` to support categories of errors in addition to those defined in the C++ International Standard. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Constructors + +### Destructor + +### ~error_category inline virtual + +Destructor does nothing. + + +```cpp showLineNumbers={false} +virtual thrust::system::error_category::~error_category() +``` + + +--- + +## Methods + +### name inline const virtual + + +```cpp showLineNumbers={false} +virtual const char * thrust::system::error_category::name() const +``` + + +**Returns:** A string naming the error category. + +### default_error_condition inline const virtual + + +```cpp showLineNumbers={false} +virtual error_condition thrust::system::error_category::default_error_condition( + int ev +) const +``` + + +**Returns:** `error_condition(ev, *this)`. + +### equivalent inline const virtual + + + + + +```cpp showLineNumbers={false} +virtual bool thrust::system::error_category::equivalent( + int code, + const error_condition &condition +) const +``` + + +**Returns:** `default_error_condition(code) == condition` + + + + + +```cpp showLineNumbers={false} +virtual bool thrust::system::error_category::equivalent( + const error_code &code, + int condition +) const +``` + + +**Returns:** `*this == code.category() && code.value() == condition` + + + + +### message const virtual + + +```cpp showLineNumbers={false} +virtual std::string thrust::system::error_category::message( + int ev +) const +``` + + +**Returns:** A string that describes the error condition denoted by `ev`. + +### operator== inline const + + +```cpp showLineNumbers={false} +bool thrust::system::error_category::operator==( + const error_category &rhs +) const +``` + + +**Returns:** `*this == &rhs` + +### operator!= inline const + + +```cpp showLineNumbers={false} +bool thrust::system::error_category::operator!=( + const error_category &rhs +) const +``` + + +**Returns:** `!(*this == rhs)` + +### operator< inline const + + +```cpp showLineNumbers={false} +bool thrust::system::error_category::operator<( + const error_category &rhs +) const +``` + + + +`less` provides a total ordering for pointers. + + +**Returns:** `less()``(this, &rhs)` diff --git a/fern/cudapages/thrust/thrust/thrust/error_code.mdx b/fern/cudapages/thrust/thrust/thrust/error_code.mdx new file mode 100644 index 0000000..f4ea469 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/error_code.mdx @@ -0,0 +1,175 @@ +--- +title: thrust::error_code +description: "The class [`error_code`](/library/api/thrust::error_code) describes an object used to hold error code values, such as those originating from the operating system or other low-level application program interfaces." +--- + +The class `error_code` describes an object used to hold error code values, such as those originating from the operating system or other low-level application program interfaces. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Constructors + +### error_code + + + + +inline + +Effects: Constructs an object of type `error_code`. + + +```cpp showLineNumbers={false} +thrust::system::error_code::error_code() +``` + + + +[`value()`](/library/api/thrust::system::error_code::value())` == 0` and [`category()`](/library/api/thrust::system::error_code::category())` == &``system_category()`. + + + + + +inline + +Effects: Constructs an object of type `error_code`. + + +```cpp showLineNumbers={false} +thrust::system::error_code::error_code( + int val, + const error_category &cat +) +``` + + + +[`value()`](/library/api/thrust::system::error_code::value())` == val` and [`category()`](/library/api/thrust::system::error_code::category())` == &cat`. + + + + + +Effects: Constructs an object of type `error_code`. + + +```cpp showLineNumbers={false} +template +thrust::system::error_code::error_code( + ErrorCodeEnum e, + ::cuda::std::enable_if_t::value> * = 0 +) +``` + + + +`*this == make_error_code(e)`. + + + + + +--- + +## Assignment operators + +### operator= + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t::value, error_code> & thrust::system::error_code::operator=( + ErrorCodeEnum e +) +``` + + + +`*this == make_error_code(e)`. + + +--- + +## Methods + +### assign inline + + +```cpp showLineNumbers={false} +void thrust::system::error_code::assign( + int val, + const error_category &cat +) +``` + + + +[`value()`](/library/api/thrust::system::error_code::value())` == val` and [`category()`](/library/api/thrust::system::error_code::category())` == &cat`. + + +### clear inline + + +```cpp showLineNumbers={false} +void thrust::system::error_code::clear() +``` + + + +[`value()`](/library/api/thrust::system::error_code::value())` == 0` and [`category()`](/library/api/thrust::system::error_code::category())` == ``system_category()`. + + +### value inline const + + +```cpp showLineNumbers={false} +int thrust::system::error_code::value() const +``` + + +**Returns:** An integral value of this `error_code` object. + +### category inline const + + +```cpp showLineNumbers={false} +const error_category & thrust::system::error_code::category() const +``` + + +**Returns:** An [`error_category`](/library/api/thrust::error_category) describing the category of this `error_code` object. + +### default_error_condition inline const + + +```cpp showLineNumbers={false} +error_condition thrust::system::error_code::default_error_condition() const +``` + + +**Returns:** [`category()`](/library/api/thrust::system::error_code::category())`.`[`default_error_condition()`](/library/api/thrust::system::error_code::default_error_condition()). + +### message inline const + + +```cpp showLineNumbers={false} +std::string thrust::system::error_code::message() const +``` + + +**Returns:** [`category()`](/library/api/thrust::system::error_code::category())`.message(value())`. + +### operator bool inline const + + +```cpp showLineNumbers={false} +thrust::system::error_code::operator bool() const +``` + + +**Returns:** [`value()`](/library/api/thrust::system::error_code::value())` != 0`. diff --git a/fern/cudapages/thrust/thrust/thrust/error_condition.mdx b/fern/cudapages/thrust/thrust/thrust/error_condition.mdx new file mode 100644 index 0000000..ff33249 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/error_condition.mdx @@ -0,0 +1,211 @@ +--- +title: thrust::error_condition +description: "The class [`error_condition`](/library/api/thrust::error_condition) describes an object used to hold values identifying error conditions." +--- + +The class `error_condition` describes an object used to hold values identifying error conditions. + +```cpp showLineNumbers={false} +#include +``` + + +`error_condition` values are portable abstractions, while [`error_code`](/library/api/thrust::error_code) values are implementation specific. + + +--- + +## Constructors + +### error_condition + + + + +inline + +Constructs an object of type `error_condition`. + + +```cpp showLineNumbers={false} +thrust::system::error_condition::error_condition() +``` + + + +[`value()`](/library/api/thrust::system::error_condition::value())` == 0`. + + + +[`category()`](/library/api/thrust::system::error_condition::category())` == ``generic_category()`. + + + + + +inline + +Constructs an object of type `error_condition`. + + +```cpp showLineNumbers={false} +thrust::system::error_condition::error_condition( + int val, + const error_category &cat +) +``` + + + +[`value()`](/library/api/thrust::system::error_condition::value())` == val`. + + + +[`category()`](/library/api/thrust::system::error_condition::category())` == cat`. + + + + + +Constructs an object of type `error_condition`. + + +```cpp showLineNumbers={false} +template +thrust::system::error_condition::error_condition( + ErrorConditionEnum e, + ::cuda::std::enable_if_t::value> * = 0 +) +``` + + + +This constructor shall not participate in overload resolution unless `is_error_condition_enum::value` is `true`. + + + +`*this == make_error_condition(e)`. + + + + + +--- + +## Assignment operators + +### operator= + +Assigns to this [`error_code`](/library/api/thrust::error_code) object from an error condition enumeration. + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t::value, error_condition> & thrust::system::error_condition::operator=( + ErrorConditionEnum e +) +``` + + + +This operator shall not participate in overload resolution unless `is_error_condition_enum::value` is `true`. + + + +`*this == make_error_condition(e)`. + + +**Returns:** *this + +--- + +## Methods + +### assign inline + +Assigns to this [`error_code`](/library/api/thrust::error_code) object from an error value and an [`error_category`](/library/api/thrust::error_category). + + +```cpp showLineNumbers={false} +void thrust::system::error_condition::assign( + int val, + const error_category &cat +) +``` + + + +[`value()`](/library/api/thrust::system::error_condition::value())` == val`. + + + +[`category()`](/library/api/thrust::system::error_condition::category())` == cat`. + + +**Parameters** + + +The new value to return from [`value()`](/library/api/thrust::system::error_condition::value()). + + + +The new [`error_category`](/library/api/thrust::error_category) to return from [`category()`](/library/api/thrust::system::error_condition::category()). + + +### clear inline + +Clears this [`error_code`](/library/api/thrust::error_code) object. + + +```cpp showLineNumbers={false} +void thrust::system::error_condition::clear() +``` + + + +[`value`](/library/api/thrust::system::error_condition::value)` == 0` + + + +[`category()`](/library/api/thrust::system::error_condition::category())` == ``generic_category()`. + + +### value inline const + + +```cpp showLineNumbers={false} +int thrust::system::error_condition::value() const +``` + + +**Returns:** The value encoded by this `error_condition`. + +### category inline const + + +```cpp showLineNumbers={false} +const error_category & thrust::system::error_condition::category() const +``` + + +**Returns:** A `const` reference to the [`error_category`](/library/api/thrust::error_category) encoded by this `error_condition`. + +### message inline const + + +```cpp showLineNumbers={false} +std::string thrust::system::error_condition::message() const +``` + + +**Returns:** [`category()`](/library/api/thrust::system::error_condition::category())`.message(value())`. + +### operator bool inline const + + +```cpp showLineNumbers={false} +thrust::system::error_condition::operator bool() const +``` + + +**Returns:** [`value()`](/library/api/thrust::system::error_condition::value())` != 0`. diff --git a/fern/cudapages/thrust/thrust/thrust/forward_device_iterator_tag.mdx b/fern/cudapages/thrust/thrust/thrust/forward_device_iterator_tag.mdx new file mode 100644 index 0000000..f2ae0c6 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/forward_device_iterator_tag.mdx @@ -0,0 +1,17 @@ +--- +title: thrust::forward_device_iterator_tag +description: "[`forward_device_iterator_tag`](/library/api/thrust::forward_device_iterator_tag) is an empty class: it has no member functions, member variables, or nested types." +--- + +`forward_device_iterator_tag` is an empty class: it has no member functions, member variables, or nested types. + +It is used solely as a "tag": a representation of the Forward Device Iterator concept within the C++ type system. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[https://en.cppreference.com/w/cpp/iterator/iterator_tags](https://en.cppreference.com/w/cpp/iterator/iterator_tags) iterator_traits, [input_device_iterator_tag](/library/api/thrust::input_device_iterator_tag), [output_device_iterator_tag](/library/api/thrust::output_device_iterator_tag), [bidirectional_device_iterator_tag](/library/api/thrust::bidirectional_device_iterator_tag), [random_access_device_iterator_tag](/library/api/thrust::random_access_device_iterator_tag), input_host_iterator_tag, output_host_iterator_tag, forward_host_iterator_tag, bidirectional_host_iterator_tag, random_access_host_iterator_tag + +**Inherits from:** `detail::iterator_category_with_system_and_traversal<::cuda::std::forward_iterator_tag, device_system_tag, forward_traversal_tag >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/forward_traversal_tag.mdx b/fern/cudapages/thrust/thrust/thrust/forward_traversal_tag.mdx new file mode 100644 index 0000000..f73a357 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/forward_traversal_tag.mdx @@ -0,0 +1,12 @@ +--- +title: thrust::forward_traversal_tag +description: "Tag type for iterators allowing forward traversal." +--- + +Tag type for iterators allowing forward traversal. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::single_pass_traversal_tag` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/host_execution_policy.mdx b/fern/cudapages/thrust/thrust/thrust/host_execution_policy.mdx new file mode 100644 index 0000000..4bcc311 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/host_execution_policy.mdx @@ -0,0 +1,72 @@ +--- +title: thrust::host_execution_policy +description: "[`host_execution_policy`](/library/api/thrust::host_execution_policy) is the base class for all Thrust parallel execution policies which are derived from Thrust's default host backend system configured with the `THRUST_HOST_SYSTEM` macro." +--- + +`host_execution_policy` is the base class for all Thrust parallel execution policies which are derived from Thrust's default host backend system configured with the `THRUST_HOST_SYSTEM` macro. + +Custom user-defined backends which wish to inherit the functionality of Thrust's host backend system should derive a policy from this type in order to interoperate with Thrust algorithm dispatch. + +The following code snippet demonstrates how to derive a standalone custom execution policy from `thrust::host_execution_policy` to implement a backend which specializes `for_each` while inheriting the behavior of every other algorithm from the host system: + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +execution_policy, +[device_execution_policy](/library/api/thrust::device_execution_policy) + +## Example + +```cpp showLineNumbers={false} +#include +#include + +// define a type derived from thrust::host_execution_policy to distinguish our custom execution policy: +struct my_policy : thrust::host_execution_policy {}; + +// overload for_each on my_policy +template +Iterator for_each(my_policy, Iterator first, Iterator last, Function f) +{ + std::cout << "Hello, world from for_each(my_policy)!" << std::endl; + + for(; first < last; ++first) + { + f(*first); + } + + return first; +} + +struct ignore_argument +{ + void operator()(int) {} +}; + +int main() +{ + int data[4]; + + // dispatch thrust::for_each using our custom policy: + my_policy exec; + thrust::for_each(exec, data, data + 4, ignore_argument()); + + // dispatch thrust::transform whose behavior our policy inherits + thrust::transform(exec, data, data, + 4, data, ::cuda::std::identity{}); + + return 0; +} +``` + + + + + + + + + + +**Inherits from:** `thrust::system::__THRUST_HOST_SYSTEM_NAMESPACE::execution_policy< DerivedPolicy >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/host_vector.mdx b/fern/cudapages/thrust/thrust/thrust/host_vector.mdx new file mode 100644 index 0000000..b86faaa --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/host_vector.mdx @@ -0,0 +1,1365 @@ +--- +title: thrust::host_vector +description: "A [`host_vector`](/library/api/thrust::host_vector) is a container that supports random access to elements, constant time removal of elements at the end, and linear time insertion and removal of elements at the beginning or in the middle." +--- + +A `host_vector` is a container that supports random access to elements, constant time removal of elements at the end, and linear time insertion and removal of elements at the beginning or in the middle. + +The number of elements in a `host_vector` may vary dynamically; memory management is automatic. The memory associated with a `host_vector` resides in memory accessible to hosts. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[https://en.cppreference.com/w/cpp/container/vector](https://en.cppreference.com/w/cpp/container/vector), +[device_vector](/library/api/thrust::device_vector), +universal_vector + + + + + + + + + + + + + +**Inherits from:** `detail::vector_base< T, std::allocator< T > >` (public) + +--- + +## Constructors + +### host_vector inline + + + + +This constructor creates an empty `host_vector`. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector() +``` + + + + + +This constructor creates an empty `host_vector`. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + const Alloc &alloc +) +``` + + +**Parameters** + + +The allocator to use by this `host_vector`. + + + + + +explicit + +This constructor creates a `host_vector` with the given size. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + size_type n +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +This constructor creates a `host_vector` with the given size, performing only default-initialization instead of value-initialization. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + size_type n, + default_init_t +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +This constructor creates a `host_vector` with the given size, without initializing elements. + +It mandates that the element type is trivially default-constructible. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + size_type n, + no_init_t +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +explicit + +This constructor creates a `host_vector` with the given size. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + size_type n, + const Alloc &alloc +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +The allocator to use by this `host_vector`. + + + + + +explicit + +This constructor creates a `host_vector` with copies of an exemplar element. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + size_type n, + const value_type &value +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +An element to copy. + + + + + +explicit + +This constructor creates a `host_vector` with copies of an exemplar element. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + size_type n, + const value_type &value, + const Alloc &alloc +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +An element to copy. + + + +The allocator to use by this `host_vector`. + + + + + +Copy constructor copies from an exemplar `host_vector`. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + const host_vector &v +) +``` + + +**Parameters** + + +The `host_vector` to copy. + + + + + +Copy constructor copies from an exemplar `host_vector`. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + const host_vector &v, + const Alloc &alloc +) +``` + + +**Parameters** + + +The `host_vector` to copy. + + + +The allocator to use by this `host_vector`. + + + + + +Copy constructor copies from an exemplar `host_vector` with different type. + + +```cpp showLineNumbers={false} +template +thrust::host_vector::host_vector( + const host_vector &v +) +``` + + +**Parameters** + + +The `host_vector` to copy. + + + + + +Copy constructor copies from an exemplar `std::vector`. + + +```cpp showLineNumbers={false} +template +thrust::host_vector::host_vector( + const std::vector &v +) +``` + + +**Parameters** + + +The `std::vector` to copy. + + + + + +Copy construct from a `vector_base` whose element type is convertible to `T`. + + +```cpp showLineNumbers={false} +template +thrust::host_vector::host_vector( + const detail::vector_base &v +) +``` + + +**Parameters** + + +The `vector_base` to copy. + + + + + +Move constructor moves from another `host_vector`. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + host_vector &&v +) +``` + + +**Parameters** + + +The `host_vector` to move. + + + + + +Move constructor moves from another `host_vector`. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + host_vector &&v, + const Alloc &alloc +) +``` + + +**Parameters** + + +The `host_vector` to move. + + + +The allocator to use by this `host_vector`. + + + + + +This constructor builds a `host_vector` from an intializer_list. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + ::cuda::std::initializer_list il +) +``` + + +**Parameters** + + +The intializer_list. + + + + + +This constructor builds a `host_vector` from an intializer_list. + + +```cpp showLineNumbers={false} +thrust::host_vector::host_vector( + ::cuda::std::initializer_list il, + const Alloc &alloc +) +``` + + +**Parameters** + + +The intializer_list. + + + +The allocator to use by this `host_vector`. + + + + + +This constructor builds a `host_vector` from a range. + + +```cpp showLineNumbers={false} +template +thrust::host_vector::host_vector( + InputIterator first, + InputIterator last +) +``` + + +**Parameters** + + +The beginning of the range. + + + +The end of the range. + + + + + +This constructor builds a `host_vector` from a range. + + +```cpp showLineNumbers={false} +template +thrust::host_vector::host_vector( + InputIterator first, + InputIterator last, + const Alloc &alloc +) +``` + + +**Parameters** + + +The beginning of the range. + + + +The end of the range. + + + +The allocator to use by this `host_vector`. + + + + + +### Destructor + +### ~host_vector inline + +The destructor erases the elements. + + +```cpp showLineNumbers={false} +thrust::host_vector::~host_vector() +``` + + +--- + +## Assignment operators + +### operator= inline + + + + +Assign operator copies from an exemplar `host_vector`. + + +```cpp showLineNumbers={false} +host_vector & thrust::host_vector::operator=( + const host_vector &v +) +``` + + +**Parameters** + + +The `host_vector` to copy. + + + + + +Move assign operator moves from another `host_vector`. + + +```cpp showLineNumbers={false} +host_vector & thrust::host_vector::operator=( + host_vector &&v +) +``` + + +**Parameters** + + +The `host_vector` to move. + + + + + +Assign operator copies from an exemplar `host_vector` with different type. + + +```cpp showLineNumbers={false} +template +host_vector & thrust::host_vector::operator=( + const host_vector &v +) +``` + + +**Parameters** + + +The `host_vector` to copy. + + + + + +Assign operator copies from an exemplar `std::vector`. + + +```cpp showLineNumbers={false} +template +host_vector & thrust::host_vector::operator=( + const std::vector &v +) +``` + + +**Parameters** + + +The `std::vector` to copy. + + + + + +Assign a `vector_base` whose element type is convertible to `T`. + + +```cpp showLineNumbers={false} +template +host_vector & thrust::host_vector::operator=( + const detail::vector_base &v +) +``` + + +**Parameters** + + +The `vector_base` to copy. + + + + + +Assign an `intializer_list` with a matching element type. + + +```cpp showLineNumbers={false} +host_vector & thrust::host_vector::operator=( + ::cuda::std::initializer_list il +) +``` + + +**Parameters** + + +The intializer_list. + + + + + +--- + +## Methods + +### resize + + + + +Resizes this vector to the specified number of elements. + + +```cpp showLineNumbers={false} +void thrust::host_vector::resize( + size_type new_size, + const value_type &x = value_type() +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::host_vector::max_size()). + +**Parameters** + + +Number of elements this vector should contain. + + + +Data with which new elements should be populated. + + + + + +Resizes this vector to the specified number of elements, performing default-initialization instead of value-initialization. + + +```cpp showLineNumbers={false} +void thrust::host_vector::resize( + size_type new_size, + default_init_t +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::host_vector::max_size()). + +**Parameters** + + +Number of elements this vector should contain. + + + + + +Resizes this vector to the specified number of elements, without initializing elements. + +It mandates that the element type is trivially default-constructible. + + +```cpp showLineNumbers={false} +void thrust::host_vector::resize( + size_type new_size, + no_init_t +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::host_vector::max_size()). + +**Parameters** + + +Number of elements this vector should contain. + + + + + +### size const + +Returns the number of elements in this vector. + + +```cpp showLineNumbers={false} +size_type thrust::host_vector::size() const +``` + + +### max_size const + +Returns the [size()](/library/api/thrust::host_vector::size()) of the largest possible vector. + + +```cpp showLineNumbers={false} +size_type thrust::host_vector::max_size() const +``` + + +**Returns:** The largest possible return value of [size()](/library/api/thrust::host_vector::size()). + +### reserve + +If n is less than or equal to [capacity()](/library/api/thrust::host_vector::capacity()), this call has no effect. + + +```cpp showLineNumbers={false} +void thrust::host_vector::reserve( + size_type n +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::host_vector::max_size()). + +### capacity const + +Returns the number of elements which have been reserved in this vector. + + +```cpp showLineNumbers={false} +size_type thrust::host_vector::capacity() const +``` + + +### shrink_to_fit + +This method shrinks the capacity of this vector to exactly fit its elements. + + +```cpp showLineNumbers={false} +void thrust::host_vector::shrink_to_fit() +``` + + +### operator[] + + + + +Subscript access to the data contained in this vector_dev. + +This operator allows for easy, array-style, data access. Note that data access with this operator is unchecked and out_of_range lookups are not defined. + + +```cpp showLineNumbers={false} +reference thrust::host_vector::operator[]( + size_type n +) +``` + + +**Returns:** Read/write reference to data. + +**Parameters** + + +The index of the element for which data should be accessed. + + + + + +const + +Subscript read access to the data contained in this vector_dev. + +This operator allows for easy, array-style, data access. Note that data access with this operator is unchecked and out_of_range lookups are not defined. + + +```cpp showLineNumbers={false} +const_reference thrust::host_vector::operator[]( + size_type n +) const +``` + + +**Returns:** Read reference to data. + +**Parameters** + + +The index of the element for which data should be accessed. + + + + + +### begin + + + + +This method returns an iterator pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +iterator thrust::host_vector::begin() +``` + + +**Returns:** mStart + + + + +const + +This method returns a const_iterator pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::host_vector::begin() const +``` + + +**Returns:** mStart + + + + +### cbegin const + +This method returns a const_iterator pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::host_vector::cbegin() const +``` + + +**Returns:** mStart + +### rbegin + + + + +This method returns a reverse_iterator pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +reverse_iterator thrust::host_vector::rbegin() +``` + + +**Returns:** A reverse_iterator pointing to the beginning of this vector's reversed sequence. + + + + +const + +This method returns a const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::host_vector::rbegin() const +``` + + +**Returns:** A const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + + + + +### crbegin const + +This method returns a const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::host_vector::crbegin() const +``` + + +**Returns:** A const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + +### end + + + + +This method returns an iterator pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +iterator thrust::host_vector::end() +``` + + +**Returns:** [begin()](/library/api/thrust::host_vector::begin()) + [size()](/library/api/thrust::host_vector::size()). + + + + +const + +This method returns a const_iterator pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::host_vector::end() const +``` + + +**Returns:** [begin()](/library/api/thrust::host_vector::begin()) + [size()](/library/api/thrust::host_vector::size()). + + + + +### cend const + +This method returns a const_iterator pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::host_vector::cend() const +``` + + +**Returns:** [begin()](/library/api/thrust::host_vector::begin()) + [size()](/library/api/thrust::host_vector::size()). + +### rend + + + + +This method returns a reverse_iterator pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +reverse_iterator thrust::host_vector::rend() +``` + + +**Returns:** [rbegin()](/library/api/thrust::host_vector::rbegin()) + [size()](/library/api/thrust::host_vector::size()). + + + + +const + +This method returns a const_reverse_iterator pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::host_vector::rend() const +``` + + +**Returns:** [rbegin()](/library/api/thrust::host_vector::rbegin()) + [size()](/library/api/thrust::host_vector::size()). + + + + +### crend const + +This method returns a const_reverse_iterator pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::host_vector::crend() const +``` + + +**Returns:** [rbegin()](/library/api/thrust::host_vector::rbegin()) + [size()](/library/api/thrust::host_vector::size()). + +### front + + + + +This method returns a reference pointing to the first element of this vector. + + +```cpp showLineNumbers={false} +reference thrust::host_vector::front() +``` + + +**Returns:** The first element of this vector. + + + + +const + +This method returns a const_reference referring to the first element of this vector. + + +```cpp showLineNumbers={false} +const_reference thrust::host_vector::front() const +``` + + +**Returns:** The first element of this vector. + + + + +### back + + + + +This method returns a reference referring to the last element of this vector_dev. + + +```cpp showLineNumbers={false} +reference thrust::host_vector::back() +``` + + +**Returns:** The last element of this vector. + + + + +const + +This method returns a const reference pointing to the last element of this vector. + + +```cpp showLineNumbers={false} +const_reference thrust::host_vector::back() const +``` + + +**Returns:** The last element of this vector. + + + + +### data + + + + +This method returns a pointer to this vector's first element. + + +```cpp showLineNumbers={false} +pointer thrust::host_vector::data() +``` + + +**Returns:** A pointer to the first element of this vector. + + + + +const + +This method returns a const_pointer to this vector's first element. + + +```cpp showLineNumbers={false} +const_pointer thrust::host_vector::data() const +``` + + +**Returns:** a const_pointer to the first element of this vector. + + + + +### clear + +This method resizes this vector to 0. + + +```cpp showLineNumbers={false} +void thrust::host_vector::clear() +``` + + +### empty const + +This method returns true iff [size()](/library/api/thrust::host_vector::size()) == 0. + + +```cpp showLineNumbers={false} +bool thrust::host_vector::empty() const +``` + + +**Returns:** true if [size()](/library/api/thrust::host_vector::size()) == 0; false, otherwise. + +### push_back + +This method appends the given element to the end of this vector. + + +```cpp showLineNumbers={false} +void thrust::host_vector::push_back( + const value_type &x +) +``` + + +**Parameters** + + +The element to append. + + +### pop_back + +This method erases the last element of this vector, invalidating all iterators and references to it. + + +```cpp showLineNumbers={false} +void thrust::host_vector::pop_back() +``` + + +### swap + +This method swaps the contents of this `host_vector` with another vector. + + +```cpp showLineNumbers={false} +void thrust::host_vector::swap( + host_vector &v +) +``` + + +**Parameters** + + +The vector with which to swap. + + +### erase + + + + +This method removes the element at position pos. + + +```cpp showLineNumbers={false} +iterator thrust::host_vector::erase( + iterator pos +) +``` + + +**Returns:** An iterator pointing to the new location of the element that followed the element at position pos. + +**Parameters** + + +The position of the element of interest. + + + + + +This method removes the range of elements [first,last) from this vector. + + +```cpp showLineNumbers={false} +iterator thrust::host_vector::erase( + iterator first, + iterator last +) +``` + + +**Returns:** An iterator pointing to the new location of the element that followed the last element in the sequence [first,last). + +**Parameters** + + +The beginning of the range of elements to remove. + + + +The end of the range of elements to remove. + + + + + +### insert + + + + +This method inserts a single copy of a given exemplar value at the specified position in this vector. + + +```cpp showLineNumbers={false} +iterator thrust::host_vector::insert( + iterator position, + const T &x +) +``` + + +**Returns:** An iterator pointing to the newly inserted element. + +**Parameters** + + +The insertion position. + + + +The exemplar element to copy & insert. + + + + + +This method inserts a copy of an exemplar value to a range at the specified position in this vector. + + +```cpp showLineNumbers={false} +void thrust::host_vector::insert( + iterator position, + size_type n, + const T &x +) +``` + + +**Parameters** + + +The insertion position + + + +The number of insertions to perform. + + + +The value to replicate and insert. + + + + + +This method inserts a copy of an input range at the specified position in this vector. + + +```cpp showLineNumbers={false} +template +void thrust::host_vector::insert( + iterator position, + InputIterator first, + InputIterator last +) +``` + + +**Template parameters** + + +Is a model of [Input Iterator](https://en.cppreference.com/w/cpp/iterator/input_iterator), and `InputIterator's` `value_type` is a model of [Assignable.](https://en.cppreference.com/w/cpp/named_req/CopyAssignable) + + +**Parameters** + + +The insertion position. + + + +The beginning of the range to copy. + + + +The end of the range to copy. + + + + + +### assign + + + + +This version of `assign` replicates a given exemplar `n` times into this vector. + + +```cpp showLineNumbers={false} +void thrust::host_vector::assign( + size_type n, + const T &x +) +``` + + +**Parameters** + + +The number of times to copy `x`. + + + +The exemplar element to replicate. + + + + + +This version of `assign` makes this vector a copy of a given input range. + + +```cpp showLineNumbers={false} +template +void thrust::host_vector::assign( + InputIterator first, + InputIterator last +) +``` + + +**Template parameters** + + +Is a model of [Input Iterator](https://en.cppreference.com/w/cpp/named_req/InputIterator). + + +**Parameters** + + +The beginning of the range to copy. + + + +The end of the range to copy. + + + + + +### get_allocator const + +This method returns a copy of this vector's allocator. + + +```cpp showLineNumbers={false} +allocator_type thrust::host_vector::get_allocator() const +``` + + +**Returns:** A copy of the allocator used by this vector. + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `Parent` | `detail::vector_base< T, Alloc >` | diff --git a/fern/cudapages/thrust/thrust/thrust/incrementable_traversal_tag.mdx b/fern/cudapages/thrust/thrust/thrust/incrementable_traversal_tag.mdx new file mode 100644 index 0000000..d98cea2 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/incrementable_traversal_tag.mdx @@ -0,0 +1,12 @@ +--- +title: thrust::incrementable_traversal_tag +description: "Tag type for iterators allowing incrementable traversal." +--- + +Tag type for iterators allowing incrementable traversal. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::no_traversal_tag` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/input_device_iterator_tag.mdx b/fern/cudapages/thrust/thrust/thrust/input_device_iterator_tag.mdx new file mode 100644 index 0000000..3dbd8a2 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/input_device_iterator_tag.mdx @@ -0,0 +1,17 @@ +--- +title: thrust::input_device_iterator_tag +description: "[`input_device_iterator_tag`](/library/api/thrust::input_device_iterator_tag) is an empty class: it has no member functions, member variables, or nested types." +--- + +`input_device_iterator_tag` is an empty class: it has no member functions, member variables, or nested types. + +It is used solely as a "tag": a representation of the Input Device Iterator concept within the C++ type system. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[https://en.cppreference.com/w/cpp/iterator/iterator_tags](https://en.cppreference.com/w/cpp/iterator/iterator_tags) iterator_traits, [output_device_iterator_tag](/library/api/thrust::output_device_iterator_tag), [forward_device_iterator_tag](/library/api/thrust::forward_device_iterator_tag), [bidirectional_device_iterator_tag](/library/api/thrust::bidirectional_device_iterator_tag), [random_access_device_iterator_tag](/library/api/thrust::random_access_device_iterator_tag), input_host_iterator_tag, output_host_iterator_tag, forward_host_iterator_tag, bidirectional_host_iterator_tag, random_access_host_iterator_tag + +**Inherits from:** `detail::iterator_category_with_system_and_traversal<::cuda::std::input_iterator_tag, device_system_tag, single_pass_traversal_tag >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/is_error_code_enum.mdx b/fern/cudapages/thrust/thrust/thrust/is_error_code_enum.mdx new file mode 100644 index 0000000..dad9eba --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/is_error_code_enum.mdx @@ -0,0 +1,21 @@ +--- +title: thrust::is_error_code_enum +description: "A metafunction returning whether or not the parameter is an [`error_code`](/library/api/thrust::error_code) enum." +--- + +A metafunction returning whether or not the parameter is an [`error_code`](/library/api/thrust::error_code) enum. + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `thrust::detail::false_type` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/is_error_condition_enum.mdx b/fern/cudapages/thrust/thrust/thrust/is_error_condition_enum.mdx new file mode 100644 index 0000000..45126d4 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/is_error_condition_enum.mdx @@ -0,0 +1,21 @@ +--- +title: thrust::is_error_condition_enum +description: "A metafunction returning whether or not the parameter is an [`error_condition`](/library/api/thrust::error_condition) enum." +--- + +A metafunction returning whether or not the parameter is an [`error_condition`](/library/api/thrust::error_condition) enum. + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `thrust::detail::false_type` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_adaptor.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_adaptor.mdx new file mode 100644 index 0000000..43cc5c8 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_adaptor.mdx @@ -0,0 +1,182 @@ +--- +title: thrust::iterator_adaptor +description: "[`iterator_adaptor`](/library/api/thrust::iterator_adaptor) is an iterator which adapts an existing type of iterator to create a new type of iterator." +--- + +`iterator_adaptor` is an iterator which adapts an existing type of iterator to create a new type of iterator. + +Most of Thrust's fancy iterators are defined via inheritance from `iterator_adaptor`. While composition of these existing Thrust iterators is often sufficient for expressing the desired functionality, it is occasionally more straightforward to derive from `iterator_adaptor` directly. + +To see how to use `iterator_adaptor` to create a novel iterator type, let's examine how to use it to define `repeat_iterator`, a fancy iterator which repeats elements from another range a given number of time: + +Except for the first two, `iterator_adaptor``'s` template parameters are optional. When omitted, or when the user specifies `thrust::use_default` in its place, `iterator_adaptor` will use a default type inferred from `Base`. + +`iterator_adaptor``'s` functionality is derived from and generally equivalent to `boost::iterator_adaptor`. The exception is Thrust's addition of the template parameter `System`, which is necessary to allow Thrust to dispatch an algorithm to one of several parallel backend systems. + +`iterator_adaptor` is a powerful tool for creating custom iterators directly. However, the large set of iterator semantics which must be satisfied for algorithm compatibility can make `iterator_adaptor` difficult to use correctly. Unless you require the full expressivity of `iterator_adaptor`, consider building a custom iterator through composition of existing higher-level fancy iterators instead. + +Interested users may refer to `boost::iterator_adaptor`'s documentation for further usage examples. + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include + +// derive repeat_iterator from iterator_adaptor +template + class repeat_iterator + : public thrust::iterator_adaptor< + repeat_iterator, // the first template parameter is the name of the iterator we're creating + Iterator // the second template parameter is the name of the iterator we're adapting + // we can use the default for the additional template parameters + > +{ + public: + // shorthand for the name of the iterator_adaptor we're deriving from + using super_t = thrust::iterator_adaptor< + repeat_iterator, + Iterator + >; + + __host__ __device__ + repeat_iterator(const Iterator &x, int n) : super_t(x), begin(x), n(n) {} + + // befriend thrust::iterator_core_access to allow it access to the private interface below + friend class thrust::iterator_core_access; + + private: + // repeat each element of the adapted range n times + unsigned int n; + + // used to keep track of where we began + Iterator begin; + + // it is private because only thrust::iterator_core_access needs access to it + __host__ __device__ + typename super_t::reference dereference() const + { + return *(begin + (this->base() - begin) / n); + } +}; +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + +**Inherits from:** `detail::make_iterator_adaptor_base::type` (public) + +--- + +## Constructors + +### iterator_adaptor + + + + +`iterator_adaptor``'s` default constructor does nothing. + + +```cpp showLineNumbers={false} +thrust::iterator_adaptor::iterator_adaptor() = default +``` + + + + + +inline explicit + +This constructor copies from a given instance of the `Base` iterator. + + +```cpp showLineNumbers={false} +thrust::iterator_adaptor::iterator_adaptor( + Base const &iter +) +``` + + + + + +--- + +## Methods + +### base inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this `iterator_adaptor` adapts. + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +Base & thrust::iterator_adaptor::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this `iterator_adaptor` adapts. + + + + +const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this `iterator_adaptor` adapts. + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `Base` | The type of iterator this `iterator_adaptor``'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_core_access.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_core_access.mdx new file mode 100644 index 0000000..dc482ad --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_core_access.mdx @@ -0,0 +1,10 @@ +--- +title: thrust::iterator_core_access +description: "[`iterator_core_access`](/library/api/thrust::iterator_core_access) is the class which user iterator types derived from [`thrust::iterator_adaptor`](/library/api/thrust::iterator_adaptor) or [`thrust::iterator_facade`](/library/api/thrust::iterator_facade) must befriend to allow it to access their private interface." +--- + +`iterator_core_access` is the class which user iterator types derived from [`thrust::iterator_adaptor`](/library/api/thrust::iterator_adaptor) or [`thrust::iterator_facade`](/library/api/thrust::iterator_facade) must befriend to allow it to access their private interface. + +```cpp showLineNumbers={false} +#include +``` diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_difference.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_difference.mdx new file mode 100644 index 0000000..06ebe58 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_difference.mdx @@ -0,0 +1,29 @@ +--- +title: thrust::iterator_difference +description: "deprecated [Since 3.0]" +--- + +deprecated [Since 3.0] + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `typename iterator_traits< Iterator >::difference_type` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_facade.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_facade.mdx new file mode 100644 index 0000000..86d8844 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_facade.mdx @@ -0,0 +1,216 @@ +--- +title: thrust::iterator_facade +description: "[`iterator_facade`](/library/api/thrust::iterator_facade) is a template which allows the programmer to define a novel iterator with a standards-conforming interface which Thrust can use to reason about algorithm acceleration opportunities." +--- + +`iterator_facade` is a template which allows the programmer to define a novel iterator with a standards-conforming interface which Thrust can use to reason about algorithm acceleration opportunities. + +Because most of a standard iterator's interface is defined in terms of a small set of core primitives, `iterator_facade` defines the non-primitive portion mechanically. In principle a novel iterator could explicitly provide the entire interface in an ad hoc fashion but doing so might be tedious and prone to subtle errors. + +Often `iterator_facade` is too primitive a tool to use for defining novel iterators. In these cases, [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) or a specific fancy iterator should be used instead. + +`iterator_facade``'s` functionality is derived from and generally equivalent to `boost::iterator_facade`. The exception is Thrust's addition of the template parameter `System`, which is necessary to allow Thrust to dispatch an algorithm to one of several parallel backend systems. An additional exception is Thrust's omission of the `operator->` member function. + +Interested users may refer to `boost::iterator_facade`'s documentation for usage examples. + +```cpp showLineNumbers={false} +#include +``` + + +`iterator_facade``'s` arithmetic operator free functions exist with the usual meanings but are omitted here for brevity. + + + + + + + + + + + + + + + + + + + + + + + + + + +--- + +## Methods + +### operator* inline const + +[`operator*()`](/library/api/thrust::iterator_facade::operator*()) dereferences this `iterator_facade`. + + +```cpp showLineNumbers={false} +reference thrust::iterator_facade::operator*() const +``` + + +**Returns:** A reference to the element pointed to by this `iterator_facade`. + +### operator[] inline const + +`operator`[] performs indexed dereference. + + +```cpp showLineNumbers={false} +reference thrust::iterator_facade::operator[]( + difference_type n +) const +``` + + +**Returns:** A reference to the element `n` distance away from this `iterator_facade`. + +### operator++ inline + + + + +`operator++` pre-increments this `iterator_facade` to refer to the element in the next position. + + +```cpp showLineNumbers={false} +Derived & thrust::iterator_facade::operator++() +``` + + +**Returns:** `*this` + + + + +`operator++` post-increments this `iterator_facade` and returns a new `iterator_facade` referring to the element in the next position. + + +```cpp showLineNumbers={false} +Derived thrust::iterator_facade::operator++( + int +) +``` + + +**Returns:** A copy of `*this` before increment. + + + + +### operator-- inline + + + + +`operator--` pre-decrements this `iterator_facade` to refer to the element in the previous position. + + +```cpp showLineNumbers={false} +Derived & thrust::iterator_facade::operator--() +``` + + +**Returns:** `*this` + + + + +`operator--` post-decrements this `iterator_facade` and returns a new `iterator_facade` referring to the element in the previous position. + + +```cpp showLineNumbers={false} +Derived thrust::iterator_facade::operator--( + int +) +``` + + +**Returns:** A copy of `*this` before decrement. + + + + +### operator+= inline + +`operator+=` increments this `iterator_facade` to refer to an element a given distance after its current position. + + +```cpp showLineNumbers={false} +Derived & thrust::iterator_facade::operator+=( + difference_type n +) +``` + + +**Returns:** `*this` + +**Parameters** + + +The quantity to increment. + + +### operator-= inline + +`operator-=` decrements this `iterator_facade` to refer to an element a given distance before its current position. + + +```cpp showLineNumbers={false} +Derived & thrust::iterator_facade::operator-=( + difference_type n +) +``` + + +**Returns:** `*this` + +**Parameters** + + +The quantity to decrement. + + +### operator- inline const + +`operator-` subtracts a given quantity from this `iterator_facade` and returns a new `iterator_facade` referring to the element at the given position before this `iterator_facade`. + + +```cpp showLineNumbers={false} +Derived thrust::iterator_facade::operator-( + difference_type n +) const +``` + + +**Returns:** An `iterator_facade` pointing `n` elements before this `iterator_facade`. + +**Parameters** + + +The quantity to decrement + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `value_type` | `::cuda::std::remove_const_t< Value >` | The type of element pointed to by `iterator_facade`. | +| `reference` | `Reference` | The return type of [`iterator_facade::operator*()`](/library/api/thrust::iterator_facade::operator*()). | +| `pointer` | `void` | The return type of `iterator_facade``'s` non-existent `operator->()` member function. | +| `difference_type` | `Difference` | The type of expressions of the form `x - y` where `x` and `y` are of type `iterator_facade`. | +| `iterator_category` | `detail::iterator_facade_category_t< System, Traversal >` | The type of iterator category of `iterator_facade`. | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_pointer.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_pointer.mdx new file mode 100644 index 0000000..9e77875 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_pointer.mdx @@ -0,0 +1,29 @@ +--- +title: thrust::iterator_pointer +description: "deprecated [Since 3.0]" +--- + +deprecated [Since 3.0] + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `typename iterator_traits< Iterator >::pointer` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_reference.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_reference.mdx new file mode 100644 index 0000000..6232233 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_reference.mdx @@ -0,0 +1,29 @@ +--- +title: thrust::iterator_reference +description: "deprecated [Since 3.0]" +--- + +deprecated [Since 3.0] + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `typename iterator_traits< Iterator >::reference` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system.mdx new file mode 100644 index 0000000..f5ed6de --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system.mdx @@ -0,0 +1,21 @@ +--- +title: thrust::iterator_system +description: "Trait obtaining the iterator system of an iterator type, usually as the systems tag type." +--- + +Trait obtaining the iterator system of an iterator type, usually as the systems tag type. + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< Iterator >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_const_void_ptr.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_const_void_ptr.mdx new file mode 100644 index 0000000..33440ce --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_const_void_ptr.mdx @@ -0,0 +1,10 @@ +--- +title: "thrust::iterator_system< const void * >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::detail::iterator_system_impl< const void * >` (public), `thrust::iterator_system< const int * >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudaconstant_iterator_T_Index.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudaconstant_iterator_T_Index.mdx new file mode 100644 index 0000000..f2efc61 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudaconstant_iterator_T_Index.mdx @@ -0,0 +1,32 @@ +--- +title: "thrust::iterator_system<::cuda::constant_iterator< T, Index > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::constant_iterator< T, Index > >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `any_system_tag` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudacounting_iterator_Start.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudacounting_iterator_Start.mdx new file mode 100644 index 0000000..6ebe750 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudacounting_iterator_Start.mdx @@ -0,0 +1,29 @@ +--- +title: "thrust::iterator_system<::cuda::counting_iterator< Start > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::counting_iterator< Start > >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `any_system_tag` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudadiscard_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudadiscard_iterator.mdx new file mode 100644 index 0000000..2b387ed --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudadiscard_iterator.mdx @@ -0,0 +1,20 @@ +--- +title: "thrust::iterator_system<::cuda::discard_iterator >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::discard_iterator >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `any_system_tag` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudapermutation_iterator_Iter_Offset.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudapermutation_iterator_Iter_Offset.mdx new file mode 100644 index 0000000..61875da --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudapermutation_iterator_Iter_Offset.mdx @@ -0,0 +1,32 @@ +--- +title: "thrust::iterator_system<::cuda::permutation_iterator< Iter, Offset > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::permutation_iterator< Iter, Offset > >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `detail::minimum_system_t< iterator_system_t< Iter >, iterator_system_t< Offset > >` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudashuffle_iterator_IndexType_Bijection.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudashuffle_iterator_IndexType_Bijection.mdx new file mode 100644 index 0000000..be126c4 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudashuffle_iterator_IndexType_Bijection.mdx @@ -0,0 +1,32 @@ +--- +title: "thrust::iterator_system<::cuda::shuffle_iterator< IndexType, Bijection > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::shuffle_iterator< IndexType, Bijection > >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `any_system_tag` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudastdreverse_iterator_Iter.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudastdreverse_iterator_Iter.mdx new file mode 100644 index 0000000..3acc7ed --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudastdreverse_iterator_Iter.mdx @@ -0,0 +1,19 @@ +--- +title: "thrust::iterator_system<::cuda::std::reverse_iterator< Iter > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::std::reverse_iterator< Iter > >` (public), `thrust::iterator_system< Iter >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudastrided_iterator_Iter_Stride.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudastrided_iterator_Iter_Stride.mdx new file mode 100644 index 0000000..9d85624 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudastrided_iterator_Iter_Stride.mdx @@ -0,0 +1,22 @@ +--- +title: "thrust::iterator_system<::cuda::strided_iterator< Iter, Stride > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::strided_iterator< Iter, Stride > >` (public), `thrust::iterator_system< Iter >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatabulate_output_iterator_Fn_Index.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatabulate_output_iterator_Fn_Index.mdx new file mode 100644 index 0000000..68d946a --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatabulate_output_iterator_Fn_Index.mdx @@ -0,0 +1,32 @@ +--- +title: "thrust::iterator_system<::cuda::tabulate_output_iterator< Fn, Index > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::tabulate_output_iterator< Fn, Index > >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `any_system_tag` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_input_output_iterator_InputFn_OutputFn_Iter.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_input_output_iterator_InputFn_OutputFn_Iter.mdx new file mode 100644 index 0000000..3bdb741 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_input_output_iterator_InputFn_OutputFn_Iter.mdx @@ -0,0 +1,25 @@ +--- +title: "thrust::iterator_system<::cuda::transform_input_output_iterator< InputFn, OutputFn, Iter > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::transform_input_output_iterator< InputFn, OutputFn, Iter > >` (public), `thrust::iterator_system< Iter >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_iterator_Fn_Iter.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_iterator_Fn_Iter.mdx new file mode 100644 index 0000000..13bb971 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_iterator_Fn_Iter.mdx @@ -0,0 +1,22 @@ +--- +title: "thrust::iterator_system<::cuda::transform_iterator< Fn, Iter > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::transform_iterator< Fn, Iter > >` (public), `thrust::iterator_system< Iter >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_output_iterator_Fn_Iter.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_output_iterator_Fn_Iter.mdx new file mode 100644 index 0000000..f42eb89 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudatransform_output_iterator_Fn_Iter.mdx @@ -0,0 +1,22 @@ +--- +title: "thrust::iterator_system<::cuda::transform_output_iterator< Fn, Iter > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::transform_output_iterator< Fn, Iter > >` (public), `thrust::iterator_system< Iter >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudazip_iterator_Iterators.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudazip_iterator_Iterators.mdx new file mode 100644 index 0000000..25180c0 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudazip_iterator_Iterators.mdx @@ -0,0 +1,29 @@ +--- +title: "thrust::iterator_system<::cuda::zip_iterator< Iterators... > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::zip_iterator< Iterators... > >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `detail::minimum_system_t< iterator_system_t< Iterators >... >` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_cudazip_transform_iterator_Fn_Iterators.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudazip_transform_iterator_Fn_Iterators.mdx new file mode 100644 index 0000000..c3672f7 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_cudazip_transform_iterator_Fn_Iterators.mdx @@ -0,0 +1,32 @@ +--- +title: "thrust::iterator_system<::cuda::zip_transform_iterator< Fn, Iterators... > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::detail::iterator_system_impl< ::cuda::zip_transform_iterator< Fn, Iterators... > >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `detail::minimum_system_t< iterator_system_t< Iterators >... >` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_system_void_ptr.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_system_void_ptr.mdx new file mode 100644 index 0000000..55a43c9 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_system_void_ptr.mdx @@ -0,0 +1,10 @@ +--- +title: "thrust::iterator_system< void * >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::detail::iterator_system_impl< void * >` (public), `thrust::iterator_system< int * >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal.mdx new file mode 100644 index 0000000..cd49f57 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal.mdx @@ -0,0 +1,21 @@ +--- +title: thrust::iterator_traversal +description: "Trait obtaining the iterator traversal category of an iterator type." +--- + +Trait obtaining the iterator traversal category of an iterator type. + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< Iterator >::iterator_category >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudaconstant_iterator_T_Index.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudaconstant_iterator_T_Index.mdx new file mode 100644 index 0000000..7c87b05 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudaconstant_iterator_T_Index.mdx @@ -0,0 +1,32 @@ +--- +title: "thrust::iterator_traversal<::cuda::constant_iterator< T, Index > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::constant_iterator< T, Index > >::iterator_category >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `random_access_traversal_tag` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudacounting_iterator_Start.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudacounting_iterator_Start.mdx new file mode 100644 index 0000000..e253e23 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudacounting_iterator_Start.mdx @@ -0,0 +1,29 @@ +--- +title: "thrust::iterator_traversal<::cuda::counting_iterator< Start > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::counting_iterator< Start > >::iterator_category >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `random_access_traversal_tag` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudadiscard_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudadiscard_iterator.mdx new file mode 100644 index 0000000..5ffaa33 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudadiscard_iterator.mdx @@ -0,0 +1,20 @@ +--- +title: "thrust::iterator_traversal<::cuda::discard_iterator >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::discard_iterator >::iterator_category >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `random_access_traversal_tag` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudapermutation_iterator_Iter_Offset.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudapermutation_iterator_Iter_Offset.mdx new file mode 100644 index 0000000..befade6 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudapermutation_iterator_Iter_Offset.mdx @@ -0,0 +1,32 @@ +--- +title: "thrust::iterator_traversal<::cuda::permutation_iterator< Iter, Offset > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::permutation_iterator< Iter, Offset > >::iterator_category >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `random_access_traversal_tag` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudashuffle_iterator_IndexType_Bijection.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudashuffle_iterator_IndexType_Bijection.mdx new file mode 100644 index 0000000..4fd94d0 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudashuffle_iterator_IndexType_Bijection.mdx @@ -0,0 +1,32 @@ +--- +title: "thrust::iterator_traversal<::cuda::shuffle_iterator< IndexType, Bijection > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::shuffle_iterator< IndexType, Bijection > >::iterator_category >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `random_access_traversal_tag` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudastdreverse_iterator_Iter.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudastdreverse_iterator_Iter.mdx new file mode 100644 index 0000000..a8f017c --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudastdreverse_iterator_Iter.mdx @@ -0,0 +1,19 @@ +--- +title: "thrust::iterator_traversal<::cuda::std::reverse_iterator< Iter > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::std::reverse_iterator< Iter > >::iterator_category >` (public), `thrust::iterator_traversal< Iter >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudastrided_iterator_Iter_Stride.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudastrided_iterator_Iter_Stride.mdx new file mode 100644 index 0000000..4d29725 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudastrided_iterator_Iter_Stride.mdx @@ -0,0 +1,22 @@ +--- +title: "thrust::iterator_traversal<::cuda::strided_iterator< Iter, Stride > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::strided_iterator< Iter, Stride > >::iterator_category >` (public), `thrust::iterator_traversal< Iter >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatabulate_output_iterator_Fn_Index.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatabulate_output_iterator_Fn_Index.mdx new file mode 100644 index 0000000..6626f96 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatabulate_output_iterator_Fn_Index.mdx @@ -0,0 +1,32 @@ +--- +title: "thrust::iterator_traversal<::cuda::tabulate_output_iterator< Fn, Index > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::tabulate_output_iterator< Fn, Index > >::iterator_category >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `random_access_traversal_tag` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_input_output_iterator_InputFn_OutputFn_Iter.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_input_output_iterator_InputFn_OutputFn_Iter.mdx new file mode 100644 index 0000000..7b1a839 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_input_output_iterator_InputFn_OutputFn_Iter.mdx @@ -0,0 +1,25 @@ +--- +title: "thrust::iterator_traversal<::cuda::transform_input_output_iterator< InputFn, OutputFn, Iter > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::transform_input_output_iterator< InputFn, OutputFn, Iter > >::iterator_category >` (public), `thrust::iterator_traversal< Iter >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_iterator_Fn_Iter.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_iterator_Fn_Iter.mdx new file mode 100644 index 0000000..9f62284 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_iterator_Fn_Iter.mdx @@ -0,0 +1,22 @@ +--- +title: "thrust::iterator_traversal<::cuda::transform_iterator< Fn, Iter > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::transform_iterator< Fn, Iter > >::iterator_category >` (public), `thrust::iterator_traversal< Iter >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_output_iterator_Fn_Iter.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_output_iterator_Fn_Iter.mdx new file mode 100644 index 0000000..53f19e0 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_output_iterator_Fn_Iter.mdx @@ -0,0 +1,22 @@ +--- +title: "thrust::iterator_traversal<::cuda::transform_output_iterator< Fn, Iter > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::transform_output_iterator< Fn, Iter > >::iterator_category >` (public), `thrust::iterator_traversal< Iter >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudazip_iterator_Iterators.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudazip_iterator_Iterators.mdx new file mode 100644 index 0000000..f6b676d --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudazip_iterator_Iterators.mdx @@ -0,0 +1,29 @@ +--- +title: "thrust::iterator_traversal<::cuda::zip_iterator< Iterators... > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::zip_iterator< Iterators... > >::iterator_category >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `detail::minimum_type< iterator_traversal_t< Iterators >... >` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudazip_transform_iterator_Fn_Iterators.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudazip_transform_iterator_Fn_Iterators.mdx new file mode 100644 index 0000000..4c02d16 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_traversal_cudazip_transform_iterator_Fn_Iterators.mdx @@ -0,0 +1,32 @@ +--- +title: "thrust::iterator_traversal<::cuda::zip_transform_iterator< Fn, Iterators... > >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `detail::iterator_category_to_traversal< iterator_traits< ::cuda::zip_transform_iterator< Fn, Iterators... > >::iterator_category >` (public) + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `detail::minimum_type< iterator_traversal_t< Iterators >... >` | diff --git a/fern/cudapages/thrust/thrust/thrust/iterator_value.mdx b/fern/cudapages/thrust/thrust/thrust/iterator_value.mdx new file mode 100644 index 0000000..e3a9d35 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/iterator_value.mdx @@ -0,0 +1,29 @@ +--- +title: thrust::iterator_value +description: "deprecated [Since 3.0]" +--- + +deprecated [Since 3.0] + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `type` | `typename iterator_traits< Iterator >::value_type` | diff --git a/fern/cudapages/thrust/thrust/thrust/linear_congruential_engine.mdx b/fern/cudapages/thrust/thrust/thrust/linear_congruential_engine.mdx new file mode 100644 index 0000000..aff4595 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/linear_congruential_engine.mdx @@ -0,0 +1,200 @@ +--- +title: thrust::linear_congruential_engine +description: "A [`linear_congruential_engine`](/library/api/thrust::linear_congruential_engine) random number engine produces unsigned integer random numbers using a linear congruential random number generation algorithm." +--- + +A `linear_congruential_engine` random number engine produces unsigned integer random numbers using a linear congruential random number generation algorithm. + +The generation algorithm has the form `x_i = (a * x_{i-1} + c) mod m`. + +The following code snippet shows examples of use of a `linear_congruential_engine` instance: + +```cpp showLineNumbers={false} +#include +``` + + +Inexperienced users should not use this class template directly. Instead, use `minstd_rand` or `minstd_rand0`. + + +**See also:** +thrust::random::minstd_rand, +thrust::random::minstd_rand0 + +## Example + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + // create a minstd_rand object, which is an instance of linear_congruential_engine + thrust::minstd_rand rng1; + + // output some random values to cout + std::cout << rng1() << std::endl; + + // a random value is printed + + // create a new minstd_rand from a seed + thrust::minstd_rand rng2(13); + + // discard some random values + rng2.discard(13); + + // stream the object to an iostream + std::cout << rng2 << std::endl; + + // rng2's current state is printed + + // print the minimum and maximum values that minstd_rand can produce + std::cout << thrust::minstd_rand::min << std::endl; + std::cout << thrust::minstd_rand::max << std::endl; + + // the range of minstd_rand is printed + + // save the state of rng2 to a different object + thrust::minstd_rand rng3 = rng2; + + // compare rng2 and rng3 + std::cout << (rng2 == rng3) << std::endl; + + // 1 is printed + + // re-seed rng2 with a different seed + rng2.seed(7); + + // compare rng2 and rng3 + std::cout << (rng2 == rng3) << std::endl; + + // 0 is printed + + return 0; +} +``` + + + + + +The type of unsigned integer to produce. + + + +The multiplier used in the generation algorithm. + + + +The increment used in the generation algorithm. + + + +The modulus used in the generation algorithm. + + + + + +--- + +## Constructors + +### linear_congruential_engine explicit + +This constructor, which optionally accepts a seed, initializes a new `linear_congruential_engine`. + + +```cpp showLineNumbers={false} +thrust::random::linear_congruential_engine::linear_congruential_engine( + result_type s = default_seed +) +``` + + +**Parameters** + + +The seed used to initialize this `linear_congruential_engine``'s` state. + + +--- + +## Methods + +### seed + +This method initializes this `linear_congruential_engine``'s` state, and optionally accepts a seed value. + + +```cpp showLineNumbers={false} +void thrust::random::linear_congruential_engine::seed( + result_type s = default_seed +) +``` + + +**Parameters** + + +The seed used to initializes this `linear_congruential_engine``'s` state. + + +### operator() + +This member function produces a new random value and updates this `linear_congruential_engine``'s` state. + + +```cpp showLineNumbers={false} +result_type thrust::random::linear_congruential_engine::operator()( + void +) +``` + + +**Returns:** A new random number. + +### discard + +This member function advances this `linear_congruential_engine``'s` state a given number of times and discards the results. + + +```cpp showLineNumbers={false} +void thrust::random::linear_congruential_engine::discard( + unsigned long long z +) +``` + + + +This function is provided because an implementation may be able to accelerate it. + + +**Parameters** + + +The number of random values to discard. + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `result_type` | `UIntType` | The type of the unsigned integer produced by this `linear_congruential_engine`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `multiplier` static | `const result_type` | The multiplier used in the generation algorithm. | +| `increment` static | `const result_type` | The increment used in the generation algorithm. | +| `modulus` static | `const result_type` | The modulus used in the generation algorithm. | +| `min` static | `const result_type` | The smallest value this `linear_congruential_engine` may potentially produce. | +| `max` static | `const result_type` | The largest value this `linear_congruential_engine` may potentially produce. | +| `default_seed` static | `const result_type` | The default seed of this `linear_congruential_engine`. | diff --git a/fern/cudapages/thrust/thrust/thrust/linear_feedback_shift_engine.mdx b/fern/cudapages/thrust/thrust/thrust/linear_feedback_shift_engine.mdx new file mode 100644 index 0000000..4a4ef24 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/linear_feedback_shift_engine.mdx @@ -0,0 +1,144 @@ +--- +title: thrust::linear_feedback_shift_engine +description: "A [`linear_feedback_shift_engine`](/library/api/thrust::linear_feedback_shift_engine) random number engine produces unsigned integer random values using a linear feedback shift random number generation algorithm." +--- + +A `linear_feedback_shift_engine` random number engine produces unsigned integer random values using a linear feedback shift random number generation algorithm. + +```cpp showLineNumbers={false} +#include +``` + + +`linear_feedback_shift_engine` is based on the Boost Template Library's linear_feedback_shift. + + + + + + +The type of unsigned integer to produce. + + + +The word size of the produced values (`w <= sizeof(UIntType)`). + + + +The k parameter of Tausworthe's 1965 algorithm. + + + +The q exponent of Tausworthe's 1965 algorithm. + + + +The step size of Tausworthe's 1965 algorithm. + + + + + +--- + +## Constructors + +### linear_feedback_shift_engine explicit + +This constructor, which optionally accepts a seed, initializes a new `linear_feedback_shift_engine`. + + +```cpp showLineNumbers={false} +thrust::random::linear_feedback_shift_engine::linear_feedback_shift_engine( + result_type value = default_seed +) +``` + + +**Parameters** + + +The seed used to initialize this `linear_feedback_shift_engine``'s` state. + + +--- + +## Methods + +### seed + +This method initializes this `linear_feedback_shift_engine``'s` state, and optionally accepts a seed value. + + +```cpp showLineNumbers={false} +void thrust::random::linear_feedback_shift_engine::seed( + result_type value = default_seed +) +``` + + +**Parameters** + + +The seed used to initializes this `linear_feedback_shift_engine``'s` state. + + +### operator() + +This member function produces a new random value and updates this `linear_feedback_shift_engine``'s` state. + + +```cpp showLineNumbers={false} +result_type thrust::random::linear_feedback_shift_engine::operator()( + void +) +``` + + +**Returns:** A new random number. + +### discard + +This member function advances this `linear_feedback_shift_engine``'s` state a given number of times and discards the results. + + +```cpp showLineNumbers={false} +void thrust::random::linear_feedback_shift_engine::discard( + unsigned long long z +) +``` + + + +This function is provided because an implementation may be able to accelerate it. + + +**Parameters** + + +The number of random values to discard. + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `result_type` | `UIntType` | The type of the unsigned integer produced by this `linear_feedback_shift_engine`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `word_size` static | `const size_t` | The word size of the produced values. | +| `exponent1` static | `const size_t` | A constant used in the generation algorithm. | +| `exponent2` static | `const size_t` | A constant used in the generation algorithm. | +| `step_size` static | `const size_t` | The step size used in the generation algorithm. | +| `min` static | `const result_type` | The smallest value this `linear_feedback_shift_engine` may potentially produce. | +| `max` static | `const result_type` | The largest value this `linear_feedback_shift_engine` may potentially produce. | +| `default_seed` static | `const result_type` | The default seed of this `linear_feedback_shift_engine`. | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/allocator.mdx b/fern/cudapages/thrust/thrust/thrust/mr/allocator.mdx new file mode 100644 index 0000000..cb06b40 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/allocator.mdx @@ -0,0 +1,185 @@ +--- +title: thrust::mr::allocator +description: "An [`mr::allocator`](/library/api/thrust::mr::allocator) is a template that fulfills the C++ requirements for Allocators, allowing to use the NPA-based memory resources where an Allocator is required." +--- + +An `mr::allocator` is a template that fulfills the C++ requirements for Allocators, allowing to use the NPA-based memory resources where an Allocator is required. + +Unlike memory resources, but like other allocators, `mr::allocator` is typed and bound to allocate object of a specific type, however it can be freely rebound to other types. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type that will be allocated by this allocator. + + + +The upstream memory resource to use for memory allocation. Must derive from [`thrust::mr::memory_resource`](/library/api/thrust::mr::memory_resource) and must be `final` (in C++11 and beyond). + + + + + +**Inherits from:** `thrust::mr::validator< MR >` (private) + +--- + +## Constructors + +### allocator inline + + + + +Constructor. + + +```cpp showLineNumbers={false} +thrust::mr::allocator::allocator( + MR *resource +) +``` + + +**Parameters** + + +The resource to be used to allocate raw memory. + + + + + +Copy constructor. + +Copies the resource pointer. + + +```cpp showLineNumbers={false} +template +thrust::mr::allocator::allocator( + const allocator &other +) +``` + + + + + +--- + +## Methods + +### max_size inline const + +Calculates the maximum number of elements allocated by this allocator. + + +```cpp showLineNumbers={false} +size_type thrust::mr::allocator::max_size() const +``` + + +**Returns:** the maximum value of `std::size_t`, divided by the size of `T`. + +### allocate inline nodiscard + +Allocates objects of type `T`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::allocator::allocate( + size_type n +) +``` + + +**Returns:** a pointer to the newly allocated storage. + +**Parameters** + + +Number of elements to allocate + + +### deallocate inline noexcept + +Deallocates objects of type `T`. + + +```cpp showLineNumbers={false} +void thrust::mr::allocator::deallocate( + pointer p, + size_type n +) noexcept +``` + + +**Parameters** + + +Pointer returned by a previous call to `allocate` + + + +Number of elements, passed as an argument to the `allocate` call that produced `p` + + +### resource inline const + +Extracts the memory resource used by this allocator. + + +```cpp showLineNumbers={false} +MR * thrust::mr::allocator::resource() const +``` + + +**Returns:** the memory resource used by this allocator. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `void_pointer` | `typename MR::pointer` | The pointer to void type of this allocator. | +| `value_type` | `T` | The value type allocated by this allocator. | +| `pointer` | `typename thrust::detail::pointer_traits< void_pointer >::template rebind< T >::other` | The pointer type allocated by this allocator. | +| `const_pointer` | `typename thrust::detail::pointer_traits< void_pointer >::template rebind< const T >::other` | The pointer to const type. | +| `reference` | `typename thrust::detail::pointer_traits< pointer >::reference` | The reference to the type allocated by this allocator. | +| `const_reference` | `typename thrust::detail::pointer_traits< const_pointer >::reference` | The const reference to the type allocated by this allocator. | +| `size_type` | `std::size_t` | The size type of this allocator. | +| `difference_type` | `typename thrust::detail::pointer_traits< pointer >::difference_type` | The difference type between pointers allocated by this allocator. | +| `propagate_on_container_copy_assignment` | `detail::true_type` | Specifies that the allocator shall be propagated on container copy assignment. | +| `propagate_on_container_move_assignment` | `detail::true_type` | Specifies that the allocator shall be propagated on container move assignment. | +| `propagate_on_container_swap` | `detail::true_type` | Specifies that the allocator shall be propagated on container swap. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `mem_res` | `MR *` | | + +--- + +## Inner classes + +### rebind + + +```cpp showLineNumbers={false} +struct thrust::mr::allocator::rebind +``` + + +The `rebind` metafunction provides the type of an `allocator` instantiated with another type. diff --git a/fern/cudapages/thrust/thrust/thrust/mr/disjoint_synchronized_pool_resource.mdx b/fern/cudapages/thrust/thrust/thrust/mr/disjoint_synchronized_pool_resource.mdx new file mode 100644 index 0000000..d3e63c7 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/disjoint_synchronized_pool_resource.mdx @@ -0,0 +1,293 @@ +--- +title: thrust::mr::disjoint_synchronized_pool_resource +description: "A mutex-synchronized version of [`disjoint_unsynchronized_pool_resource`](/library/api/thrust::mr::disjoint_unsynchronized_pool_resource)." +--- + +A mutex-synchronized version of [`disjoint_unsynchronized_pool_resource`](/library/api/thrust::mr::disjoint_unsynchronized_pool_resource). + +Uses `std::mutex`, and therefore requires C++11. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type of memory resources that will be used for allocating memory blocks to be handed off to the user + + + +The type of memory resources that will be used for allocating bookkeeping memory + + + + + +**Inherits from:** `thrust::mr::memory_resource< Upstream::pointer >` (public) + +--- + +## Constructors + +### disjoint_synchronized_pool_resource inline + + + + +Constructor. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_synchronized_pool_resource::disjoint_synchronized_pool_resource( + Upstream *upstream, + Bookkeeper *bookkeeper, + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +The upstream memory resource for allocations + + + +The upstream memory resource for bookkeeping + + + +Pool options to use + + + + + +Constructor. + +Upstream and bookkeeping resources are obtained by calling `get_global_resource` for their types. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_synchronized_pool_resource::disjoint_synchronized_pool_resource( + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +Pool options to use + + + + + +--- + +## Methods + +### release inline + +Releases all held memory to upstream. + + +```cpp showLineNumbers={false} +void thrust::mr::disjoint_synchronized_pool_resource::release() +``` + + +### do_allocate inline nodiscard virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +virtual void_ptr thrust::mr::disjoint_synchronized_pool_resource::do_allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### do_deallocate inline virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::mr::disjoint_synchronized_pool_resource::do_deallocate( + void_ptr p, + std::size_t n, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +--- + +## Static methods + +### get_default_options inline static + +Get the default options for a disjoint pool. + +These are meant to be a sensible set of values for many use cases, and as such, may be tuned in the future. This function is exposed so that creating a set of options that are just a slight departure from the defaults is easy. + + +```cpp showLineNumbers={false} +static pool_options thrust::mr::disjoint_synchronized_pool_resource::get_default_options() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `unsync_pool` | `disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >` | | +| `lock_t` | `std::lock_guard< std::mutex >` | | +| `void_ptr` | `typename Upstream::pointer` | | +| `pointer` | `Upstream::pointer` | Alias for the template parameter. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `mtx` | `std::mutex` | | +| `upstream_pool` | `unsync_pool` | | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/disjoint_unsynchronized_pool_resource.mdx b/fern/cudapages/thrust/thrust/thrust/mr/disjoint_unsynchronized_pool_resource.mdx new file mode 100644 index 0000000..eacb5e0 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/disjoint_unsynchronized_pool_resource.mdx @@ -0,0 +1,410 @@ +--- +title: thrust::mr::disjoint_unsynchronized_pool_resource +description: "A memory resource adaptor allowing for pooling and caching allocations from `Upstream`, using `Bookkeeper` for management of that cached and pooled memory, allowing to cache portions of memory inaccessible from the host." +--- + +A memory resource adaptor allowing for pooling and caching allocations from `Upstream`, using `Bookkeeper` for management of that cached and pooled memory, allowing to cache portions of memory inaccessible from the host. + +On a typical memory resource, calls to `allocate` and `deallocate` actually allocate and deallocate memory. Pooling memory resources only allocate and deallocate memory from an external resource (the upstream memory resource) when there's no suitable memory currently cached; otherwise, they use memory they have acquired beforehand, to make memory allocation faster and more efficient. + +The disjoint version of the pool resources uses a separate upstream memory resource, `Bookkeeper`, to allocate memory necessary to manage the cached memory. There may be many reasons to do that; the canonical one is that `Upstream` allocates memory that is inaccessible to the code of the pool resource, which means that it cannot embed the necessary information in memory obtained from `Upstream`; for instance, `Upstream` can be a CUDA non-managed memory resource, or a CUDA managed memory resource whose memory we would prefer to not migrate back and forth between host and device when executing bookkeeping code. + +This is not the only case where it makes sense to use a disjoint pool resource, though. In a multi-core environment it may be beneficial to avoid stealing cache lines from other cores by writing over bookkeeping information embedded in an allocated block of memory. In such a case, one can imagine wanting to use a disjoint pool where both the upstream and the bookkeeper are of the same type, to allocate memory consistently, but separately for those two purposes. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type of memory resources that will be used for allocating memory blocks to be handed off to the user + + + +The type of memory resources that will be used for allocating bookkeeping memory + + + + + +**Inherits from:** `thrust::mr::memory_resource< Upstream::pointer >` (public), `thrust::mr::validator2< Upstream, Bookkeeper >` (private) + +This class is marked final. + +--- + +## Constructors + +### disjoint_unsynchronized_pool_resource inline + + + + +Constructor. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_unsynchronized_pool_resource::disjoint_unsynchronized_pool_resource( + Upstream *upstream, + Bookkeeper *bookkeeper, + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +The upstream memory resource for allocations + + + +The upstream memory resource for bookkeeping + + + +Pool options to use + + + + + +Constructor. + +Upstream and bookkeeping resources are obtained by calling `get_global_resource` for their types. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_unsynchronized_pool_resource::disjoint_unsynchronized_pool_resource( + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +Pool options to use + + + + + +### Destructor + +### ~disjoint_unsynchronized_pool_resource inline + +Destructor. + +Releases all held memory to upstream. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_unsynchronized_pool_resource::~disjoint_unsynchronized_pool_resource() +``` + + +--- + +## Methods + +### release inline + +Releases all held memory to upstream. + + +```cpp showLineNumbers={false} +void thrust::mr::disjoint_unsynchronized_pool_resource::release() +``` + + +### squeeze inline + + +```cpp showLineNumbers={false} +void thrust::mr::disjoint_unsynchronized_pool_resource::squeeze() +``` + + +### do_allocate inline nodiscard virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +virtual void_ptr thrust::mr::disjoint_unsynchronized_pool_resource::do_allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### do_allocate_impl inline nodiscard + + +```cpp showLineNumbers={false} +void_ptr thrust::mr::disjoint_unsynchronized_pool_resource::do_allocate_impl( + std::size_t bytes, + std::size_t alignment +) +``` + + +### do_deallocate inline virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::mr::disjoint_unsynchronized_pool_resource::do_deallocate( + void_ptr p, + std::size_t n, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +--- + +## Static methods + +### get_default_options inline static + +Get the default options for a disjoint pool. + +These are meant to be a sensible set of values for many use cases, and as such, may be tuned in the future. This function is exposed so that creating a set of options that are just a slight departure from the defaults is easy. + + +```cpp showLineNumbers={false} +static pool_options thrust::mr::disjoint_unsynchronized_pool_resource::get_default_options() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `void_ptr` | `typename Upstream::pointer` | | +| `char_ptr` | `typename thrust::detail::pointer_traits< void_ptr >::template rebind< char >::other` | | +| `chunk_vector` | `thrust::host_vector< chunk_descriptor, allocator< chunk_descriptor, Bookkeeper > >` | | +| `oversized_block_vector` | `thrust::host_vector< oversized_block_descriptor, allocator< oversized_block_descriptor, Bookkeeper > >` | | +| `pointer_vector` | `thrust::host_vector< void_ptr, allocator< void_ptr, Bookkeeper > >` | | +| `pool_vector` | `thrust::host_vector< pool, allocator< pool, Bookkeeper > >` | | +| `pointer` | `Upstream::pointer` | Alias for the template parameter. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `m_upstream` | `Upstream *` | | +| `m_bookkeeper` | `Bookkeeper *` | | +| `m_options` | `pool_options` | | +| `m_smallest_block_log2` | `std::size_t` | | +| `m_pools` | `pool_vector` | | +| `m_allocated` | `chunk_vector` | | +| `m_cached_oversized` | `oversized_block_vector` | | +| `m_oversized` | `oversized_block_vector` | | + +--- + +## Inner classes + +### chunk_descriptor + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::chunk_descriptor +``` + + +| Name | Type | Description | +|---|---|---| +| `size` | `std::size_t` | | +| `pointer` | `void_ptr` | | +| `pool_idx` | `std::size_t` | | + +### oversized_block_descriptor + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::oversized_block_descriptor +``` + + +| Name | Type | Description | +|---|---|---| +| `size` | `std::size_t` | | +| `alignment` | `std::size_t` | | +| `pointer` | `void_ptr` | | + +### equal_pointers + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::equal_pointers +``` + + +| Name | Type | Description | +|---|---|---| +| `p` | `void_ptr` | | + +### matching_alignment + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::matching_alignment +``` + + +| Name | Type | Description | +|---|---|---| +| `requested` | `std::size_t` | | + +### pool + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::pool +``` + + +| Name | Type | Description | +|---|---|---| +| `free_blocks` | `pointer_vector` | | +| `previous_allocated_count` | `std::size_t` | | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/fancy_pointer_resource.mdx b/fern/cudapages/thrust/thrust/thrust/mr/fancy_pointer_resource.mdx new file mode 100644 index 0000000..f53833f --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/fancy_pointer_resource.mdx @@ -0,0 +1,229 @@ +--- +title: thrust::mr::fancy_pointer_resource +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::mr::memory_resource< Pointer >` (public), `thrust::mr::validator< Upstream >` (private) + +This class is marked final. + +--- + +## Constructors + +### fancy_pointer_resource inline + + + + + +```cpp showLineNumbers={false} +thrust::mr::fancy_pointer_resource::fancy_pointer_resource() +``` + + + + + + +```cpp showLineNumbers={false} +thrust::mr::fancy_pointer_resource::fancy_pointer_resource( + Upstream *upstream +) +``` + + + + + +--- + +## Methods + +### do_allocate inline nodiscard virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +virtual Pointer thrust::mr::fancy_pointer_resource::do_allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### do_deallocate inline virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::mr::fancy_pointer_resource::do_deallocate( + Pointer p, + std::size_t bytes, + std::size_t alignment +) override +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `pointer` | `Pointer` | Alias for the template parameter. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `m_upstream` | `Upstream *` | | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/memory_resource.mdx b/fern/cudapages/thrust/thrust/thrust/mr/memory_resource.mdx new file mode 100644 index 0000000..0506d7c --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/memory_resource.mdx @@ -0,0 +1,204 @@ +--- +title: thrust::mr::memory_resource +description: "[`memory_resource`](/library/api/thrust::mr::memory_resource) is the base class for all other memory resources." +--- + +`memory_resource` is the base class for all other memory resources. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The pointer type that is allocated and deallocated by the memory resource derived from this base class. If this is `void *`, this class derives from `std::pmr::memory_resource`. + + + + + +--- + +## Constructors + +### Destructor + +### ~memory_resource virtual + +Virtual destructor, defaulted when possible. + + +```cpp showLineNumbers={false} +virtual thrust::mr::memory_resource::~memory_resource() = default +``` + + +--- + +## Methods + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +### do_allocate virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +virtual pointer thrust::mr::memory_resource::do_allocate( + std::size_t bytes, + std::size_t alignment +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### do_deallocate virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::mr::memory_resource::do_deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment +) +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `pointer` | `Pointer` | Alias for the template parameter. | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/memory_resource_void_ptr.mdx b/fern/cudapages/thrust/thrust/thrust/mr/memory_resource_void_ptr.mdx new file mode 100644 index 0000000..a007c16 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/memory_resource_void_ptr.mdx @@ -0,0 +1,98 @@ +--- +title: "thrust::mr::memory_resource< void * >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Methods + +### Destructor + +### ~memory_resource< void * > virtual + +": "/library/api/thrust::mr::memory_resource%3C void * %3E"}}> +```cpp showLineNumbers={false} +virtual thrust::mr::memory_resource::~memory_resource() = default +``` + + +### allocate inline nodiscard + +": "/library/api/thrust::mr::memory_resource%3C void * %3E"}}> +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +### deallocate inline noexcept + +": "/library/api/thrust::mr::memory_resource%3C void * %3E"}}> +```cpp showLineNumbers={false} +void thrust::mr::memory_resource::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +### is_equal inline const noexcept + +": "/library/api/thrust::mr::memory_resource%3C void * %3E", "memory_resource": "/library/api/thrust::mr::memory_resource"}}> +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource::is_equal( + const memory_resource &other +) const noexcept +``` + + +### do_allocate virtual + +": "/library/api/thrust::mr::memory_resource%3C void * %3E"}}> +```cpp showLineNumbers={false} +virtual pointer thrust::mr::memory_resource::do_allocate( + std::size_t bytes, + std::size_t alignment +) +``` + + +### do_deallocate virtual + +": "/library/api/thrust::mr::memory_resource%3C void * %3E"}}> +```cpp showLineNumbers={false} +virtual void thrust::mr::memory_resource::do_deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment +) +``` + + +### do_is_equal inline const noexcept virtual + +": "/library/api/thrust::mr::memory_resource%3C void * %3E", "memory_resource": "/library/api/thrust::mr::memory_resource"}}> +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `pointer` | `void *` | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/new_delete_resource.mdx b/fern/cudapages/thrust/thrust/thrust/mr/new_delete_resource.mdx new file mode 100644 index 0000000..809fc59 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/new_delete_resource.mdx @@ -0,0 +1,207 @@ +--- +title: thrust::mr::new_delete_resource +description: "A memory resource that uses global operators new and delete to allocate and deallocate memory." +--- + +A memory resource that uses global operators new and delete to allocate and deallocate memory. + +Uses alignment-enabled overloads when available, otherwise uses regular overloads and implements alignment requirements by itself. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::mr::new_delete_resource_base` (public) + +This class is marked final. + +--- + +## Methods + +### do_allocate inline virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +void * thrust::mr::new_delete_resource::do_allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### do_deallocate + + + + +inline + + +```cpp showLineNumbers={false} +void thrust::mr::new_delete_resource::do_deallocate( + void *p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + + + + +virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::mr::memory_resource::do_deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment +) +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + + + + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `pointer` | `Pointer` | Alias for the template parameter. | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/new_delete_resource_base.mdx b/fern/cudapages/thrust/thrust/thrust/mr/new_delete_resource_base.mdx new file mode 100644 index 0000000..70e0590 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/new_delete_resource_base.mdx @@ -0,0 +1,201 @@ +--- +title: thrust::mr::new_delete_resource_base +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::mr::memory_resource<>` (public) + +--- + +## Methods + +### do_allocate inline virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +void * thrust::mr::new_delete_resource_base::do_allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### do_deallocate + + + + +inline + + +```cpp showLineNumbers={false} +void thrust::mr::new_delete_resource_base::do_deallocate( + void *p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + + + + +virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::mr::memory_resource::do_deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment +) +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + + + + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `pointer` | `Pointer` | Alias for the template parameter. | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/polymorphic_adaptor_resource.mdx b/fern/cudapages/thrust/thrust/thrust/mr/polymorphic_adaptor_resource.mdx new file mode 100644 index 0000000..9fd1901 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/polymorphic_adaptor_resource.mdx @@ -0,0 +1,211 @@ +--- +title: thrust::mr::polymorphic_adaptor_resource +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `thrust::mr::memory_resource< void * >` (public) + +This class is marked final. + +--- + +## Constructors + +### polymorphic_adaptor_resource inline + + +```cpp showLineNumbers={false} +thrust::mr::polymorphic_adaptor_resource::polymorphic_adaptor_resource( + memory_resource *t +) +``` + + +--- + +## Methods + +### do_allocate inline virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +virtual Pointer thrust::mr::polymorphic_adaptor_resource::do_allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### do_deallocate inline virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::mr::polymorphic_adaptor_resource::do_deallocate( + Pointer p, + std::size_t bytes, + std::size_t alignment +) override +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::polymorphic_adaptor_resource::do_is_equal( + const memory_resource &other +) const noexcept override +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `pointer` | `Pointer` | Alias for the template parameter. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `upstream_resource` | `memory_resource< Pointer > *` | | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/pool_options.mdx b/fern/cudapages/thrust/thrust/thrust/mr/pool_options.mdx new file mode 100644 index 0000000..f781070 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/pool_options.mdx @@ -0,0 +1,43 @@ +--- +title: thrust::mr::pool_options +description: "A type used for configuring pooling resource adaptors, to fine-tune their behavior and parameters." +--- + +A type used for configuring pooling resource adaptors, to fine-tune their behavior and parameters. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Methods + +### validate inline const + +Checks if the options are self-consistent. + +/returns true if the options are self-consistent, false otherwise. + + +```cpp showLineNumbers={false} +bool thrust::mr::pool_options::validate() const +``` + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `min_blocks_per_chunk` | `std::size_t` | The minimal number of blocks, i.e. | +| `min_bytes_per_chunk` | `std::size_t` | The minimal number of bytes in a single chunk allocated from upstream. | +| `max_blocks_per_chunk` | `std::size_t` | The maximal number of blocks, i.e. | +| `max_bytes_per_chunk` | `std::size_t` | The maximal number of bytes in a single chunk allocated from upstream. | +| `smallest_block_size` | `std::size_t` | The size of blocks in the smallest pool covered by the pool resource. | +| `largest_block_size` | `std::size_t` | The size of blocks in the largest pool covered by the pool resource. | +| `alignment` | `std::size_t` | The alignment of all blocks in internal pools of the pool resource. | +| `cache_oversized` | `bool` | Decides whether oversized and overaligned blocks are cached for later use, or immediately return it to the upstream resource. | +| `cached_size_cutoff_factor` | `std::size_t` | The size factor at which a cached allocation is considered too ridiculously oversized to use to fulfill an allocation request. | +| `cached_alignment_cutoff_factor` | `std::size_t` | The alignment factor at which a cached allocation is considered too ridiculously overaligned to use to fulfill an allocation request. | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/stateless_resource_allocator.mdx b/fern/cudapages/thrust/thrust/thrust/mr/stateless_resource_allocator.mdx new file mode 100644 index 0000000..0ada1c6 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/stateless_resource_allocator.mdx @@ -0,0 +1,221 @@ +--- +title: thrust::mr::stateless_resource_allocator +description: "A helper allocator class that uses global instances of a given upstream memory resource." +--- + +A helper allocator class that uses global instances of a given upstream memory resource. + +Requires the memory resource to be default constructible. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type that will be allocated by this allocator. + + + +The upstream memory resource to use for memory allocation. Must derive from [`thrust::mr::memory_resource`](/library/api/thrust::mr::memory_resource) and must be `final` (in C++11 and beyond). + + + + + +**Inherits from:** `thrust::mr::allocator< T, Upstream >` (public) + +--- + +## Constructors + +### stateless_resource_allocator inline + + + + +Default constructor. + +Uses `get_global_resource` to get the global instance of `Upstream` and initializes the `allocator` base subobject with that resource. + + +```cpp showLineNumbers={false} +thrust::mr::stateless_resource_allocator::stateless_resource_allocator() +``` + + + + + +Copy constructor. + +Copies the memory resource pointer. + + +```cpp showLineNumbers={false} +thrust::mr::stateless_resource_allocator::stateless_resource_allocator( + const stateless_resource_allocator &other +) +``` + + + + + +Conversion constructor from an allocator of a different type. + +Copies the memory resource pointer. + + +```cpp showLineNumbers={false} +template +thrust::mr::stateless_resource_allocator::stateless_resource_allocator( + const stateless_resource_allocator &other +) +``` + + + + + +### Destructor + +### ~stateless_resource_allocator inline + +Destructor. + + +```cpp showLineNumbers={false} +thrust::mr::stateless_resource_allocator::~stateless_resource_allocator() +``` + + +--- + +## Assignment operators + +### operator= + + +```cpp showLineNumbers={false} +stateless_resource_allocator & thrust::mr::stateless_resource_allocator::operator=( + const stateless_resource_allocator & +) = default +``` + + +--- + +## Methods + +### max_size inline const + +Calculates the maximum number of elements allocated by this allocator. + + +```cpp showLineNumbers={false} +size_type thrust::mr::allocator::max_size() const +``` + + +**Returns:** the maximum value of `std::size_t`, divided by the size of `T`. + +### allocate inline nodiscard + +Allocates objects of type `T`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::allocator::allocate( + size_type n +) +``` + + +**Returns:** a pointer to the newly allocated storage. + +**Parameters** + + +Number of elements to allocate + + +### deallocate inline noexcept + +Deallocates objects of type `T`. + + +```cpp showLineNumbers={false} +void thrust::mr::allocator::deallocate( + pointer p, + size_type n +) noexcept +``` + + +**Parameters** + + +Pointer returned by a previous call to `allocate` + + + +Number of elements, passed as an argument to the `allocate` call that produced `p` + + +### resource inline const + +Extracts the memory resource used by this allocator. + + +```cpp showLineNumbers={false} +Upstream * thrust::mr::allocator::resource() const +``` + + +**Returns:** the memory resource used by this allocator. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base` | `thrust::mr::allocator< T, Upstream >` | | +| `void_pointer` | `typename Upstream::pointer` | The pointer to void type of this allocator. | +| `value_type` | `T` | The value type allocated by this allocator. | +| `pointer` | `typename thrust::detail::pointer_traits< void_pointer >::template rebind< T >::other` | The pointer type allocated by this allocator. | +| `const_pointer` | `typename thrust::detail::pointer_traits< void_pointer >::template rebind< const T >::other` | The pointer to const type. | +| `reference` | `typename thrust::detail::pointer_traits< pointer >::reference` | The reference to the type allocated by this allocator. | +| `const_reference` | `typename thrust::detail::pointer_traits< const_pointer >::reference` | The const reference to the type allocated by this allocator. | +| `size_type` | `std::size_t` | The size type of this allocator. | +| `difference_type` | `typename thrust::detail::pointer_traits< pointer >::difference_type` | The difference type between pointers allocated by this allocator. | +| `propagate_on_container_copy_assignment` | `detail::true_type` | Specifies that the allocator shall be propagated on container copy assignment. | +| `propagate_on_container_move_assignment` | `detail::true_type` | Specifies that the allocator shall be propagated on container move assignment. | +| `propagate_on_container_swap` | `detail::true_type` | Specifies that the allocator shall be propagated on container swap. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `mem_res` | `Upstream *` | | + +--- + +## Inner classes + +### rebind + + +```cpp showLineNumbers={false} +struct thrust::mr::stateless_resource_allocator::rebind +``` + + +The `rebind` metafunction provides the type of an `stateless_resource_allocator` instantiated with another type. diff --git a/fern/cudapages/thrust/thrust/thrust/mr/synchronized_pool_resource.mdx b/fern/cudapages/thrust/thrust/thrust/mr/synchronized_pool_resource.mdx new file mode 100644 index 0000000..f537020 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/synchronized_pool_resource.mdx @@ -0,0 +1,284 @@ +--- +title: thrust::mr::synchronized_pool_resource +description: "A mutex-synchronized version of [`unsynchronized_pool_resource`](/library/api/thrust::mr::unsynchronized_pool_resource)." +--- + +A mutex-synchronized version of [`unsynchronized_pool_resource`](/library/api/thrust::mr::unsynchronized_pool_resource). + +Uses `std::mutex`, and therefore requires C++11. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type of memory resources that will be used for allocating memory + + + + + +**Inherits from:** `thrust::mr::memory_resource< Upstream::pointer >` (public) + +--- + +## Constructors + +### synchronized_pool_resource inline + + + + +Constructor. + + +```cpp showLineNumbers={false} +thrust::mr::synchronized_pool_resource::synchronized_pool_resource( + Upstream *upstream, + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +The upstream memory resource for allocations + + + +Pool options to use + + + + + +Constructor. + +The upstream resource is obtained by calling `get_global_resource`. + + +```cpp showLineNumbers={false} +thrust::mr::synchronized_pool_resource::synchronized_pool_resource( + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +Pool options to use + + + + + +--- + +## Methods + +### release inline + +Releases all held memory to upstream. + + +```cpp showLineNumbers={false} +void thrust::mr::synchronized_pool_resource::release() +``` + + +### do_allocate inline nodiscard virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +virtual void_ptr thrust::mr::synchronized_pool_resource::do_allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### do_deallocate inline virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::mr::synchronized_pool_resource::do_deallocate( + void_ptr p, + std::size_t n, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +--- + +## Static methods + +### get_default_options inline static + +Get the default options for a pool. + +These are meant to be a sensible set of values for many use cases, and as such, may be tuned in the future. This function is exposed so that creating a set of options that are just a slight departure from the defaults is easy. + + +```cpp showLineNumbers={false} +static pool_options thrust::mr::synchronized_pool_resource::get_default_options() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `unsync_pool` | `unsynchronized_pool_resource< Upstream >` | | +| `lock_t` | `std::lock_guard< std::mutex >` | | +| `void_ptr` | `typename Upstream::pointer` | | +| `pointer` | `Upstream::pointer` | Alias for the template parameter. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `mtx` | `std::mutex` | | +| `upstream_pool` | `unsync_pool` | | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/unsynchronized_pool_resource.mdx b/fern/cudapages/thrust/thrust/thrust/mr/unsynchronized_pool_resource.mdx new file mode 100644 index 0000000..3979007 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/unsynchronized_pool_resource.mdx @@ -0,0 +1,373 @@ +--- +title: thrust::mr::unsynchronized_pool_resource +description: "A memory resource adaptor allowing for pooling and caching allocations from `Upstream`, using memory allocated from it for both blocks then allocated to the user and for internal bookkeeping of the cached memory." +--- + +A memory resource adaptor allowing for pooling and caching allocations from `Upstream`, using memory allocated from it for both blocks then allocated to the user and for internal bookkeeping of the cached memory. + +On a typical memory resource, calls to `allocate` and `deallocate` actually allocate and deallocate memory. Pooling memory resources only allocate and deallocate memory from an external resource (the upstream memory resource) when there's no suitable memory currently cached; otherwise, they use memory they have acquired beforehand, to make memory allocation faster and more efficient. + +The non-disjoint version of the pool resource uses a single upstream memory resource. Every allocation is larger than strictly necessary to fulfill the end-user's request, because it needs to account for the memory overhead of tracking the memory blocks and chunks inside those same memory regions. Nevertheless, this version should be more memory-efficient than the [`disjoint_unsynchronized_pool_resource`](/library/api/thrust::mr::disjoint_unsynchronized_pool_resource), because it doesn't need to allocate additional blocks of memory from a separate resource, which in turn would necessitate the bookkeeping overhead in the upstream resource. + +This version requires that memory allocated from Upstream is accessible from device. It supports smart references, meaning that the non-managed CUDA resource, returning a device-tagged pointer, will work, but will be much less efficient than the disjoint version, which wouldn't need to touch device memory at all, and therefore wouldn't need to transfer it back and forth between the host and the device whenever an allocation or a deallocation happens. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type of memory resources that will be used for allocating memory blocks + + + + + +**Inherits from:** `thrust::mr::memory_resource< Upstream::pointer >` (public), `thrust::mr::validator< Upstream >` (private) + +This class is marked final. + +--- + +## Constructors + +### unsynchronized_pool_resource inline + + + + +Constructor. + + +```cpp showLineNumbers={false} +thrust::mr::unsynchronized_pool_resource::unsynchronized_pool_resource( + Upstream *upstream, + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +The upstream memory resource for allocations + + + +Pool options to use + + + + + +Constructor. + +The upstream resource is obtained by calling `get_global_resource`. + + +```cpp showLineNumbers={false} +thrust::mr::unsynchronized_pool_resource::unsynchronized_pool_resource( + pool_options options = get_default_options() +) +``` + + +**Parameters** + + +Pool options to use + + + + + +### Destructor + +### ~unsynchronized_pool_resource inline + +Destructor. + +Releases all held memory to upstream. + + +```cpp showLineNumbers={false} +thrust::mr::unsynchronized_pool_resource::~unsynchronized_pool_resource() +``` + + +--- + +## Methods + +### release inline + +Releases all held memory to upstream. + + +```cpp showLineNumbers={false} +void thrust::mr::unsynchronized_pool_resource::release() +``` + + +### do_allocate inline nodiscard virtual + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +virtual void_ptr thrust::mr::unsynchronized_pool_resource::do_allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### do_deallocate inline virtual + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +virtual void thrust::mr::unsynchronized_pool_resource::do_deallocate( + void_ptr p, + std::size_t n, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) override +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The size of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### allocate inline nodiscard + +Allocates memory of size at least `bytes` and alignment at least `alignment`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::memory_resource::allocate( + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) +``` + + +**Returns:** A pointer to void to the newly allocated memory. + +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + +**Parameters** + + +Size, in bytes, that is requested from this allocation + + + +Alignment that is requested from this allocation + + +### deallocate inline noexcept + +Deallocates memory pointed to by `p`. + + +```cpp showLineNumbers={false} +void thrust::mr::memory_resource::deallocate( + pointer p, + std::size_t bytes, + std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT +) noexcept +``` + + +**Parameters** + + +Pointer to be deallocated + + + +The size of the allocation. This must be equivalent to the value of `bytes` that was passed to the allocation function that returned `p`. + + + +The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. + + +### is_equal inline const noexcept + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +bool thrust::mr::memory_resource::is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +### do_is_equal inline const noexcept virtual + +Compares this resource to the other one. + +The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. + + +```cpp showLineNumbers={false} +virtual bool thrust::mr::memory_resource::do_is_equal( + const memory_resource &other +) const noexcept +``` + + +**Returns:** whether the two resources are equivalent. + +**Parameters** + + +The other resource to compare this resource to + + +--- + +## Static methods + +### get_default_options inline static + +Get the default options for a pool. + +These are meant to be a sensible set of values for many use cases, and as such, may be tuned in the future. This function is exposed so that creating a set of options that are just a slight departure from the defaults is easy. + + +```cpp showLineNumbers={false} +static pool_options thrust::mr::unsynchronized_pool_resource::get_default_options() +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `void_ptr` | `typename Upstream::pointer` | | +| `void_ptr_traits` | `thrust::detail::pointer_traits< void_ptr >` | | +| `char_ptr` | `typename void_ptr_traits::template rebind< char >::other` | | +| `block_descriptor_ptr` | `typename void_ptr_traits::template rebind< block_descriptor >::other` | | +| `chunk_descriptor_ptr` | `typename void_ptr_traits::template rebind< chunk_descriptor >::other` | | +| `oversized_block_descriptor_ptr` | `typename void_ptr_traits::template rebind< oversized_block_descriptor >::other` | | +| `oversized_block_ptr_traits` | `thrust::detail::pointer_traits< oversized_block_descriptor_ptr >` | | +| `pool_vector` | `thrust::host_vector< pool, allocator< pool, Upstream > >` | | +| `pointer` | `Upstream::pointer` | Alias for the template parameter. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `m_upstream` | `Upstream *` | | +| `m_options` | `pool_options` | | +| `m_smallest_block_log2` | `std::size_t` | | +| `m_pools` | `pool_vector` | | +| `m_allocated` | `chunk_descriptor_ptr` | | +| `m_oversized` | `oversized_block_descriptor_ptr` | | +| `m_cached_oversized` | `oversized_block_descriptor_ptr` | | + +--- + +## Inner classes + +### block_descriptor + + +```cpp showLineNumbers={false} +struct thrust::mr::unsynchronized_pool_resource::block_descriptor +``` + + +| Name | Type | Description | +|---|---|---| +| `next` | `block_descriptor_ptr` | | + +### chunk_descriptor + + +```cpp showLineNumbers={false} +struct thrust::mr::unsynchronized_pool_resource::chunk_descriptor +``` + + +| Name | Type | Description | +|---|---|---| +| `size` | `std::size_t` | | +| `next` | `chunk_descriptor_ptr` | | + +### oversized_block_descriptor + + +```cpp showLineNumbers={false} +struct thrust::mr::unsynchronized_pool_resource::oversized_block_descriptor +``` + + +| Name | Type | Description | +|---|---|---| +| `size` | `std::size_t` | | +| `alignment` | `std::size_t` | | +| `prev` | `oversized_block_descriptor_ptr` | | +| `next` | `oversized_block_descriptor_ptr` | | +| `next_cached` | `oversized_block_descriptor_ptr` | | +| `current_size` | `std::size_t` | | + +### pool + + +```cpp showLineNumbers={false} +struct thrust::mr::unsynchronized_pool_resource::pool +``` + + +| Name | Type | Description | +|---|---|---| +| `free_list` | `block_descriptor_ptr` | | +| `previous_allocated_count` | `std::size_t` | | diff --git a/fern/cudapages/thrust/thrust/thrust/mr/validator.mdx b/fern/cudapages/thrust/thrust/thrust/mr/validator.mdx new file mode 100644 index 0000000..242fb91 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/validator.mdx @@ -0,0 +1,17 @@ +--- +title: thrust::mr::validator +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + diff --git a/fern/cudapages/thrust/thrust/thrust/mr/validator2.mdx b/fern/cudapages/thrust/thrust/thrust/mr/validator2.mdx new file mode 100644 index 0000000..4e518ca --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/validator2.mdx @@ -0,0 +1,22 @@ +--- +title: thrust::mr::validator2 +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::mr::validator< T >` (private), `thrust::mr::validator< U >` (private) diff --git a/fern/cudapages/thrust/thrust/thrust/mr/validator2_T_T.mdx b/fern/cudapages/thrust/thrust/thrust/mr/validator2_T_T.mdx new file mode 100644 index 0000000..d792663 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/mr/validator2_T_T.mdx @@ -0,0 +1,19 @@ +--- +title: "thrust::mr::validator2< T, T >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `thrust::mr::validator< T >` (private), `thrust::mr::validator< T >` (private), `thrust::mr::validator< T >` (private) diff --git a/fern/cudapages/thrust/thrust/thrust/no_traversal_tag.mdx b/fern/cudapages/thrust/thrust/thrust/no_traversal_tag.mdx new file mode 100644 index 0000000..b6cd92a --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/no_traversal_tag.mdx @@ -0,0 +1,10 @@ +--- +title: thrust::no_traversal_tag +description: "Tag type for iterators allowing no traversal." +--- + +Tag type for iterators allowing no traversal. + +```cpp showLineNumbers={false} +#include +``` diff --git a/fern/cudapages/thrust/thrust/thrust/normal_distribution.mdx b/fern/cudapages/thrust/thrust/thrust/normal_distribution.mdx new file mode 100644 index 0000000..f7b61b3 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/normal_distribution.mdx @@ -0,0 +1,271 @@ +--- +title: thrust::normal_distribution +description: "A [`normal_distribution`](/library/api/thrust::normal_distribution) random number distribution produces floating point Normally distributed random numbers." +--- + +A `normal_distribution` random number distribution produces floating point Normally distributed random numbers. + +The following code snippet demonstrates examples of using a `normal_distribution` with a random number engine to produce random values drawn from the Normal distribution with a given mean and variance: + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + // create a minstd_rand object to act as our source of randomness + thrust::minstd_rand rng; + + // create a normal_distribution to produce floats from the Normal distribution + // with mean 2.0 and standard deviation 3.5 + thrust::random::normal_distribution dist(2.0f, 3.5f); + + // write a random number to standard output + std::cout << dist(rng) << std::endl; + + // write the mean of the distribution, just in case we forgot + std::cout << dist.mean() << std::endl; + + // 2.0 is printed + + // and the standard deviation + std::cout << dist.stddev() << std::endl; + + // 3.5 is printed + + return 0; +} +``` + + + + + +The type of floating point number to produce. + + + + + +**Inherits from:** `detail::normal_distribution_base::type` (public) + +--- + +## Constructors + +### normal_distribution explicit + + + + +This constructor creates a new `normal_distribution` from two values defining the half-open interval of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::normal_distribution::normal_distribution( + RealType mean = 0.0, + RealType stddev = 1.0 +) +``` + + +**Parameters** + + +The mean (expected value) of the distribution. Defaults to `0.0`. + + + +The standard deviation of the distribution. Defaults to `1.0`. + + + + + +This constructor creates a new `normal_distribution` from a [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the range of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::normal_distribution::normal_distribution( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the parameters (i.e., the mean and standard deviation) of the distribution. + + + + + +--- + +## Methods + +### reset + +Calling this member function guarantees that subsequent uses of this `normal_distribution` do not depend on values produced by any random number generator prior to invoking this function. + + +```cpp showLineNumbers={false} +void thrust::random::normal_distribution::reset() +``` + + +### operator() + + + + +This method produces a new Normal random integer drawn from this `normal_distribution``'s` range using a `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::normal_distribution::operator()( + UniformRandomNumberGenerator &urng +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + + + +This method produces a new Normal random integer as if by creating a new `normal_distribution` from the given [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object, and calling its `operator()` method with the given `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::normal_distribution::operator()( + UniformRandomNumberGenerator &urng, + const param_type &parm +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + +A [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the parameters of the `normal_distribution` to draw from. + + + + + +### mean const + +This method returns the value of the parameter with which this `normal_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::normal_distribution::mean() const +``` + + +**Returns:** The mean (expected value) of this `normal_distribution``'s` output. + +### stddev const + +This method returns the value of the parameter with which this `normal_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::normal_distribution::stddev() const +``` + + +**Returns:** The standard deviation of this [`uniform_real_distribution`](/library/api/thrust::uniform_real_distribution)`'s` output. + +### param + + + + +This method changes the parameters of this `normal_distribution` using the values encapsulated in a given [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object. + + +```cpp showLineNumbers={false} +void thrust::random::normal_distribution::param( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the new parameters (i.e., the mean and variance) of this `normal_distribution`. + + + + + +const + +This method returns a [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the parameters with which this `normal_distribution` was constructed. + + +```cpp showLineNumbers={false} +param_type thrust::random::normal_distribution::param() const +``` + + +**Returns:** A [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the parameters (i.e., the mean and standard deviation) of this `normal_distribution`. + + + + +### min const + +This method returns the smallest floating point number this `normal_distribution` can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::normal_distribution::min() const +``` + + +**Returns:** The lower bound of this `normal_distribution``'s` half-open interval. + +### max const + +This method returns the smallest number larger than largest floating point number this [`uniform_real_distribution`](/library/api/thrust::uniform_real_distribution) can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::normal_distribution::max() const +``` + + +**Returns:** The upper bound of this `normal_distribution``'s` half-open interval. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `super_t` | `typename detail::normal_distribution_base< RealType >::type` | | +| `result_type` | `RealType` | The type of the floating point number produced by this `normal_distribution`. | +| `param_type` | `::cuda::std::pair< RealType, RealType >` | The type of the object encapsulating this `normal_distribution``'s` parameters. | diff --git a/fern/cudapages/thrust/thrust/thrust/offset_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/offset_iterator.mdx new file mode 100644 index 0000000..795391c --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/offset_iterator.mdx @@ -0,0 +1,166 @@ +--- +title: thrust::offset_iterator +description: "[`offset_iterator`](/library/api/thrust::offset_iterator) wraps another iterator and an integral offset, applies the offset to the iterator when dereferencing, comparing, or computing the distance between two offset_iterators." +--- + +`offset_iterator` wraps another iterator and an integral offset, applies the offset to the iterator when dereferencing, comparing, or computing the distance between two offset_iterators. + +This is useful, when the underlying iterator cannot be incremented, decremented, or advanced (e.g., because those operations are only supported in device code). + +The following code snippet demonstrates how to create an `offset_iterator``:` + +Alternatively, an `offset_iterator` can also use an iterator to retrieve the offset from an iterator. However, such an `offset_iterator` cannot be moved anymore by changing the offset, so it will move the base iterator instead. + +In the above example, the offset is loaded from a device vector, transformed by a [`transform_iterator`](/library/api/thrust::transform_iterator), and then applied to the underlying iterator, when the `offset_iterator` is accessed. + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include + +int main() +{ + thrust::device_vector data{1, 2, 3, 4}; + auto b = offset_iterator{data.begin(), 1}; + auto e = offset_iterator{data.end(), -1}; + thrust::fill(b, e, 42); + // data is now [1, 42, 42, 4] + ++b; // does not call ++ on the underlying iterator + assert(b == e - 1); + + return 0; +} +``` + +```cpp showLineNumbers={false} +#include +#include +#include +#include + +int main() +{ + using thrust::placeholders::_1; + thrust::device_vector data{1, 2, 3, 4}; + + thrust::device_vector offsets{1}; // offset is only available on device + auto offset = thrust::make_transform_iterator(offsets.begin(), _1 * 2); + thrust::offset_iterator iter(data.begin(), offset); // load and transform offset upon access + // iter is at position 2 (= 1 * 2) in data, and would return 3 in device code + + return 0; +} +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::iterator_adaptor< offset_iterator< Iterator, typename ::cuda::std::iterator_traits< Iterator >::difference_type >, Iterator >` (public) + +--- + +## Constructors + +### offset_iterator inline + + +```cpp showLineNumbers={false} +thrust::offset_iterator::offset_iterator( + Iterator it = {}, + Offset offset = {} +) +``` + + +--- + +## Methods + +### offset inline + + + + + +```cpp showLineNumbers={false} +Offset & thrust::offset_iterator::offset() +``` + + + + + +const + + +```cpp showLineNumbers={false} +const Offset & thrust::offset_iterator::offset() const +``` + + + + + +### base inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +Base & thrust::iterator_adaptor::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/output_device_iterator_tag.mdx b/fern/cudapages/thrust/thrust/thrust/output_device_iterator_tag.mdx new file mode 100644 index 0000000..b7790e7 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/output_device_iterator_tag.mdx @@ -0,0 +1,17 @@ +--- +title: thrust::output_device_iterator_tag +description: "[`output_device_iterator_tag`](/library/api/thrust::output_device_iterator_tag) is an empty class: it has no member functions, member variables, or nested types." +--- + +`output_device_iterator_tag` is an empty class: it has no member functions, member variables, or nested types. + +It is used solely as a "tag": a representation of the Output Device Iterator concept within the C++ type system. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[https://en.cppreference.com/w/cpp/iterator/iterator_tags](https://en.cppreference.com/w/cpp/iterator/iterator_tags) iterator_traits, [input_device_iterator_tag](/library/api/thrust::input_device_iterator_tag), [forward_device_iterator_tag](/library/api/thrust::forward_device_iterator_tag), [bidirectional_device_iterator_tag](/library/api/thrust::bidirectional_device_iterator_tag), [random_access_device_iterator_tag](/library/api/thrust::random_access_device_iterator_tag), input_host_iterator_tag, output_host_iterator_tag, forward_host_iterator_tag, bidirectional_host_iterator_tag, random_access_host_iterator_tag + +**Inherits from:** `detail::iterator_category_with_system_and_traversal<::cuda::std::output_iterator_tag, device_system_tag, single_pass_traversal_tag >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/per_device_allocator.mdx b/fern/cudapages/thrust/thrust/thrust/per_device_allocator.mdx new file mode 100644 index 0000000..92e2bc8 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/per_device_allocator.mdx @@ -0,0 +1,208 @@ +--- +title: thrust::per_device_allocator +description: "A helper allocator class that uses global per device instances of a given upstream memory resource." +--- + +A helper allocator class that uses global per device instances of a given upstream memory resource. + +Requires the memory resource to be default constructible. + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type that will be allocated by this allocator. + + + + + + +The execution policy of the system to be used to retrieve the resource for the current device. + + + + + +**Inherits from:** `thrust::mr::allocator< T, Upstream >` (public) + +--- + +## Constructors + +### per_device_allocator inline + + + + +Default constructor. + + +```cpp showLineNumbers={false} +thrust::per_device_allocator::per_device_allocator() +``` + + + + + +Copy constructor. + +Copies the memory resource pointer. + + +```cpp showLineNumbers={false} +thrust::per_device_allocator::per_device_allocator( + const per_device_allocator &other +) +``` + + + + + +Conversion constructor from an allocator of a different type. + +Copies the memory resource pointer. + + +```cpp showLineNumbers={false} +template +thrust::per_device_allocator::per_device_allocator( + const per_device_allocator &other +) +``` + + + + + +### Destructor + +### ~per_device_allocator inline + +Destructor. + + +```cpp showLineNumbers={false} +thrust::per_device_allocator::~per_device_allocator() +``` + + +--- + +## Methods + +### max_size inline const + +Calculates the maximum number of elements allocated by this allocator. + + +```cpp showLineNumbers={false} +size_type thrust::mr::allocator::max_size() const +``` + + +**Returns:** the maximum value of `std::size_t`, divided by the size of `T`. + +### allocate inline nodiscard + +Allocates objects of type `T`. + + +```cpp showLineNumbers={false} +pointer thrust::mr::allocator::allocate( + size_type n +) +``` + + +**Returns:** a pointer to the newly allocated storage. + +**Parameters** + + +Number of elements to allocate + + +### deallocate inline noexcept + +Deallocates objects of type `T`. + + +```cpp showLineNumbers={false} +void thrust::mr::allocator::deallocate( + pointer p, + size_type n +) noexcept +``` + + +**Parameters** + + +Pointer returned by a previous call to `allocate` + + + +Number of elements, passed as an argument to the `allocate` call that produced `p` + + +### resource inline const + +Extracts the memory resource used by this allocator. + + +```cpp showLineNumbers={false} +Upstream * thrust::mr::allocator::resource() const +``` + + +**Returns:** the memory resource used by this allocator. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base` | `thrust::mr::allocator< T, Upstream >` | | +| `void_pointer` | `typename Upstream::pointer` | The pointer to void type of this allocator. | +| `value_type` | `T` | The value type allocated by this allocator. | +| `pointer` | `typename thrust::detail::pointer_traits< void_pointer >::template rebind< T >::other` | The pointer type allocated by this allocator. | +| `const_pointer` | `typename thrust::detail::pointer_traits< void_pointer >::template rebind< const T >::other` | The pointer to const type. | +| `reference` | `typename thrust::detail::pointer_traits< pointer >::reference` | The reference to the type allocated by this allocator. | +| `const_reference` | `typename thrust::detail::pointer_traits< const_pointer >::reference` | The const reference to the type allocated by this allocator. | +| `size_type` | `std::size_t` | The size type of this allocator. | +| `difference_type` | `typename thrust::detail::pointer_traits< pointer >::difference_type` | The difference type between pointers allocated by this allocator. | +| `propagate_on_container_copy_assignment` | `detail::true_type` | Specifies that the allocator shall be propagated on container copy assignment. | +| `propagate_on_container_move_assignment` | `detail::true_type` | Specifies that the allocator shall be propagated on container move assignment. | +| `propagate_on_container_swap` | `detail::true_type` | Specifies that the allocator shall be propagated on container swap. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `mem_res` | `Upstream *` | | + +--- + +## Inner classes + +### rebind + + +```cpp showLineNumbers={false} +struct thrust::per_device_allocator::rebind +``` + + +The `rebind` metafunction provides the type of an `per_device_allocator` instantiated with another type. diff --git a/fern/cudapages/thrust/thrust/thrust/permutation_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/permutation_iterator.mdx new file mode 100644 index 0000000..e770885 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/permutation_iterator.mdx @@ -0,0 +1,190 @@ +--- +title: thrust::permutation_iterator +description: "[`permutation_iterator`](/library/api/thrust::permutation_iterator) is an iterator which represents a pointer into a reordered view of a given range." +--- + +`permutation_iterator` is an iterator which represents a pointer into a reordered view of a given range. + +`permutation_iterator` is an imprecise name; the reordered view need not be a strict permutation. This iterator is useful for fusing a scatter or gather operation with other algorithms. + +This iterator takes two arguments: + +- an iterator to the range `V` on which the "permutation" will be applied +- the reindexing scheme that defines how the elements of `V` will be permuted. + +Note that `permutation_iterator` is not limited to strict permutations of the given range `V`. The distance between begin and end of the reindexing iterators is allowed to be smaller compared to the size of the range `V`, in which case the `permutation_iterator` only provides a "permutation" of a subrange of `V`. The indices neither need to be unique. In this same context, it must be noted that the past-the-end `permutation_iterator` is completely defined by means of the past-the-end iterator to the indices. + +The following code snippet demonstrates how to create a `permutation_iterator` which represents a reordering of the contents of a [`device_vector`](/library/api/thrust::device_vector). + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +make_permutation_iterator + +## Example + +```cpp showLineNumbers={false} +#include +#include +... +thrust::device_vector values{10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f, 70.0f, 80.0f}; +thrust::device_vector indices{2, 6, 1, 3}; + +using ElementIterator = thrust::device_vector::iterator; +using IndexIterator = thrust::device_vector::iterator ; + +thrust::permutation_iterator iter(values.begin(), indices.begin()); + +*iter; // returns 30.0f; +iter[0]; // returns 30.0f; +iter[1]; // returns 70.0f; +iter[2]; // returns 20.0f; +iter[3]; // returns 40.0f; + +// iter[4] is an out-of-bounds error + +*iter = -1.0f; // sets values[2] to -1.0f; +iter[0] = -1.0f; // sets values[2] to -1.0f; +iter[1] = -1.0f; // sets values[6] to -1.0f; +iter[2] = -1.0f; // sets values[1] to -1.0f; +iter[3] = -1.0f; // sets values[3] to -1.0f; + +// values is now {10, -1, -1, -1, 50, 60, -1, 80} +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::iterator_adaptor< permutation_iterator< ElementIterator, IndexIterator >, IndexIterator, it_value_t< ElementIterator >, minimum_system_t< System1, System2 >, use_default, it_reference_t< ElementIterator > >` (public) + +--- + +## Constructors + +### permutation_iterator + + + + +Null constructor calls the null constructor of this `permutation_iterator``'s` element iterator. + + +```cpp showLineNumbers={false} +thrust::permutation_iterator::permutation_iterator() = default +``` + + + + + +inline explicit + +Constructor accepts an `ElementIterator` into a range of values and an `IndexIterator` into a range of indices defining the indexing scheme on the values. + + +```cpp showLineNumbers={false} +thrust::permutation_iterator::permutation_iterator( + ElementIterator x, + IndexIterator y +) +``` + + +**Parameters** + + +An `ElementIterator` pointing this `permutation_iterator``'s` range of values. + + + +An `IndexIterator` pointing to an indexing scheme to use on `x`. + + + + + +inline + +Copy constructor accepts a related `permutation_iterator`. + + +```cpp showLineNumbers={false} +template +thrust::permutation_iterator::permutation_iterator( + permutation_iterator const &rhs +) +``` + + +**Parameters** + + +A compatible `permutation_iterator` to copy from. + + + + + +--- + +## Methods + +### base inline const + + +```cpp showLineNumbers={false} +IndexIterator const & thrust::iterator_adaptor, IndexIterator, it_value_t, minimum_system_t, use_default, it_reference_t, use_default>::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +IndexIterator & thrust::iterator_adaptor, IndexIterator, it_value_t, minimum_system_t, use_default, it_reference_t, use_default>::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +const + + +```cpp showLineNumbers={false} +IndexIterator const & thrust::iterator_adaptor, IndexIterator, it_value_t, minimum_system_t, use_default, it_reference_t, use_default>::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `IndexIterator` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/pointer.mdx b/fern/cudapages/thrust/thrust/thrust/pointer.mdx new file mode 100644 index 0000000..3a49e50 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/pointer.mdx @@ -0,0 +1,289 @@ +--- +title: thrust::pointer +description: "`pointer` stores a pointer to an object allocated in memory." +--- + +`pointer` stores a pointer to an object allocated in memory. + +Like [`device_ptr`](/library/api/thrust::device_ptr), this type ensures type safety when dispatching standard algorithms on ranges resident in memory. + +`pointer` generalizes [`device_ptr`](/library/api/thrust::device_ptr) by relaxing the backend system associated with the `pointer`. Instead of the backend system specified by `THRUST_DEVICE_SYSTEM`, `pointer's` system is given by its second template parameter, `Tag`. For the purpose of Thrust dispatch, [`device_ptr`](/library/api/thrust::device_ptr) and [`pointer`](/library/api/thrust::pointer::pointer%3CElement,device_system_tag%3E) are considered equivalent. + +The raw pointer encapsulated by a `pointer` may be obtained through its [`get`](/library/api/thrust::pointer::get) member function or the `raw_pointer_cast` free function. + +```cpp showLineNumbers={false} +#include +``` + + +`pointer` is not a smart pointer; it is the client's responsibility to deallocate memory pointer to by `pointer`. + + +**See also:** +[device_ptr](/library/api/thrust::device_ptr), +reference, +raw_pointer_cast + + + + + +Specifies the type of the pointed-to object. + + + +Specifies the system with which this `pointer` is associated. This may be any Thrust backend system, or a user-defined tag. + + + +Allows the client to specify the reference type returned upon derereference. By default, this type is `reference`. + + + +Allows the client to specify the name of the derived type when `pointer` is used as a base class. This is useful to ensure that arithmetic on values of the derived type return values of the derived type as a result. By default, this type is [`pointer`](/library/api/thrust::pointer::pointer%3CElement,Tag,Reference%3E). + + + + + +**Inherits from:** `thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >` (public) + +--- + +## Constructors + +### pointer inline + + + + +`pointer's` default constructor initializes its encapsulated pointer to `0` + + +```cpp showLineNumbers={false} +thrust::pointer::pointer() +``` + + + + + + +```cpp showLineNumbers={false} +thrust::pointer::pointer( + ::cuda::std::nullptr_t +) +``` + + + + + +explicit + +This constructor allows construction of a [`pointer`](/library/api/thrust::pointer::pointer%3Cconst T, ...%3E) from a `T*`. + + +```cpp showLineNumbers={false} +template +thrust::pointer::pointer( + OtherElement *ptr +) +``` + + +**Template parameters** + + +`OtherElement` shall be convertible to `Element`. + + +**Parameters** + + +A raw pointer to copy from, presumed to point to a location in `Tag's` memory. + + + + + +This constructor allows initialization from another pointer-like object. + + +```cpp showLineNumbers={false} +template +thrust::pointer::pointer( + const OtherPointer &other +) +``` + + +**Template parameters** + + +The tag associated with `OtherPointer` shall be convertible to `Tag`, and its element type shall be convertible to `Element`. + + +**Parameters** + + +The `OtherPointer` to copy. + + + + + +--- + +## Assignment operators + +### operator= inline + + + + + +```cpp showLineNumbers={false} +derived_type & thrust::pointer::operator=( + ::cuda::std::nullptr_t +) +``` + + + + + +Assignment operator allows assigning from another pointer-like object whose element type is convertible to `Element`. + + +```cpp showLineNumbers={false} +template +detail::enable_if_pointer_is_convertible_t thrust::pointer::operator=( + const OtherPointer &other +) +``` + + +**Returns:** `*this` + +**Template parameters** + + +The tag associated with `OtherPointer` shall be convertible to `Tag`, and its element type shall be convertible to `Element`. + + +**Parameters** + + +The other pointer-like object to assign from. + + + + + +--- + +## Methods + +### dereference inline const + + +```cpp showLineNumbers={false} +template +SuperRef thrust::pointer::dereference() const +``` + + +### get inline const + +`get` returns this `pointer's` encapsulated raw pointer. + + +```cpp showLineNumbers={false} +Element * thrust::pointer::get() const +``` + + +**Returns:** This `pointer's` raw pointer. + +### operator-> inline const + + +```cpp showLineNumbers={false} +Element * thrust::pointer::operator->() const +``` + + +### operator bool inline explicit const + + +```cpp showLineNumbers={false} +thrust::pointer::operator bool() const +``` + + +### base inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +Base & thrust::iterator_adaptor::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Static methods + +### pointer_to inline static + + +```cpp showLineNumbers={false} +static derived_type thrust::pointer::pointer_to( + typename detail::pointer_traits_detail::pointer_to_param::type r +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `super_t` | `typename detail::pointer_base< Element, Tag, Reference, Derived >::type` | | +| `derived_type` | `typename detail::pointer_base< Element, Tag, Reference, Derived >::derived_type` | | +| `raw_pointer` | `typename super_t::base_type` | The type of the raw pointer. | +| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/proclaim_contiguous_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/proclaim_contiguous_iterator.mdx new file mode 100644 index 0000000..537e7c2 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/proclaim_contiguous_iterator.mdx @@ -0,0 +1,25 @@ +--- +title: thrust::proclaim_contiguous_iterator +description: "Customization point that can be customized to indicate that an iterator type `Iterator` satisfies [ContiguousIterator](https://en.cppreference.com/w/cpp/named_req/ContiguousIterator), aka it points to elements that are contiguous in memory." +--- + +Customization point that can be customized to indicate that an iterator type `Iterator` satisfies [ContiguousIterator](https://en.cppreference.com/w/cpp/named_req/ContiguousIterator), aka it points to elements that are contiguous in memory. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +is_contiguous_iterator, +THRUST_PROCLAIM_CONTIGUOUS_ITERATOR + + + + + + + + + + +**Inherits from:** `cuda::std::false_type` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/project1st.mdx b/fern/cudapages/thrust/thrust/thrust/project1st.mdx new file mode 100644 index 0000000..b2ba183 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/project1st.mdx @@ -0,0 +1,57 @@ +--- +title: thrust::project1st +description: "[`project1st`](/library/api/thrust::project1st) is a function object that takes two arguments and returns its first argument; the second argument is unused." +--- + +`project1st` is a function object that takes two arguments and returns its first argument; the second argument is unused. + +It is essentially a generalization of identity to the case of a Binary Function. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +identity, +[project2nd](/library/api/thrust::project2nd) + +## Example + +```cpp showLineNumbers={false} +#include +#include +... +int x = 137; +int y = -137; +thrust::project1st pj1; +assert(x == pj1(x,y)); +``` + + + + + + + + + + + + + +--- + +## Methods + +### operator() inline constexpr const + +Function call operator. + + +```cpp showLineNumbers={false} +const T1 & thrust::project1st::operator()( + const T1 &lhs, + const T2 & +) const +``` + diff --git a/fern/cudapages/thrust/thrust/thrust/project1st_void_void.mdx b/fern/cudapages/thrust/thrust/thrust/project1st_void_void.mdx new file mode 100644 index 0000000..8664f82 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/project1st_void_void.mdx @@ -0,0 +1,33 @@ +--- +title: "thrust::project1st< void, void >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Methods + +### operator() inline constexpr const noexcept + +": "/library/api/thrust::project1st%3C void, void %3E"}}> +```cpp showLineNumbers={false} +template +decltype( + t1 +) const noexcept(t1) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `is_transparent` | `void` | diff --git a/fern/cudapages/thrust/thrust/thrust/project2nd.mdx b/fern/cudapages/thrust/thrust/thrust/project2nd.mdx new file mode 100644 index 0000000..af9b9c6 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/project2nd.mdx @@ -0,0 +1,57 @@ +--- +title: thrust::project2nd +description: "[`project2nd`](/library/api/thrust::project2nd) is a function object that takes two arguments and returns its second argument; the first argument is unused." +--- + +`project2nd` is a function object that takes two arguments and returns its second argument; the first argument is unused. + +It is essentially a generalization of identity to the case of a Binary Function. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +identity, +[project1st](/library/api/thrust::project1st) + +## Example + +```cpp showLineNumbers={false} +#include +#include +... +int x = 137; +int y = -137; +thrust::project2nd pj2; +assert(y == pj2(x,y)); +``` + + + + + + + + + + + + + +--- + +## Methods + +### operator() inline constexpr const + +Function call operator. + + +```cpp showLineNumbers={false} +const T2 & thrust::project2nd::operator()( + const T1 &, + const T2 &rhs +) const +``` + diff --git a/fern/cudapages/thrust/thrust/thrust/project2nd_void_void.mdx b/fern/cudapages/thrust/thrust/thrust/project2nd_void_void.mdx new file mode 100644 index 0000000..3b66f6a --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/project2nd_void_void.mdx @@ -0,0 +1,33 @@ +--- +title: "thrust::project2nd< void, void >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Methods + +### operator() inline constexpr const noexcept + +": "/library/api/thrust::project2nd%3C void, void %3E"}}> +```cpp showLineNumbers={false} +template +decltype( + t2 +) const noexcept(t2) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `is_transparent` | `void` | diff --git a/fern/cudapages/thrust/thrust/thrust/random/discard_block_engine.mdx b/fern/cudapages/thrust/thrust/thrust/random/discard_block_engine.mdx new file mode 100644 index 0000000..2f86d3c --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/random/discard_block_engine.mdx @@ -0,0 +1,223 @@ +--- +title: thrust::random::discard_block_engine +description: "A [`discard_block_engine`](/library/api/thrust::random::discard_block_engine) adapts an existing base random number engine and produces random values by discarding some of the values returned by its base engine." +--- + +A `discard_block_engine` adapts an existing base random number engine and produces random values by discarding some of the values returned by its base engine. + +Each cycle of the compound engine begins by returning `r` values successively produced by the base engine and ends by discarding `p-r` such values. The engine's state is the state of its base engine followed by the number of calls to `operator()` that have occurred since the beginning of the current cycle. + +The following code snippet shows an example of using a `discard_block_engine` instance: + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include + +int main() +{ + // create a discard_block_engine from minstd_rand, with a cycle length of 13 + // keep every first 10 values, and discard the next 3 + thrust::discard_block_engine rng; + + // print a random number to standard output + std::cout << rng() << std::endl; + + return 0; +} +``` + + + + + +The type of the base random number engine to adapt. + + + +The discard cycle length. + + + +The number of values to return of the base engine. Because `p-r` will be discarded, `r <= p`. + + + + + +--- + +## Constructors + +### discard_block_engine + + + + +This constructor constructs a new `discard_block_engine` and constructs its [`base_type`](/library/api/thrust::random::discard_block_engine::base_type) engine using its null constructor. + + +```cpp showLineNumbers={false} +thrust::random::discard_block_engine::discard_block_engine() +``` + + + + + +explicit + +This constructor constructs a new `discard_block_engine` using a given [`base_type`](/library/api/thrust::random::discard_block_engine::base_type) engine to initialize its adapted base engine. + + +```cpp showLineNumbers={false} +thrust::random::discard_block_engine::discard_block_engine( + const base_type &urng +) +``` + + +**Parameters** + + +A [`base_type`](/library/api/thrust::random::discard_block_engine::base_type) to use to initialize this `discard_block_engine``'s` adapted base engine. + + + + + +explicit + +This constructor initializes a new `discard_block_engine` with a given seed. + + +```cpp showLineNumbers={false} +thrust::random::discard_block_engine::discard_block_engine( + result_type s +) +``` + + +**Parameters** + + +The seed used to initialize this `discard_block_engine``'s` adapted base engine. + + + + + +--- + +## Methods + +### seed + + + + +This method initializes the state of this `discard_block_engine``'s` adapted base engine by using its `default_seed` value. + + +```cpp showLineNumbers={false} +void thrust::random::discard_block_engine::seed() +``` + + + + + +This method initializes the state of this `discard_block_engine``'s` adapted base engine by using the given seed. + + +```cpp showLineNumbers={false} +void thrust::random::discard_block_engine::seed( + result_type s +) +``` + + +**Parameters** + + +The seed with which to initialize this `discard_block_engine``'s` adapted base engine. + + + + + +### operator() + +This member function produces a new random value and updates this `discard_block_engine``'s` state. + + +```cpp showLineNumbers={false} +result_type thrust::random::discard_block_engine::operator()( + void +) +``` + + +**Returns:** A new random number. + +### discard + +This member function advances this `discard_block_engine``'s` state a given number of times and discards the results. + + +```cpp showLineNumbers={false} +void thrust::random::discard_block_engine::discard( + unsigned long long z +) +``` + + + +This function is provided because an implementation may be able to accelerate it. + + +**Parameters** + + +The number of random values to discard. + + +### base const + +This member function returns a const reference to this `discard_block_engine``'s` adapted base engine. + + +```cpp showLineNumbers={false} +const base_type & thrust::random::discard_block_engine::base() const +``` + + +**Returns:** A const reference to the base engine this `discard_block_engine` adapts. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `Engine` | The type of the adapted base random number engine. | +| `result_type` | `typename base_type::result_type` | The type of the unsigned integer produced by this [`linear_congruential_engine`](/library/api/thrust::random::linear_congruential_engine). | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `block_size` static | `const size_t` | The length of the production cycle. | +| `used_block` static | `const size_t` | The number of used numbers per production cycle. | +| `min` static | `const result_type` | The smallest value this `discard_block_engine` may potentially produce. | +| `max` static | `const result_type` | The largest value this `discard_block_engine` may potentially produce. | diff --git a/fern/cudapages/thrust/thrust/thrust/random/linear_congruential_engine.mdx b/fern/cudapages/thrust/thrust/thrust/random/linear_congruential_engine.mdx new file mode 100644 index 0000000..6f34b6d --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/random/linear_congruential_engine.mdx @@ -0,0 +1,200 @@ +--- +title: thrust::random::linear_congruential_engine +description: "A [`linear_congruential_engine`](/library/api/thrust::random::linear_congruential_engine) random number engine produces unsigned integer random numbers using a linear congruential random number generation algorithm." +--- + +A `linear_congruential_engine` random number engine produces unsigned integer random numbers using a linear congruential random number generation algorithm. + +The generation algorithm has the form `x_i = (a * x_{i-1} + c) mod m`. + +The following code snippet shows examples of use of a `linear_congruential_engine` instance: + +```cpp showLineNumbers={false} +#include +``` + + +Inexperienced users should not use this class template directly. Instead, use `minstd_rand` or `minstd_rand0`. + + +**See also:** +thrust::random::minstd_rand, +thrust::random::minstd_rand0 + +## Example + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + // create a minstd_rand object, which is an instance of linear_congruential_engine + thrust::minstd_rand rng1; + + // output some random values to cout + std::cout << rng1() << std::endl; + + // a random value is printed + + // create a new minstd_rand from a seed + thrust::minstd_rand rng2(13); + + // discard some random values + rng2.discard(13); + + // stream the object to an iostream + std::cout << rng2 << std::endl; + + // rng2's current state is printed + + // print the minimum and maximum values that minstd_rand can produce + std::cout << thrust::minstd_rand::min << std::endl; + std::cout << thrust::minstd_rand::max << std::endl; + + // the range of minstd_rand is printed + + // save the state of rng2 to a different object + thrust::minstd_rand rng3 = rng2; + + // compare rng2 and rng3 + std::cout << (rng2 == rng3) << std::endl; + + // 1 is printed + + // re-seed rng2 with a different seed + rng2.seed(7); + + // compare rng2 and rng3 + std::cout << (rng2 == rng3) << std::endl; + + // 0 is printed + + return 0; +} +``` + + + + + +The type of unsigned integer to produce. + + + +The multiplier used in the generation algorithm. + + + +The increment used in the generation algorithm. + + + +The modulus used in the generation algorithm. + + + + + +--- + +## Constructors + +### linear_congruential_engine explicit + +This constructor, which optionally accepts a seed, initializes a new `linear_congruential_engine`. + + +```cpp showLineNumbers={false} +thrust::random::linear_congruential_engine::linear_congruential_engine( + result_type s = default_seed +) +``` + + +**Parameters** + + +The seed used to initialize this `linear_congruential_engine``'s` state. + + +--- + +## Methods + +### seed + +This method initializes this `linear_congruential_engine``'s` state, and optionally accepts a seed value. + + +```cpp showLineNumbers={false} +void thrust::random::linear_congruential_engine::seed( + result_type s = default_seed +) +``` + + +**Parameters** + + +The seed used to initializes this `linear_congruential_engine``'s` state. + + +### operator() + +This member function produces a new random value and updates this `linear_congruential_engine``'s` state. + + +```cpp showLineNumbers={false} +result_type thrust::random::linear_congruential_engine::operator()( + void +) +``` + + +**Returns:** A new random number. + +### discard + +This member function advances this `linear_congruential_engine``'s` state a given number of times and discards the results. + + +```cpp showLineNumbers={false} +void thrust::random::linear_congruential_engine::discard( + unsigned long long z +) +``` + + + +This function is provided because an implementation may be able to accelerate it. + + +**Parameters** + + +The number of random values to discard. + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `result_type` | `UIntType` | The type of the unsigned integer produced by this `linear_congruential_engine`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `multiplier` static | `const result_type` | The multiplier used in the generation algorithm. | +| `increment` static | `const result_type` | The increment used in the generation algorithm. | +| `modulus` static | `const result_type` | The modulus used in the generation algorithm. | +| `min` static | `const result_type` | The smallest value this `linear_congruential_engine` may potentially produce. | +| `max` static | `const result_type` | The largest value this `linear_congruential_engine` may potentially produce. | +| `default_seed` static | `const result_type` | The default seed of this `linear_congruential_engine`. | diff --git a/fern/cudapages/thrust/thrust/thrust/random/linear_feedback_shift_engine.mdx b/fern/cudapages/thrust/thrust/thrust/random/linear_feedback_shift_engine.mdx new file mode 100644 index 0000000..2b0db8a --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/random/linear_feedback_shift_engine.mdx @@ -0,0 +1,144 @@ +--- +title: thrust::random::linear_feedback_shift_engine +description: "A [`linear_feedback_shift_engine`](/library/api/thrust::random::linear_feedback_shift_engine) random number engine produces unsigned integer random values using a linear feedback shift random number generation algorithm." +--- + +A `linear_feedback_shift_engine` random number engine produces unsigned integer random values using a linear feedback shift random number generation algorithm. + +```cpp showLineNumbers={false} +#include +``` + + +`linear_feedback_shift_engine` is based on the Boost Template Library's linear_feedback_shift. + + + + + + +The type of unsigned integer to produce. + + + +The word size of the produced values (`w <= sizeof(UIntType)`). + + + +The k parameter of Tausworthe's 1965 algorithm. + + + +The q exponent of Tausworthe's 1965 algorithm. + + + +The step size of Tausworthe's 1965 algorithm. + + + + + +--- + +## Constructors + +### linear_feedback_shift_engine explicit + +This constructor, which optionally accepts a seed, initializes a new `linear_feedback_shift_engine`. + + +```cpp showLineNumbers={false} +thrust::random::linear_feedback_shift_engine::linear_feedback_shift_engine( + result_type value = default_seed +) +``` + + +**Parameters** + + +The seed used to initialize this `linear_feedback_shift_engine``'s` state. + + +--- + +## Methods + +### seed + +This method initializes this `linear_feedback_shift_engine``'s` state, and optionally accepts a seed value. + + +```cpp showLineNumbers={false} +void thrust::random::linear_feedback_shift_engine::seed( + result_type value = default_seed +) +``` + + +**Parameters** + + +The seed used to initializes this `linear_feedback_shift_engine``'s` state. + + +### operator() + +This member function produces a new random value and updates this `linear_feedback_shift_engine``'s` state. + + +```cpp showLineNumbers={false} +result_type thrust::random::linear_feedback_shift_engine::operator()( + void +) +``` + + +**Returns:** A new random number. + +### discard + +This member function advances this `linear_feedback_shift_engine``'s` state a given number of times and discards the results. + + +```cpp showLineNumbers={false} +void thrust::random::linear_feedback_shift_engine::discard( + unsigned long long z +) +``` + + + +This function is provided because an implementation may be able to accelerate it. + + +**Parameters** + + +The number of random values to discard. + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `result_type` | `UIntType` | The type of the unsigned integer produced by this `linear_feedback_shift_engine`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `word_size` static | `const size_t` | The word size of the produced values. | +| `exponent1` static | `const size_t` | A constant used in the generation algorithm. | +| `exponent2` static | `const size_t` | A constant used in the generation algorithm. | +| `step_size` static | `const size_t` | The step size used in the generation algorithm. | +| `min` static | `const result_type` | The smallest value this `linear_feedback_shift_engine` may potentially produce. | +| `max` static | `const result_type` | The largest value this `linear_feedback_shift_engine` may potentially produce. | +| `default_seed` static | `const result_type` | The default seed of this `linear_feedback_shift_engine`. | diff --git a/fern/cudapages/thrust/thrust/thrust/random/normal_distribution.mdx b/fern/cudapages/thrust/thrust/thrust/random/normal_distribution.mdx new file mode 100644 index 0000000..2891dbd --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/random/normal_distribution.mdx @@ -0,0 +1,271 @@ +--- +title: thrust::random::normal_distribution +description: "A [`normal_distribution`](/library/api/thrust::random::normal_distribution) random number distribution produces floating point Normally distributed random numbers." +--- + +A `normal_distribution` random number distribution produces floating point Normally distributed random numbers. + +The following code snippet demonstrates examples of using a `normal_distribution` with a random number engine to produce random values drawn from the Normal distribution with a given mean and variance: + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + // create a minstd_rand object to act as our source of randomness + thrust::minstd_rand rng; + + // create a normal_distribution to produce floats from the Normal distribution + // with mean 2.0 and standard deviation 3.5 + thrust::random::normal_distribution dist(2.0f, 3.5f); + + // write a random number to standard output + std::cout << dist(rng) << std::endl; + + // write the mean of the distribution, just in case we forgot + std::cout << dist.mean() << std::endl; + + // 2.0 is printed + + // and the standard deviation + std::cout << dist.stddev() << std::endl; + + // 3.5 is printed + + return 0; +} +``` + + + + + +The type of floating point number to produce. + + + + + +**Inherits from:** `detail::normal_distribution_base::type` (public) + +--- + +## Constructors + +### normal_distribution explicit + + + + +This constructor creates a new `normal_distribution` from two values defining the half-open interval of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::normal_distribution::normal_distribution( + RealType mean = 0.0, + RealType stddev = 1.0 +) +``` + + +**Parameters** + + +The mean (expected value) of the distribution. Defaults to `0.0`. + + + +The standard deviation of the distribution. Defaults to `1.0`. + + + + + +This constructor creates a new `normal_distribution` from a [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the range of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::normal_distribution::normal_distribution( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the parameters (i.e., the mean and standard deviation) of the distribution. + + + + + +--- + +## Methods + +### reset + +Calling this member function guarantees that subsequent uses of this `normal_distribution` do not depend on values produced by any random number generator prior to invoking this function. + + +```cpp showLineNumbers={false} +void thrust::random::normal_distribution::reset() +``` + + +### operator() + + + + +This method produces a new Normal random integer drawn from this `normal_distribution``'s` range using a `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::normal_distribution::operator()( + UniformRandomNumberGenerator &urng +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + + + +This method produces a new Normal random integer as if by creating a new `normal_distribution` from the given [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object, and calling its `operator()` method with the given `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::normal_distribution::operator()( + UniformRandomNumberGenerator &urng, + const param_type &parm +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + +A [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the parameters of the `normal_distribution` to draw from. + + + + + +### mean const + +This method returns the value of the parameter with which this `normal_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::normal_distribution::mean() const +``` + + +**Returns:** The mean (expected value) of this `normal_distribution``'s` output. + +### stddev const + +This method returns the value of the parameter with which this `normal_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::normal_distribution::stddev() const +``` + + +**Returns:** The standard deviation of this [`uniform_real_distribution`](/library/api/thrust::random::uniform_real_distribution)`'s` output. + +### param + + + + +This method changes the parameters of this `normal_distribution` using the values encapsulated in a given [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object. + + +```cpp showLineNumbers={false} +void thrust::random::normal_distribution::param( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the new parameters (i.e., the mean and variance) of this `normal_distribution`. + + + + + +const + +This method returns a [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the parameters with which this `normal_distribution` was constructed. + + +```cpp showLineNumbers={false} +param_type thrust::random::normal_distribution::param() const +``` + + +**Returns:** A [`param_type`](/library/api/thrust::random::normal_distribution::param_type) object encapsulating the parameters (i.e., the mean and standard deviation) of this `normal_distribution`. + + + + +### min const + +This method returns the smallest floating point number this `normal_distribution` can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::normal_distribution::min() const +``` + + +**Returns:** The lower bound of this `normal_distribution``'s` half-open interval. + +### max const + +This method returns the smallest number larger than largest floating point number this [`uniform_real_distribution`](/library/api/thrust::random::uniform_real_distribution) can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::normal_distribution::max() const +``` + + +**Returns:** The upper bound of this `normal_distribution``'s` half-open interval. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `super_t` | `typename detail::normal_distribution_base< RealType >::type` | | +| `result_type` | `RealType` | The type of the floating point number produced by this `normal_distribution`. | +| `param_type` | `::cuda::std::pair< RealType, RealType >` | The type of the object encapsulating this `normal_distribution``'s` parameters. | diff --git a/fern/cudapages/thrust/thrust/thrust/random/subtract_with_carry_engine.mdx b/fern/cudapages/thrust/thrust/thrust/random/subtract_with_carry_engine.mdx new file mode 100644 index 0000000..ca4e15f --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/random/subtract_with_carry_engine.mdx @@ -0,0 +1,152 @@ +--- +title: thrust::random::subtract_with_carry_engine +description: "A [`subtract_with_carry_engine`](/library/api/thrust::random::subtract_with_carry_engine) random number engine produces unsigned integer random numbers using the subtract with carry algorithm of Marsaglia & Zaman." +--- + +A `subtract_with_carry_engine` random number engine produces unsigned integer random numbers using the subtract with carry algorithm of Marsaglia & Zaman. + +The generation algorithm is performed as follows: + +1. Let `Y = X_{i-s}- X_{i-r} - c`. +2. Set `X_i` to `y = T mod m`. Set `c` to `1` if `Y < 0`, otherwise set `c` to `0`. + +This algorithm corresponds to a modular linear function of the form + +`TA(x_i) = (a * x_i) mod b`, where `b` is of the form `m^r - m^s + 1` and `a = b - (b-1)/m`. + +```cpp showLineNumbers={false} +#include +``` + + +Inexperienced users should not use this class template directly. Instead, use `ranlux24_base` or `ranlux48_base`, which are instances of `subtract_with_carry_engine`. + + +**See also:** +thrust::random::ranlux24_base, +thrust::random::ranlux48_base + + + + + +The type of unsigned integer to produce. + + + +The word size of the produced values (` w <= sizeof(UIntType)`). + + + +The short lag of the generation algorithm. + + + +The long lag of the generation algorithm. + + + + + +--- + +## Constructors + +### subtract_with_carry_engine explicit + +This constructor, which optionally accepts a seed, initializes a new `subtract_with_carry_engine`. + + +```cpp showLineNumbers={false} +thrust::random::subtract_with_carry_engine::subtract_with_carry_engine( + result_type value = default_seed +) +``` + + +**Parameters** + + +The seed used to initialize this `subtract_with_carry_engine``'s` state. + + +--- + +## Methods + +### seed + +This method initializes this `subtract_with_carry_engine``'s` state, and optionally accepts a seed value. + + +```cpp showLineNumbers={false} +void thrust::random::subtract_with_carry_engine::seed( + result_type value = default_seed +) +``` + + +**Parameters** + + +The seed used to initializes this `subtract_with_carry_engine``'s` state. + + +### operator() + +This member function produces a new random value and updates this `subtract_with_carry_engine``'s` state. + + +```cpp showLineNumbers={false} +result_type thrust::random::subtract_with_carry_engine::operator()( + void +) +``` + + +**Returns:** A new random number. + +### discard + +This member function advances this `subtract_with_carry_engine``'s` state a given number of times and discards the results. + + +```cpp showLineNumbers={false} +void thrust::random::subtract_with_carry_engine::discard( + unsigned long long z +) +``` + + + +This function is provided because an implementation may be able to accelerate it. + + +**Parameters** + + +The number of random values to discard. + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `result_type` | `UIntType` | The type of the unsigned integer produced by this `subtract_with_carry_engine`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `word_size` static | `const size_t` | The word size of the produced values. | +| `short_lag` static | `const size_t` | The size of the short lag used in the generation algorithm. | +| `long_lag` static | `const size_t` | The size of the long lag used in the generation algorithm. | +| `min` static | `const result_type` | The smallest value this `subtract_with_carry_engine` may potentially produce. | +| `max` static | `const result_type` | The largest value this `subtract_with_carry_engine` may potentially produce. | +| `default_seed` static | `const result_type` | The default seed of this `subtract_with_carry_engine`. | diff --git a/fern/cudapages/thrust/thrust/thrust/random/uniform_int_distribution.mdx b/fern/cudapages/thrust/thrust/thrust/random/uniform_int_distribution.mdx new file mode 100644 index 0000000..be6f490 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/random/uniform_int_distribution.mdx @@ -0,0 +1,277 @@ +--- +title: thrust::random::uniform_int_distribution +description: "A [`uniform_int_distribution`](/library/api/thrust::random::uniform_int_distribution) random number distribution produces signed or unsigned integer uniform random numbers from a given range." +--- + +A `uniform_int_distribution` random number distribution produces signed or unsigned integer uniform random numbers from a given range. + +The following code snippet demonstrates examples of using a `uniform_int_distribution` with a random number engine to produce random integers drawn from a given range: + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + // create a minstd_rand object to act as our source of randomness + thrust::minstd_rand rng; + + // create a uniform_int_distribution to produce ints from [-7,13] + thrust::uniform_int_distribution dist(-7,13); + + // write a random number from the range [-7,13] to standard output + std::cout << dist(rng) << std::endl; + + // write the range of the distribution, just in case we forgot + std::cout << dist.min() << std::endl; + + // -7 is printed + + std::cout << dist.max() << std::endl; + + // 13 is printed + + // write the parameters of the distribution (which happen to be the bounds) to standard output + std::cout << dist.a() << std::endl; + + // -7 is printed + + std::cout << dist.b() << std::endl; + + // 13 is printed + + return 0; +} +``` + + + + + +The type of integer to produce. + + + + + +--- + +## Constructors + +### uniform_int_distribution explicit + + + + +This constructor creates a new `uniform_int_distribution` from two values defining the range of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::uniform_int_distribution::uniform_int_distribution( + IntType a = 0, + IntType b = ::cuda::std::numeric_limits::max() +) +``` + + +**Parameters** + + +The smallest integer to potentially produce. Defaults to `0`. + + + +The largest integer to potentially produce. Defaults to the largest representable integer in the platform. + + + + + +This constructor creates a new `uniform_int_distribution` from a [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object encapsulating the range of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::uniform_int_distribution::uniform_int_distribution( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object encapsulating the parameters (i.e., the range) of the distribution. + + + + + +--- + +## Methods + +### reset + +This does nothing. + +It is included to conform to the requirements of the RandomDistribution concept. + + +```cpp showLineNumbers={false} +void thrust::random::uniform_int_distribution::reset() +``` + + +### operator() + + + + +This method produces a new uniform random integer drawn from this `uniform_int_distribution``'s` range using a `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::uniform_int_distribution::operator()( + UniformRandomNumberGenerator &urng +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + + + +This method produces a new uniform random integer as if by creating a new `uniform_int_distribution` from the given [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object, and calling its `operator()` method with the given `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::uniform_int_distribution::operator()( + UniformRandomNumberGenerator &urng, + const param_type &parm +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + +A [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object encapsulating the parameters of the `uniform_int_distribution` to draw from. + + + + + +### a const + +This method returns the value of the parameter with which this `uniform_int_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_int_distribution::a() const +``` + + +**Returns:** The lower bound of this `uniform_int_distribution``'s` range. + +### b const + +This method returns the value of the parameter with which this `uniform_int_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_int_distribution::b() const +``` + + +**Returns:** The upper bound of this `uniform_int_distribution``'s` range. + +### param + + + + +This method changes the parameters of this `uniform_int_distribution` using the values encapsulated in a given [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object. + + +```cpp showLineNumbers={false} +void thrust::random::uniform_int_distribution::param( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object encapsulating the new range of this `uniform_int_distribution`. + + + + + +const + +This method returns a [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object encapsulating the parameters with which this `uniform_int_distribution` was constructed. + + +```cpp showLineNumbers={false} +param_type thrust::random::uniform_int_distribution::param() const +``` + + +**Returns:** A [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object enapsulating the range of this `uniform_int_distribution`. + + + + +### min const + +This method returns the smallest integer this `uniform_int_distribution` can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_int_distribution::min() const +``` + + +**Returns:** The lower bound of this `uniform_int_distribution``'s` range. + +### max const + +This method returns the largest integer this `uniform_int_distribution` can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_int_distribution::max() const +``` + + +**Returns:** The upper bound of this `uniform_int_distribution``'s` range. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `result_type` | `IntType` | The type of the integer produced by this `uniform_int_distribution`. | +| `param_type` | `::cuda::std::pair< IntType, IntType >` | The type of the object encapsulating this `uniform_int_distribution``'s` parameters. | diff --git a/fern/cudapages/thrust/thrust/thrust/random/uniform_real_distribution.mdx b/fern/cudapages/thrust/thrust/thrust/random/uniform_real_distribution.mdx new file mode 100644 index 0000000..1dc78c4 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/random/uniform_real_distribution.mdx @@ -0,0 +1,277 @@ +--- +title: thrust::random::uniform_real_distribution +description: "A [`uniform_real_distribution`](/library/api/thrust::random::uniform_real_distribution) random number distribution produces floating point uniform random numbers from a half-open interval." +--- + +A `uniform_real_distribution` random number distribution produces floating point uniform random numbers from a half-open interval. + +The following code snippet demonstrates examples of using a `uniform_real_distribution` with a random number engine to produce random integers drawn from a given range: + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + // create a minstd_rand object to act as our source of randomness + thrust::minstd_rand rng; + + // create a uniform_real_distribution to produce floats from [-7,13) + thrust::uniform_real_distribution dist(-7,13); + + // write a random number from the range [-7,13) to standard output + std::cout << dist(rng) << std::endl; + + // write the range of the distribution, just in case we forgot + std::cout << dist.min() << std::endl; + + // -7.0 is printed + + std::cout << dist.max() << std::endl; + + // 13.0 is printed + + // write the parameters of the distribution (which happen to be the bounds) to standard output + std::cout << dist.a() << std::endl; + + // -7.0 is printed + + std::cout << dist.b() << std::endl; + + // 13.0 is printed + + return 0; +} +``` + + + + + +The type of floating point number to produce. + + + + + +--- + +## Constructors + +### uniform_real_distribution explicit + + + + +This constructor creates a new `uniform_real_distribution` from two values defining the half-open interval of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::uniform_real_distribution::uniform_real_distribution( + RealType a = 0.0, + RealType b = 1.0 +) +``` + + +**Parameters** + + +The smallest floating point number to potentially produce. Defaults to `0.0`. + + + +The smallest number larger than the largest floating point number to potentially produce. Defaults to `1.0`. + + + + + +This constructor creates a new `uniform_real_distribution` from a [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object encapsulating the range of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::uniform_real_distribution::uniform_real_distribution( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object encapsulating the parameters (i.e., the range) of the distribution. + + + + + +--- + +## Methods + +### reset + +This does nothing. + +It is included to conform to the requirements of the RandomDistribution concept. + + +```cpp showLineNumbers={false} +void thrust::random::uniform_real_distribution::reset() +``` + + +### operator() + + + + +This method produces a new uniform random integer drawn from this `uniform_real_distribution``'s` range using a `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::uniform_real_distribution::operator()( + UniformRandomNumberGenerator &urng +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + + + +This method produces a new uniform random integer as if by creating a new `uniform_real_distribution` from the given [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object, and calling its `operator()` method with the given `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::uniform_real_distribution::operator()( + UniformRandomNumberGenerator &urng, + const param_type &parm +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + +A [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object encapsulating the parameters of the `uniform_real_distribution` to draw from. + + + + + +### a const + +This method returns the value of the parameter with which this `uniform_real_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_real_distribution::a() const +``` + + +**Returns:** The lower bound of this `uniform_real_distribution``'s` half-open interval. + +### b const + +This method returns the value of the parameter with which this `uniform_real_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_real_distribution::b() const +``` + + +**Returns:** The upper bound of this `uniform_real_distribution``'s` half-open interval. + +### param + + + + +This method changes the parameters of this `uniform_real_distribution` using the values encapsulated in a given [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object. + + +```cpp showLineNumbers={false} +void thrust::random::uniform_real_distribution::param( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object encapsulating the new half-open interval of this `uniform_real_distribution`. + + + + + +const + +This method returns a [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object encapsulating the parameters with which this `uniform_real_distribution` was constructed. + + +```cpp showLineNumbers={false} +param_type thrust::random::uniform_real_distribution::param() const +``` + + +**Returns:** A [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object enapsulating the half-open interval of this `uniform_real_distribution`. + + + + +### min const + +This method returns the smallest floating point number this `uniform_real_distribution` can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_real_distribution::min() const +``` + + +**Returns:** The lower bound of this `uniform_real_distribution``'s` half-open interval. + +### max const + +This method returns the smallest number larger than largest floating point number this `uniform_real_distribution` can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_real_distribution::max() const +``` + + +**Returns:** The upper bound of this `uniform_real_distribution``'s` half-open interval. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `result_type` | `RealType` | The type of the floating point number produced by this `uniform_real_distribution`. | +| `param_type` | `::cuda::std::pair< RealType, RealType >` | The type of the object encapsulating this `uniform_real_distribution``'s` parameters. | diff --git a/fern/cudapages/thrust/thrust/thrust/random/xor_combine_engine.mdx b/fern/cudapages/thrust/thrust/thrust/random/xor_combine_engine.mdx new file mode 100644 index 0000000..48d9114 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/random/xor_combine_engine.mdx @@ -0,0 +1,239 @@ +--- +title: thrust::random::xor_combine_engine +description: "An [`xor_combine_engine`](/library/api/thrust::random::xor_combine_engine) adapts two existing base random number engines and produces random values by combining the values produced by each." +--- + +An `xor_combine_engine` adapts two existing base random number engines and produces random values by combining the values produced by each. + +The following code snippet shows an example of using an `xor_combine_engine` instance: + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include + +int main() +{ + // create an xor_combine_engine from minstd_rand and minstd_rand0 + // use a shift of 0 for each + thrust::xor_combine_engine rng; + + // print a random number to standard output + std::cout << rng() << std::endl; + + return 0; +} +``` + + + + + +The type of the first base random number engine to adapt. + + + +The size of the first shift to use in the generation algorithm. + + + +The type of the second base random number engine to adapt. + + + +The second of the second shift to use in the generation algorithm. Defaults to `0`. + + + + + +--- + +## Constructors + +### xor_combine_engine + + + + +This constructor constructs a new `xor_combine_engine` and constructs its adapted engines using their null constructors. + + +```cpp showLineNumbers={false} +thrust::random::xor_combine_engine::xor_combine_engine() +``` + + + + + +This constructor constructs a new `xor_combine_engine` using given [`base1_type`](/library/api/thrust::random::xor_combine_engine::base1_type) and [`base2_type`](/library/api/thrust::random::xor_combine_engine::base2_type) engines to initialize its adapted base engines. + + +```cpp showLineNumbers={false} +thrust::random::xor_combine_engine::xor_combine_engine( + const base1_type &urng1, + const base2_type &urng2 +) +``` + + +**Parameters** + + +A [`base1_type`](/library/api/thrust::random::xor_combine_engine::base1_type) to use to initialize this `xor_combine_engine``'s` first adapted base engine. + + + +A [`base2_type`](/library/api/thrust::random::xor_combine_engine::base2_type) to use to initialize this `xor_combine_engine``'s` first adapted base engine. + + + + + +This constructor initializes a new `xor_combine_engine` with a given seed. + + +```cpp showLineNumbers={false} +thrust::random::xor_combine_engine::xor_combine_engine( + result_type s +) +``` + + +**Parameters** + + +The seed used to initialize this `xor_combine_engine``'s` adapted base engines. + + + + + +--- + +## Methods + +### seed + + + + +This method initializes the state of this `xor_combine_engine``'s` adapted base engines by using their `default_seed` values. + + +```cpp showLineNumbers={false} +void thrust::random::xor_combine_engine::seed() +``` + + + + + +This method initializes the state of this `xor_combine_engine``'s` adapted base engines by using the given seed. + + +```cpp showLineNumbers={false} +void thrust::random::xor_combine_engine::seed( + result_type s +) +``` + + +**Parameters** + + +The seed with which to initialize this `xor_combine_engine``'s` adapted base engines. + + + + + +### operator() + +This member function produces a new random value and updates this `xor_combine_engine``'s` state. + + +```cpp showLineNumbers={false} +result_type thrust::random::xor_combine_engine::operator()( + void +) +``` + + +**Returns:** A new random number. + +### discard + +This member function advances this `xor_combine_engine``'s` state a given number of times and discards the results. + + +```cpp showLineNumbers={false} +void thrust::random::xor_combine_engine::discard( + unsigned long long z +) +``` + + + +This function is provided because an implementation may be able to accelerate it. + + +**Parameters** + + +The number of random values to discard. + + +### base1 const + +This member function returns a const reference to this `xor_combine_engine``'s` first adapted base engine. + + +```cpp showLineNumbers={false} +const base1_type & thrust::random::xor_combine_engine::base1() const +``` + + +**Returns:** A const reference to the first base engine this `xor_combine_engine` adapts. + +### base2 const + +This member function returns a const reference to this `xor_combine_engine``'s` second adapted base engine. + + +```cpp showLineNumbers={false} +const base2_type & thrust::random::xor_combine_engine::base2() const +``` + + +**Returns:** A const reference to the second base engine this `xor_combine_engine` adapts. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base1_type` | `Engine1` | The type of the first adapted base random number engine. | +| `base2_type` | `Engine2` | The type of the second adapted base random number engine. | +| `result_type` | `typename thrust::detail::eval_if<(sizeof(typename base2_type::result_type) > sizeof(typename base1_type::result_type)), ::cuda::std::type_identity< typename base2_type::result_type >, ::cuda::std::type_identity< typename base1_type::result_type > >::type` | The type of the unsigned integer produced by this `xor_combine_engine`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `shift1` static | `const size_t` | The size of the first shift used in the generation algorithm. | +| `shift2` static | `const size_t` | The size of the second shift used in the generation algorithm. | +| `min` static | `const result_type` | The smallest value this `xor_combine_engine` may potentially produce. | +| `max` static | `const result_type` | The largest value this `xor_combine_engine` may potentially produce. | diff --git a/fern/cudapages/thrust/thrust/thrust/random_access_device_iterator_tag.mdx b/fern/cudapages/thrust/thrust/thrust/random_access_device_iterator_tag.mdx new file mode 100644 index 0000000..eba276a --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/random_access_device_iterator_tag.mdx @@ -0,0 +1,17 @@ +--- +title: thrust::random_access_device_iterator_tag +description: "[`random_access_device_iterator_tag`](/library/api/thrust::random_access_device_iterator_tag) is an empty class: it has no member functions, member variables, or nested types." +--- + +`random_access_device_iterator_tag` is an empty class: it has no member functions, member variables, or nested types. + +It is used solely as a "tag": a representation of the Random Access Device Iterator concept within the C++ type system. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[https://en.cppreference.com/w/cpp/iterator/iterator_tags](https://en.cppreference.com/w/cpp/iterator/iterator_tags) iterator_traits, [input_device_iterator_tag](/library/api/thrust::input_device_iterator_tag), [output_device_iterator_tag](/library/api/thrust::output_device_iterator_tag), [forward_device_iterator_tag](/library/api/thrust::forward_device_iterator_tag), [bidirectional_device_iterator_tag](/library/api/thrust::bidirectional_device_iterator_tag), input_host_iterator_tag, output_host_iterator_tag, forward_host_iterator_tag, bidirectional_host_iterator_tag, random_access_host_iterator_tag + +**Inherits from:** `detail::iterator_category_with_system_and_traversal<::cuda::std::random_access_iterator_tag, device_system_tag, random_access_traversal_tag >` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/random_access_traversal_tag.mdx b/fern/cudapages/thrust/thrust/thrust/random_access_traversal_tag.mdx new file mode 100644 index 0000000..05f4144 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/random_access_traversal_tag.mdx @@ -0,0 +1,12 @@ +--- +title: thrust::random_access_traversal_tag +description: "Tag type for iterators allowing random access traversal." +--- + +Tag type for iterators allowing random access traversal. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::bidirectional_traversal_tag` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/runtime_value.mdx b/fern/cudapages/thrust/thrust/thrust/runtime_value.mdx new file mode 100644 index 0000000..47fbbc2 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/runtime_value.mdx @@ -0,0 +1,27 @@ +--- +title: thrust::runtime_value +description: "Holds a runtime value." +--- + +Holds a runtime value. + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `value` | `T` | | diff --git a/fern/cudapages/thrust/thrust/thrust/shuffle_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/shuffle_iterator.mdx new file mode 100644 index 0000000..7534174 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/shuffle_iterator.mdx @@ -0,0 +1,127 @@ +--- +title: thrust::shuffle_iterator +description: "[`shuffle_iterator`](/library/api/thrust::shuffle_iterator) is an iterator which generates a sequence of values representing a random permutation." +--- + +`shuffle_iterator` is an iterator which generates a sequence of values representing a random permutation. + +This iterator is useful for working with random permutations of a range without explicitly storing them in memory. The shuffle iterator is also useful for sampling from a range by selecting only a subset of the elements in the permutation. + +The following code snippet demonstrates how to create a `shuffle_iterator` which generates a random permutation of a vector. + +This next example demonstrates how to use a `shuffle_iterator` to randomly sample from a vector. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +make_shuffle_iterator + +## Example + +```cpp showLineNumbers={false} +#include +... +// create a shuffle iterator +thrust::shuffle_iterator iterator(4, thrust::default_random_engine(0xDEADBEEF)); +// iterator[0] returns 1 +// iterator[1] returns 3 +// iterator[2] returns 2 +// iterator[3] returns 0 + +thrust::device_vector vec = {0, 10, 20, 30}; +thrust::device_vector shuffled(4); +thrust::gather(iterator, iterator + 4, vec.begin(), shuffled.begin()); +// shuffled returns {10, 30, 20, 0} +``` + +```cpp showLineNumbers={false} +#include +... +// create a shuffle iterator +thrust::shuffle_iterator iterator(100, thrust::default_random_engine(0xDEADBEEF)); + +// iterator[0] returns 38 +// iterator[1] returns 50 +// iterator[2] returns 18 +// iterator[3] returns 12 + +// create a vector of size 100 +thrust::device_vector vec(100); +thrust::device_vector sample(4); + +// fill vec with random values +thrust::sequence(vec.begin(), vec.end(), 100); + +// sample 4 random values from vec +thrust::gather(iterator, iterator + 4, vec.begin(), sample.begin()); +// sample returns {138, 150, 118, 112} +``` + + + + + + + + + + + + + +--- + +## Constructors + +### shuffle_iterator inline + + + + +Constructs a `shuffle_iterator` with a given number of elements and a `URBG`. + +The parameters will be forwarded to the bijection constructor. + + +```cpp showLineNumbers={false} +template +thrust::shuffle_iterator::shuffle_iterator( + IndexType n, + URBG &&g +) +``` + + +**Parameters** + + +The number of elements in the permutation. + + + +The `URBG` used to generate the random permutation. This is only invoked during construction of the `shuffle_iterator`. + + + + + +Constructs a `shuffle_iterator` with a given bijection. + + +```cpp showLineNumbers={false} +thrust::shuffle_iterator::shuffle_iterator( + BijectionFunc bijection +) +``` + + +**Parameters** + + +The bijection to use. + + + + diff --git a/fern/cudapages/thrust/thrust/thrust/single_pass_traversal_tag.mdx b/fern/cudapages/thrust/thrust/thrust/single_pass_traversal_tag.mdx new file mode 100644 index 0000000..70b9368 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/single_pass_traversal_tag.mdx @@ -0,0 +1,12 @@ +--- +title: thrust::single_pass_traversal_tag +description: "Tag type for iterators allowing single pass traversal." +--- + +Tag type for iterators allowing single pass traversal. + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::incrementable_traversal_tag` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/square.mdx b/fern/cudapages/thrust/thrust/thrust/square.mdx new file mode 100644 index 0000000..06cfd6a --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/square.mdx @@ -0,0 +1,57 @@ +--- +title: thrust::square +description: "`square` is a function object." +--- + +`square` is a function object. + +Specifically, it is an Adaptable Unary Function. If `f` is an object of class `square`, and `x` is an object of class `T`, then `f(x)` returns `x*x`. + +The following code snippet demonstrates how to use `square` to square the elements of a [device_vector](/library/api/thrust::device_vector) of `floats`. + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include +#include +... +const int N = 1000; +thrust::device_vector V1(N); +thrust::device_vector V2(N); + +thrust::sequence(V1.begin(), V1.end(), 1); + +thrust::transform(V1.begin(), V1.end(), V2.begin(), + thrust::square()); +// V2 is now {1, 4, 9, ..., 1000000} +``` + + + + + +Is a model of [Assignable](https://en.cppreference.com/w/cpp/named_req/CopyAssignable), and if `x` is an object of type `T`, then `x*x` must be defined and must have a return type that is convertible to `T`. + + + + + +--- + +## Methods + +### operator() inline constexpr const + + +```cpp showLineNumbers={false} +T thrust::square::operator()( + const T &x +) const +``` + diff --git a/fern/cudapages/thrust/thrust/thrust/square_void.mdx b/fern/cudapages/thrust/thrust/thrust/square_void.mdx new file mode 100644 index 0000000..ca001e1 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/square_void.mdx @@ -0,0 +1,33 @@ +--- +title: "thrust::square< void >" +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Methods + +### operator() inline constexpr const noexcept + +": "/library/api/thrust::square%3C void %3E"}}> +```cpp showLineNumbers={false} +template +T thrust::square::operator()( + const T &x +) const noexcept(x *x) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `is_transparent` | `void` | diff --git a/fern/cudapages/thrust/thrust/thrust/strided_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/strided_iterator.mdx new file mode 100644 index 0000000..fdf9857 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/strided_iterator.mdx @@ -0,0 +1,126 @@ +--- +title: thrust::strided_iterator +description: "A [`strided_iterator`](/library/api/thrust::strided_iterator) wraps another iterator and moves it by a specified stride each time it is incremented or decremented." +--- + +A `strided_iterator` wraps another iterator and moves it by a specified stride each time it is incremented or decremented. + +```cpp showLineNumbers={false} +#include +``` + + +Use `cuda::strided_iterator` instead + + + + + + +A random access iterator + + + +Either a [runtime_value](/library/api/thrust::runtime_value) or a [compile_time_value](/library/api/thrust::compile_time_value) specifying the stride + + + + + +**Inherits from:** `thrust::iterator_adaptor< strided_iterator< RandomAccessIterator, StrideHolder >, RandomAccessIterator >` (public), `StrideHolder` (private) + +--- + +## Constructors + +### strided_iterator + + + + + +```cpp showLineNumbers={false} +thrust::strided_iterator::strided_iterator() = default +``` + + + + + +inline + +Creates a `strided_iterator` from an existing iterator and a stride. + + +```cpp showLineNumbers={false} +thrust::strided_iterator::strided_iterator( + RandomAccessIterator it, + StrideHolder stride = {} +) +``` + + + + + +--- + +## Methods + +### stride_holder inline const + +Returns either the [runtime_value](/library/api/thrust::runtime_value) or the [compile_time_value](/library/api/thrust::compile_time_value) holding the stride's value. + + +```cpp showLineNumbers={false} +const auto & thrust::strided_iterator::stride_holder() const +``` + + +### stride inline const + +Returns the stride's value. + + +```cpp showLineNumbers={false} +difference_type thrust::strided_iterator::stride() const +``` + + +### base inline const + + +```cpp showLineNumbers={false} +RandomAccessIterator const & thrust::iterator_adaptor, RandomAccessIterator, use_default, use_default, use_default, use_default, use_default>::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline const + + +```cpp showLineNumbers={false} +RandomAccessIterator const & thrust::iterator_adaptor, RandomAccessIterator, use_default, use_default, use_default, use_default, use_default>::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `RandomAccessIterator` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `has_static_stride` static constexpr | `bool` | | diff --git a/fern/cudapages/thrust/thrust/thrust/subtract_with_carry_engine.mdx b/fern/cudapages/thrust/thrust/thrust/subtract_with_carry_engine.mdx new file mode 100644 index 0000000..4222806 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/subtract_with_carry_engine.mdx @@ -0,0 +1,152 @@ +--- +title: thrust::subtract_with_carry_engine +description: "A [`subtract_with_carry_engine`](/library/api/thrust::subtract_with_carry_engine) random number engine produces unsigned integer random numbers using the subtract with carry algorithm of Marsaglia & Zaman." +--- + +A `subtract_with_carry_engine` random number engine produces unsigned integer random numbers using the subtract with carry algorithm of Marsaglia & Zaman. + +The generation algorithm is performed as follows: + +1. Let `Y = X_{i-s}- X_{i-r} - c`. +2. Set `X_i` to `y = T mod m`. Set `c` to `1` if `Y < 0`, otherwise set `c` to `0`. + +This algorithm corresponds to a modular linear function of the form + +`TA(x_i) = (a * x_i) mod b`, where `b` is of the form `m^r - m^s + 1` and `a = b - (b-1)/m`. + +```cpp showLineNumbers={false} +#include +``` + + +Inexperienced users should not use this class template directly. Instead, use `ranlux24_base` or `ranlux48_base`, which are instances of `subtract_with_carry_engine`. + + +**See also:** +thrust::random::ranlux24_base, +thrust::random::ranlux48_base + + + + + +The type of unsigned integer to produce. + + + +The word size of the produced values (` w <= sizeof(UIntType)`). + + + +The short lag of the generation algorithm. + + + +The long lag of the generation algorithm. + + + + + +--- + +## Constructors + +### subtract_with_carry_engine explicit + +This constructor, which optionally accepts a seed, initializes a new `subtract_with_carry_engine`. + + +```cpp showLineNumbers={false} +thrust::random::subtract_with_carry_engine::subtract_with_carry_engine( + result_type value = default_seed +) +``` + + +**Parameters** + + +The seed used to initialize this `subtract_with_carry_engine``'s` state. + + +--- + +## Methods + +### seed + +This method initializes this `subtract_with_carry_engine``'s` state, and optionally accepts a seed value. + + +```cpp showLineNumbers={false} +void thrust::random::subtract_with_carry_engine::seed( + result_type value = default_seed +) +``` + + +**Parameters** + + +The seed used to initializes this `subtract_with_carry_engine``'s` state. + + +### operator() + +This member function produces a new random value and updates this `subtract_with_carry_engine``'s` state. + + +```cpp showLineNumbers={false} +result_type thrust::random::subtract_with_carry_engine::operator()( + void +) +``` + + +**Returns:** A new random number. + +### discard + +This member function advances this `subtract_with_carry_engine``'s` state a given number of times and discards the results. + + +```cpp showLineNumbers={false} +void thrust::random::subtract_with_carry_engine::discard( + unsigned long long z +) +``` + + + +This function is provided because an implementation may be able to accelerate it. + + +**Parameters** + + +The number of random values to discard. + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `result_type` | `UIntType` | The type of the unsigned integer produced by this `subtract_with_carry_engine`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `word_size` static | `const size_t` | The word size of the produced values. | +| `short_lag` static | `const size_t` | The size of the short lag used in the generation algorithm. | +| `long_lag` static | `const size_t` | The size of the long lag used in the generation algorithm. | +| `min` static | `const result_type` | The smallest value this `subtract_with_carry_engine` may potentially produce. | +| `max` static | `const result_type` | The largest value this `subtract_with_carry_engine` may potentially produce. | +| `default_seed` static | `const result_type` | The default seed of this `subtract_with_carry_engine`. | diff --git a/fern/cudapages/thrust/thrust/thrust/system/error_category.mdx b/fern/cudapages/thrust/thrust/thrust/system/error_category.mdx new file mode 100644 index 0000000..99c2e3d --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/system/error_category.mdx @@ -0,0 +1,139 @@ +--- +title: thrust::system::error_category +description: "The class [`error_category`](/library/api/thrust::system::error_category) serves as a base class for types used to identify the source and encoding of a particular category of error code." +--- + +The class `error_category` serves as a base class for types used to identify the source and encoding of a particular category of error code. + +Classes may be derived from `error_category` to support categories of errors in addition to those defined in the C++ International Standard. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Constructors + +### Destructor + +### ~error_category inline virtual + +Destructor does nothing. + + +```cpp showLineNumbers={false} +virtual thrust::system::error_category::~error_category() +``` + + +--- + +## Methods + +### name inline const virtual + + +```cpp showLineNumbers={false} +virtual const char * thrust::system::error_category::name() const +``` + + +**Returns:** A string naming the error category. + +### default_error_condition inline const virtual + + +```cpp showLineNumbers={false} +virtual error_condition thrust::system::error_category::default_error_condition( + int ev +) const +``` + + +**Returns:** `error_condition(ev, *this)`. + +### equivalent inline const virtual + + + + + +```cpp showLineNumbers={false} +virtual bool thrust::system::error_category::equivalent( + int code, + const error_condition &condition +) const +``` + + +**Returns:** `default_error_condition(code) == condition` + + + + + +```cpp showLineNumbers={false} +virtual bool thrust::system::error_category::equivalent( + const error_code &code, + int condition +) const +``` + + +**Returns:** `*this == code.category() && code.value() == condition` + + + + +### message const virtual + + +```cpp showLineNumbers={false} +virtual std::string thrust::system::error_category::message( + int ev +) const +``` + + +**Returns:** A string that describes the error condition denoted by `ev`. + +### operator== inline const + + +```cpp showLineNumbers={false} +bool thrust::system::error_category::operator==( + const error_category &rhs +) const +``` + + +**Returns:** `*this == &rhs` + +### operator!= inline const + + +```cpp showLineNumbers={false} +bool thrust::system::error_category::operator!=( + const error_category &rhs +) const +``` + + +**Returns:** `!(*this == rhs)` + +### operator< inline const + + +```cpp showLineNumbers={false} +bool thrust::system::error_category::operator<( + const error_category &rhs +) const +``` + + + +`less` provides a total ordering for pointers. + + +**Returns:** `less()``(this, &rhs)` diff --git a/fern/cudapages/thrust/thrust/thrust/system/error_code.mdx b/fern/cudapages/thrust/thrust/thrust/system/error_code.mdx new file mode 100644 index 0000000..a44372f --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/system/error_code.mdx @@ -0,0 +1,175 @@ +--- +title: thrust::system::error_code +description: "The class [`error_code`](/library/api/thrust::system::error_code) describes an object used to hold error code values, such as those originating from the operating system or other low-level application program interfaces." +--- + +The class `error_code` describes an object used to hold error code values, such as those originating from the operating system or other low-level application program interfaces. + +```cpp showLineNumbers={false} +#include +``` + +--- + +## Constructors + +### error_code + + + + +inline + +Effects: Constructs an object of type `error_code`. + + +```cpp showLineNumbers={false} +thrust::system::error_code::error_code() +``` + + + +[`value()`](/library/api/thrust::system::error_code::value())` == 0` and [`category()`](/library/api/thrust::system::error_code::category())` == &``system_category()`. + + + + + +inline + +Effects: Constructs an object of type `error_code`. + + +```cpp showLineNumbers={false} +thrust::system::error_code::error_code( + int val, + const error_category &cat +) +``` + + + +[`value()`](/library/api/thrust::system::error_code::value())` == val` and [`category()`](/library/api/thrust::system::error_code::category())` == &cat`. + + + + + +Effects: Constructs an object of type `error_code`. + + +```cpp showLineNumbers={false} +template +thrust::system::error_code::error_code( + ErrorCodeEnum e, + ::cuda::std::enable_if_t::value> * = 0 +) +``` + + + +`*this == make_error_code(e)`. + + + + + +--- + +## Assignment operators + +### operator= + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t::value, error_code> & thrust::system::error_code::operator=( + ErrorCodeEnum e +) +``` + + + +`*this == make_error_code(e)`. + + +--- + +## Methods + +### assign inline + + +```cpp showLineNumbers={false} +void thrust::system::error_code::assign( + int val, + const error_category &cat +) +``` + + + +[`value()`](/library/api/thrust::system::error_code::value())` == val` and [`category()`](/library/api/thrust::system::error_code::category())` == &cat`. + + +### clear inline + + +```cpp showLineNumbers={false} +void thrust::system::error_code::clear() +``` + + + +[`value()`](/library/api/thrust::system::error_code::value())` == 0` and [`category()`](/library/api/thrust::system::error_code::category())` == ``system_category()`. + + +### value inline const + + +```cpp showLineNumbers={false} +int thrust::system::error_code::value() const +``` + + +**Returns:** An integral value of this `error_code` object. + +### category inline const + + +```cpp showLineNumbers={false} +const error_category & thrust::system::error_code::category() const +``` + + +**Returns:** An [`error_category`](/library/api/thrust::system::error_category) describing the category of this `error_code` object. + +### default_error_condition inline const + + +```cpp showLineNumbers={false} +error_condition thrust::system::error_code::default_error_condition() const +``` + + +**Returns:** [`category()`](/library/api/thrust::system::error_code::category())`.`[`default_error_condition()`](/library/api/thrust::system::error_code::default_error_condition()). + +### message inline const + + +```cpp showLineNumbers={false} +std::string thrust::system::error_code::message() const +``` + + +**Returns:** [`category()`](/library/api/thrust::system::error_code::category())`.message(value())`. + +### operator bool inline const + + +```cpp showLineNumbers={false} +thrust::system::error_code::operator bool() const +``` + + +**Returns:** [`value()`](/library/api/thrust::system::error_code::value())` != 0`. diff --git a/fern/cudapages/thrust/thrust/thrust/system/error_condition.mdx b/fern/cudapages/thrust/thrust/thrust/system/error_condition.mdx new file mode 100644 index 0000000..941dc2e --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/system/error_condition.mdx @@ -0,0 +1,211 @@ +--- +title: thrust::system::error_condition +description: "The class [`error_condition`](/library/api/thrust::system::error_condition) describes an object used to hold values identifying error conditions." +--- + +The class `error_condition` describes an object used to hold values identifying error conditions. + +```cpp showLineNumbers={false} +#include +``` + + +`error_condition` values are portable abstractions, while [`error_code`](/library/api/thrust::system::error_code) values are implementation specific. + + +--- + +## Constructors + +### error_condition + + + + +inline + +Constructs an object of type `error_condition`. + + +```cpp showLineNumbers={false} +thrust::system::error_condition::error_condition() +``` + + + +[`value()`](/library/api/thrust::system::error_condition::value())` == 0`. + + + +[`category()`](/library/api/thrust::system::error_condition::category())` == ``generic_category()`. + + + + + +inline + +Constructs an object of type `error_condition`. + + +```cpp showLineNumbers={false} +thrust::system::error_condition::error_condition( + int val, + const error_category &cat +) +``` + + + +[`value()`](/library/api/thrust::system::error_condition::value())` == val`. + + + +[`category()`](/library/api/thrust::system::error_condition::category())` == cat`. + + + + + +Constructs an object of type `error_condition`. + + +```cpp showLineNumbers={false} +template +thrust::system::error_condition::error_condition( + ErrorConditionEnum e, + ::cuda::std::enable_if_t::value> * = 0 +) +``` + + + +This constructor shall not participate in overload resolution unless `is_error_condition_enum::value` is `true`. + + + +`*this == make_error_condition(e)`. + + + + + +--- + +## Assignment operators + +### operator= + +Assigns to this [`error_code`](/library/api/thrust::system::error_code) object from an error condition enumeration. + + +```cpp showLineNumbers={false} +template +::cuda::std::enable_if_t::value, error_condition> & thrust::system::error_condition::operator=( + ErrorConditionEnum e +) +``` + + + +This operator shall not participate in overload resolution unless `is_error_condition_enum::value` is `true`. + + + +`*this == make_error_condition(e)`. + + +**Returns:** *this + +--- + +## Methods + +### assign inline + +Assigns to this [`error_code`](/library/api/thrust::system::error_code) object from an error value and an [`error_category`](/library/api/thrust::system::error_category). + + +```cpp showLineNumbers={false} +void thrust::system::error_condition::assign( + int val, + const error_category &cat +) +``` + + + +[`value()`](/library/api/thrust::system::error_condition::value())` == val`. + + + +[`category()`](/library/api/thrust::system::error_condition::category())` == cat`. + + +**Parameters** + + +The new value to return from [`value()`](/library/api/thrust::system::error_condition::value()). + + + +The new [`error_category`](/library/api/thrust::system::error_category) to return from [`category()`](/library/api/thrust::system::error_condition::category()). + + +### clear inline + +Clears this [`error_code`](/library/api/thrust::system::error_code) object. + + +```cpp showLineNumbers={false} +void thrust::system::error_condition::clear() +``` + + + +[`value`](/library/api/thrust::system::error_condition::value)` == 0` + + + +[`category()`](/library/api/thrust::system::error_condition::category())` == ``generic_category()`. + + +### value inline const + + +```cpp showLineNumbers={false} +int thrust::system::error_condition::value() const +``` + + +**Returns:** The value encoded by this `error_condition`. + +### category inline const + + +```cpp showLineNumbers={false} +const error_category & thrust::system::error_condition::category() const +``` + + +**Returns:** A `const` reference to the [`error_category`](/library/api/thrust::system::error_category) encoded by this `error_condition`. + +### message inline const + + +```cpp showLineNumbers={false} +std::string thrust::system::error_condition::message() const +``` + + +**Returns:** [`category()`](/library/api/thrust::system::error_condition::category())`.message(value())`. + +### operator bool inline const + + +```cpp showLineNumbers={false} +thrust::system::error_condition::operator bool() const +``` + + +**Returns:** [`value()`](/library/api/thrust::system::error_condition::value())` != 0`. diff --git a/fern/cudapages/thrust/thrust/thrust/system/is_error_code_enum.mdx b/fern/cudapages/thrust/thrust/thrust/system/is_error_code_enum.mdx new file mode 100644 index 0000000..2fc9861 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/system/is_error_code_enum.mdx @@ -0,0 +1,21 @@ +--- +title: thrust::system::is_error_code_enum +description: "A metafunction returning whether or not the parameter is an [`error_code`](/library/api/thrust::system::error_code) enum." +--- + +A metafunction returning whether or not the parameter is an [`error_code`](/library/api/thrust::system::error_code) enum. + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `thrust::detail::false_type` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/system/is_error_code_enum_cudaerrcerrc_t.mdx b/fern/cudapages/thrust/thrust/thrust/system/is_error_code_enum_cudaerrcerrc_t.mdx new file mode 100644 index 0000000..7992d01 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/system/is_error_code_enum_cudaerrcerrc_t.mdx @@ -0,0 +1,12 @@ +--- +title: "thrust::system::is_error_code_enum< cuda::errc::errc_t >" +description: "Specialization of [`is_error_code_enum`](/library/api/thrust::system::is_error_code_enum) for [`cuda::errc::errc_t`](/library/api/thrust::system::cuda::errc::errc_t)." +--- + +Specialization of [`is_error_code_enum`](/library/api/thrust::system::is_error_code_enum) for [`cuda::errc::errc_t`](/library/api/thrust::system::cuda::errc::errc_t). + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::detail::false_type` (public), `thrust::detail::true_type` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/system/is_error_condition_enum.mdx b/fern/cudapages/thrust/thrust/thrust/system/is_error_condition_enum.mdx new file mode 100644 index 0000000..9922b95 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/system/is_error_condition_enum.mdx @@ -0,0 +1,21 @@ +--- +title: thrust::system::is_error_condition_enum +description: "A metafunction returning whether or not the parameter is an [`error_condition`](/library/api/thrust::system::error_condition) enum." +--- + +A metafunction returning whether or not the parameter is an [`error_condition`](/library/api/thrust::system::error_condition) enum. + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + +**Inherits from:** `thrust::detail::false_type` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/system/is_error_condition_enum_errcerrc_t.mdx b/fern/cudapages/thrust/thrust/thrust/system/is_error_condition_enum_errcerrc_t.mdx new file mode 100644 index 0000000..38b6872 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/system/is_error_condition_enum_errcerrc_t.mdx @@ -0,0 +1,12 @@ +--- +title: "thrust::system::is_error_condition_enum< errc::errc_t >" +description: "Specialization of [`is_error_condition_enum`](/library/api/thrust::system::is_error_condition_enum) for [`errc::errc_t`](/library/api/thrust::system::errc::errc_t)." +--- + +Specialization of [`is_error_condition_enum`](/library/api/thrust::system::is_error_condition_enum) for [`errc::errc_t`](/library/api/thrust::system::errc::errc_t). + +```cpp showLineNumbers={false} +#include +``` + +**Inherits from:** `thrust::detail::false_type` (public), `thrust::detail::true_type` (public) diff --git a/fern/cudapages/thrust/thrust/thrust/system/system_error.mdx b/fern/cudapages/thrust/thrust/thrust/system/system_error.mdx new file mode 100644 index 0000000..2a2895c --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/system/system_error.mdx @@ -0,0 +1,290 @@ +--- +title: thrust::system::system_error +description: "The class [`system_error`](/library/api/thrust::system::system_error) describes an exception object used to report error conditions that have an associated [`error_code`](/library/api/thrust::system::error_code)." +--- + +The class `system_error` describes an exception object used to report error conditions that have an associated [`error_code`](/library/api/thrust::system::error_code). + +Such error conditions typically originate from the operating system or other low-level application program interfaces. + +Thrust uses `system_error` to report the error codes returned from device backends such as the CUDA runtime. + +The following code listing demonstrates how to catch a `system_error` to recover from an error. + +```cpp showLineNumbers={false} +#include +``` + + +If an error represents an out-of-memory condition, implementations are encouraged to throw an exception object of type `std::bad_alloc` rather than `system_error`. + + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include + +void terminate_gracefully() +{ + // application-specific termination code here + ... +} + +int main() +{ + try + { + thrust::device_vector vec; + thrust::sort(vec.begin(), vec.end()); + } + catch(thrust::system_error e) + { + std::cerr << "Error inside sort: " << e.what() << std::endl; + terminate_gracefully(); + } + + return 0; +} +``` + +**Inherits from:** `std::runtime_error` (public) + +--- + +## Constructors + +### system_error inline + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + error_code ec, + const std::string &what_arg +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == ec`. + + + +`std::string(what()).find(what_arg) != string::npos`. + + +**Parameters** + + +The value returned by [`code()`](/library/api/thrust::system::system_error::code()). + + + +A string to include in the result returned by [`what()`](/library/api/thrust::system::system_error::what()). + + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + error_code ec, + const char *what_arg +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == ec`. + + + +`std::string(what()).find(what_arg) != string::npos`. + + +**Parameters** + + +The value returned by [`code()`](/library/api/thrust::system::system_error::code()). + + + +A string to include in the result returned by [`what()`](/library/api/thrust::system::system_error::what()). + + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + error_code ec +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == ec`. + + +**Parameters** + + +The value returned by [`code()`](/library/api/thrust::system::system_error::code()). + + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + int ev, + const error_category &ecat, + const std::string &what_arg +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == error_code(ev, ecat)`. + + + +`std::string(what()).find(what_arg) != string::npos`. + + +**Parameters** + + +The error value used to create an [`error_code`](/library/api/thrust::system::error_code). + + + +The [`error_category`](/library/api/thrust::system::error_category) used to create an [`error_code`](/library/api/thrust::system::error_code). + + + +A string to include in the result returned by [`what()`](/library/api/thrust::system::system_error::what()). + + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + int ev, + const error_category &ecat, + const char *what_arg +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == error_code(ev, ecat)`. + + + +`std::string(what()).find(what_arg) != string::npos`. + + +**Parameters** + + +The error value used to create an [`error_code`](/library/api/thrust::system::error_code). + + + +The [`error_category`](/library/api/thrust::system::error_category) used to create an [`error_code`](/library/api/thrust::system::error_code). + + + +A string to include in the result returned by [`what()`](/library/api/thrust::system::system_error::what()). + + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + int ev, + const error_category &ecat +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == error_code(ev, ecat)`. + + +**Parameters** + + +The error value used to create an [`error_code`](/library/api/thrust::system::error_code). + + + +The [`error_category`](/library/api/thrust::system::error_category) used to create an [`error_code`](/library/api/thrust::system::error_code). + + + + + +### Destructor + +### ~system_error inline noexcept virtual + +Destructor does not throw. + + +```cpp showLineNumbers={false} +virtual thrust::system::system_error::~system_error() noexcept +``` + + +--- + +## Methods + +### code inline const noexcept + +Returns an object encoding the error. + + +```cpp showLineNumbers={false} +const error_code & thrust::system::system_error::code() const noexcept +``` + + +**Returns:** `ec` or `error_code(ev, ecat)`, from the constructor, as appropriate. + +### what inline const noexcept + +Returns a human-readable string indicating the nature of the error. + + +```cpp showLineNumbers={false} +const char * thrust::system::system_error::what() const noexcept +``` + + +**Returns:** a string incorporating [`code()`](/library/api/thrust::system::system_error::code())`.message()` and the arguments supplied in the constructor. diff --git a/fern/cudapages/thrust/thrust/thrust/system_error.mdx b/fern/cudapages/thrust/thrust/thrust/system_error.mdx new file mode 100644 index 0000000..cb42423 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/system_error.mdx @@ -0,0 +1,290 @@ +--- +title: thrust::system_error +description: "The class [`system_error`](/library/api/thrust::system_error) describes an exception object used to report error conditions that have an associated [`error_code`](/library/api/thrust::error_code)." +--- + +The class `system_error` describes an exception object used to report error conditions that have an associated [`error_code`](/library/api/thrust::error_code). + +Such error conditions typically originate from the operating system or other low-level application program interfaces. + +Thrust uses `system_error` to report the error codes returned from device backends such as the CUDA runtime. + +The following code listing demonstrates how to catch a `system_error` to recover from an error. + +```cpp showLineNumbers={false} +#include +``` + + +If an error represents an out-of-memory condition, implementations are encouraged to throw an exception object of type `std::bad_alloc` rather than `system_error`. + + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include + +void terminate_gracefully() +{ + // application-specific termination code here + ... +} + +int main() +{ + try + { + thrust::device_vector vec; + thrust::sort(vec.begin(), vec.end()); + } + catch(thrust::system_error e) + { + std::cerr << "Error inside sort: " << e.what() << std::endl; + terminate_gracefully(); + } + + return 0; +} +``` + +**Inherits from:** `std::runtime_error` (public) + +--- + +## Constructors + +### system_error inline + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + error_code ec, + const std::string &what_arg +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == ec`. + + + +`std::string(what()).find(what_arg) != string::npos`. + + +**Parameters** + + +The value returned by [`code()`](/library/api/thrust::system::system_error::code()). + + + +A string to include in the result returned by [`what()`](/library/api/thrust::system::system_error::what()). + + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + error_code ec, + const char *what_arg +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == ec`. + + + +`std::string(what()).find(what_arg) != string::npos`. + + +**Parameters** + + +The value returned by [`code()`](/library/api/thrust::system::system_error::code()). + + + +A string to include in the result returned by [`what()`](/library/api/thrust::system::system_error::what()). + + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + error_code ec +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == ec`. + + +**Parameters** + + +The value returned by [`code()`](/library/api/thrust::system::system_error::code()). + + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + int ev, + const error_category &ecat, + const std::string &what_arg +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == error_code(ev, ecat)`. + + + +`std::string(what()).find(what_arg) != string::npos`. + + +**Parameters** + + +The error value used to create an [`error_code`](/library/api/thrust::error_code). + + + +The [`error_category`](/library/api/thrust::error_category) used to create an [`error_code`](/library/api/thrust::error_code). + + + +A string to include in the result returned by [`what()`](/library/api/thrust::system::system_error::what()). + + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + int ev, + const error_category &ecat, + const char *what_arg +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == error_code(ev, ecat)`. + + + +`std::string(what()).find(what_arg) != string::npos`. + + +**Parameters** + + +The error value used to create an [`error_code`](/library/api/thrust::error_code). + + + +The [`error_category`](/library/api/thrust::error_category) used to create an [`error_code`](/library/api/thrust::error_code). + + + +A string to include in the result returned by [`what()`](/library/api/thrust::system::system_error::what()). + + + + + +Constructs an object of class `system_error`. + + +```cpp showLineNumbers={false} +thrust::system::system_error::system_error( + int ev, + const error_category &ecat +) +``` + + + +[`code()`](/library/api/thrust::system::system_error::code())` == error_code(ev, ecat)`. + + +**Parameters** + + +The error value used to create an [`error_code`](/library/api/thrust::error_code). + + + +The [`error_category`](/library/api/thrust::error_category) used to create an [`error_code`](/library/api/thrust::error_code). + + + + + +### Destructor + +### ~system_error inline noexcept virtual + +Destructor does not throw. + + +```cpp showLineNumbers={false} +virtual thrust::system::system_error::~system_error() noexcept +``` + + +--- + +## Methods + +### code inline const noexcept + +Returns an object encoding the error. + + +```cpp showLineNumbers={false} +const error_code & thrust::system::system_error::code() const noexcept +``` + + +**Returns:** `ec` or `error_code(ev, ecat)`, from the constructor, as appropriate. + +### what inline const noexcept + +Returns a human-readable string indicating the nature of the error. + + +```cpp showLineNumbers={false} +const char * thrust::system::system_error::what() const noexcept +``` + + +**Returns:** a string incorporating [`code()`](/library/api/thrust::system::system_error::code())`.message()` and the arguments supplied in the constructor. diff --git a/fern/cudapages/thrust/thrust/thrust/tabulate_output_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/tabulate_output_iterator.mdx new file mode 100644 index 0000000..84a7ffb --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/tabulate_output_iterator.mdx @@ -0,0 +1,149 @@ +--- +title: thrust::tabulate_output_iterator +description: "[`tabulate_output_iterator`](/library/api/thrust::tabulate_output_iterator) is a special kind of output iterator which, whenever a value is assigned to a dereferenced iterator, calls the given callable with the index that corresponds to the offset of the dereferenced iterator and the assigned value." +--- + +`tabulate_output_iterator` is a special kind of output iterator which, whenever a value is assigned to a dereferenced iterator, calls the given callable with the index that corresponds to the offset of the dereferenced iterator and the assigned value. + +The following code snippet demonstrated how to create a `tabulate_output_iterator` which prints the index and the assigned value. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +make_tabulate_output_iterator + +## Example + +```cpp showLineNumbers={false} +#include + +// note: functor inherits form binary function +struct print_op +{ + __host__ __device__ + void operator()(int index, float value) const + { + printf("%d: %f\n", index, value); + } +}; + +int main() +{ + auto tabulate_it = thrust::make_tabulate_output_iterator(print_op{}); + + tabulate_it[0] = 1.0f; // prints: 0: 1.0 + tabulate_it[1] = 3.0f; // prints: 1: 3.0 + tabulate_it[9] = 5.0f; // prints: 9: 5.0 +} +``` + + + + + + + + + + + + + + + + +**Inherits from:** `thrust::iterator_adaptor< tabulate_output_iterator< BinaryFunction, System, DifferenceT >, counting_iterator< DifferenceT >, void, System, use_default, tabulate_output_iterator_proxy< BinaryFunction, DifferenceT > >` (public) + +--- + +## Constructors + +### tabulate_output_iterator + + + + + +```cpp showLineNumbers={false} +thrust::tabulate_output_iterator::tabulate_output_iterator() = default +``` + + + + + +inline + +This constructor takes as argument a `BinaryFunction` and copies it to a new `tabulate_output_iterator`. + + +```cpp showLineNumbers={false} +thrust::tabulate_output_iterator::tabulate_output_iterator( + BinaryFunction fun +) +``` + + +**Parameters** + + +A `BinaryFunction` called whenever a value is assigned to this `tabulate_output_iterator`. + + + + + +--- + +## Methods + +### base inline const + + +```cpp showLineNumbers={false} +counting_iterator const & thrust::iterator_adaptor, counting_iterator, void, System, use_default, tabulate_output_iterator_proxy, use_default>::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +counting_iterator & thrust::iterator_adaptor, counting_iterator, void, System, use_default, tabulate_output_iterator_proxy, use_default>::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +const + + +```cpp showLineNumbers={false} +counting_iterator const & thrust::iterator_adaptor, counting_iterator, void, System, use_default, tabulate_output_iterator_proxy, use_default>::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `counting_iterator< DifferenceT >` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/tagged_deleter.mdx b/fern/cudapages/thrust/thrust/thrust/tagged_deleter.mdx new file mode 100644 index 0000000..71f7e6c --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/tagged_deleter.mdx @@ -0,0 +1,46 @@ +--- +title: thrust::tagged_deleter +description: "" +--- + +```cpp showLineNumbers={false} +#include +``` + + + + + + + + + + + + + +**Inherits from:** `Lambda` (public) + +--- + +## Constructors + +### tagged_deleter inline + + +```cpp showLineNumbers={false} +thrust::tagged_deleter::tagged_deleter( + Lambda &&l +) +``` + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `pointer` | `Pointer` | diff --git a/fern/cudapages/thrust/thrust/thrust/transform_input_output_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/transform_input_output_iterator.mdx new file mode 100644 index 0000000..5ca22e8 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/transform_input_output_iterator.mdx @@ -0,0 +1,172 @@ +--- +title: thrust::transform_input_output_iterator +description: "[`transform_input_output_iterator`](/library/api/thrust::transform_input_output_iterator) is a special kind of iterator which applies transform functions when reading from or writing to dereferenced values." +--- + +`transform_input_output_iterator` is a special kind of iterator which applies transform functions when reading from or writing to dereferenced values. + +This iterator is useful for algorithms that operate on a type that needs to be serialized/deserialized from values in another iterator, avoiding the need to materialize intermediate results in memory. This also enables the transform functions to be fused with the operations that read and write to the `transform_input_output_iterator`. + +The following code snippet demonstrates how to create a `transform_input_output_iterator` which performs different transformations when reading from and writing to the iterator. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +make_transform_input_output_iterator + +## Example + +```cpp showLineNumbers={false} +#include +#include + + int main() + { + const size_t size = 4; + thrust::device_vector v(size); + + // Write 1.0f, 2.0f, 3.0f, 4.0f to vector + thrust::sequence(v.begin(), v.end(), 1); + + // Iterator that returns negated values and writes squared values + auto iter = thrust::make_transform_input_output_iterator(v.begin(), + ::cuda::std::negate{}, thrust::square{}); + + // Iterator negates values when reading + std::cout << iter[0] << " "; // -1.0f; + std::cout << iter[1] << " "; // -2.0f; + std::cout << iter[2] << " "; // -3.0f; + std::cout << iter[3] << "\n"; // -4.0f; + + // Write 1.0f, 2.0f, 3.0f, 4.0f to iterator + thrust::sequence(iter, iter + size, 1); + + // Values were squared before writing to vector + std::cout << v[0] << " "; // 1.0f; + std::cout << v[1] << " "; // 4.0f; + std::cout << v[2] << " "; // 9.0f; + std::cout << v[3] << "\n"; // 16.0f; + + } +``` + + + + + + + + + + + + + + + + +**Inherits from:** `thrust::iterator_adaptor< transform_input_output_iterator< InputFunction, OutputFunction, Iterator >, Iterator, invoke_result_t< InputFunction, iterator_value_type >, use_default, use_default, transform_input_output_iterator_proxy< InputFunction, OutputFunction, Iterator > >` (public) + +--- + +## Constructors + +### transform_input_output_iterator + + + + + +```cpp showLineNumbers={false} +thrust::transform_input_output_iterator::transform_input_output_iterator() = default +``` + + + + + +inline + +This constructor takes as argument a `Iterator` an `InputFunction` and an `OutputFunction` and copies them to a new `transform_input_output_iterator`. + + +```cpp showLineNumbers={false} +thrust::transform_input_output_iterator::transform_input_output_iterator( + Iterator const &io, + InputFunction input_function, + OutputFunction output_function +) +``` + + +**Parameters** + + +An `Iterator` pointing to where the input to `InputFunction` will be read from and the result of `OutputFunction` will be written to + + + +An `InputFunction` to be executed on values read from the iterator + + + +An `OutputFunction` to be executed on values written to the iterator + + + + + +--- + +## Methods + +### base inline const + + +```cpp showLineNumbers={false} +Iterator const & thrust::iterator_adaptor, Iterator, invoke_result_t, use_default, use_default, transform_input_output_iterator_proxy, use_default>::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +Iterator & thrust::iterator_adaptor, Iterator, invoke_result_t, use_default, use_default, transform_input_output_iterator_proxy, use_default>::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +const + + +```cpp showLineNumbers={false} +Iterator const & thrust::iterator_adaptor, Iterator, invoke_result_t, use_default, use_default, transform_input_output_iterator_proxy, use_default>::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `Iterator` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/transform_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/transform_iterator.mdx new file mode 100644 index 0000000..23b3817 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/transform_iterator.mdx @@ -0,0 +1,317 @@ +--- +title: thrust::transform_iterator +description: "[`transform_iterator`](/library/api/thrust::transform_iterator) is an iterator which represents a pointer into a range of values after transformation by a function." +--- + +`transform_iterator` is an iterator which represents a pointer into a range of values after transformation by a function. + +This iterator is useful for creating a range filled with the result of applying an operation to another range without either explicitly storing it in memory, or explicitly executing the transformation. Using `transform_iterator` facilitates kernel fusion by deferring the execution of a transformation until the value is needed while saving both memory capacity and bandwidth. + +The following code snippet demonstrates how to create a `transform_iterator` which represents the result of `sqrtf` applied to the contents of a [`device_vector`](/library/api/thrust::device_vector). + +This next example demonstrates how to use a `transform_iterator` with the `thrust::reduce` function to compute the sum of squares of a sequence. We will create temporary `transform_iterators` with the `make_transform_iterator` function in order to avoid explicitly specifying their type: + +The following example illustrates how to use the third template argument to explicitly specify the return type of the function. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +make_transform_iterator + +## Example + +```cpp showLineNumbers={false} +#include +#include + +struct square_root +{ + __host__ __device__ + float operator()(float x) const + { + return sqrtf(x); + } +}; + +int main() +{ + thrust::device_vector v{1.0f, 4.0f, 9.0f, 16.0f}; + + using FloatIterator = thrust::device_vector::iterator; + + thrust::transform_iterator iter(v.begin(), square_root()); + + *iter; // returns 1.0f + iter[0]; // returns 1.0f; + iter[1]; // returns 2.0f; + iter[2]; // returns 3.0f; + iter[3]; // returns 4.0f; + + // iter[4] is an out-of-bounds error +} +``` + +```cpp showLineNumbers={false} +#include +#include +#include +#include + +struct square +{ + __host__ __device__ + float operator()(float x) const + { + return x * x; + } +}; + +int main() +{ + // initialize a device array + thrust::device_vector v{1.0f, 2.0f, 3.0f, 4.0f}; + + float sum_of_squares = + thrust::reduce(thrust::make_transform_iterator(v.begin(), square()), + thrust::make_transform_iterator(v.end(), square())); + + std::cout << "sum of squares: " << sum_of_squares << std::endl; + return 0; +} +``` + +```cpp showLineNumbers={false} +#include +#include + +struct square_root +{ + __host__ __device__ + float operator()(float x) const + { + return sqrtf(x); + } +}; + +int main() +{ + thrust::device_vector v{1.0f, 4.0f, 9.0f, 16.0f}; + + using FloatIterator = thrust::device_vector::iterator; + + // note: float result_type is specified explicitly + thrust::transform_iterator iter(v.begin(), square_root()); + + *iter; // returns 1.0f + iter[0]; // returns 1.0f; + iter[1]; // returns 2.0f; + iter[2]; // returns 3.0f; + iter[3]; // returns 4.0f; + + // iter[4] is an out-of-bounds error +} +``` + + + + + + + + + + + + + + + + + + + +**Inherits from:** `thrust::iterator_adaptor< transform_iterator< AdaptableUnaryFunction, Iterator, use_default, use_default >, Iterator, value_type, use_default, use_default, reference >` (public) + +--- + +## Constructors + +### transform_iterator + + + + + +```cpp showLineNumbers={false} +thrust::transform_iterator::transform_iterator() = default +``` + + + + + + +```cpp showLineNumbers={false} +thrust::transform_iterator::transform_iterator( + transform_iterator const & +) = default +``` + + + + + +inline + +This constructor takes as arguments an `Iterator` and an `AdaptableUnaryFunction` and copies them to a new `transform_iterator`. + + +```cpp showLineNumbers={false} +thrust::transform_iterator::transform_iterator( + Iterator const &x, + AdaptableUnaryFunction f +) +``` + + +**Parameters** + + +An `Iterator` pointing to the input to this `transform_iterator``'s` `AdaptableUnaryFunction`. + + + +An `AdaptableUnaryFunction` used to transform the objects pointed to by `x`. + + + + + +inline explicit + +This explicit constructor copies the value of a given `Iterator` and creates this `transform_iterator``'s` `AdaptableUnaryFunction` using its null constructor. + + +```cpp showLineNumbers={false} +thrust::transform_iterator::transform_iterator( + Iterator const &x +) +``` + + +**Parameters** + + +An `Iterator` to copy. + + + + + +inline + +This copy constructor creates a new `transform_iterator` from another `transform_iterator`. + + +```cpp showLineNumbers={false} +template +thrust::transform_iterator::transform_iterator( + const transform_iterator &other, + detail::enable_if_convertible_t * = 0, + detail::enable_if_convertible_t * = 0 +) +``` + + +**Parameters** + + +The `transform_iterator` to copy. + + + + + +--- + +## Assignment operators + +### operator= inline + + +```cpp showLineNumbers={false} +transform_iterator & thrust::transform_iterator::operator=( + transform_iterator const &other +) +``` + + +--- + +## Methods + +### functor inline const + +This method returns a copy of this `transform_iterator``'s` `AdaptableUnaryFunction`. + + +```cpp showLineNumbers={false} +AdaptableUnaryFunction thrust::transform_iterator::functor() const +``` + + +**Returns:** A copy of this `transform_iterator``'s` `AdaptableUnaryFunction`. + +### base inline const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +Base & thrust::iterator_adaptor::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +const + + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/transform_output_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/transform_output_iterator.mdx new file mode 100644 index 0000000..e9c57bd --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/transform_output_iterator.mdx @@ -0,0 +1,164 @@ +--- +title: thrust::transform_output_iterator +description: "[`transform_output_iterator`](/library/api/thrust::transform_output_iterator) is a special kind of output iterator which transforms a value written upon dereference." +--- + +`transform_output_iterator` is a special kind of output iterator which transforms a value written upon dereference. + +This iterator is useful for transforming an output from algorithms without explicitly storing the intermediate result in the memory and applying subsequent transformation, thereby avoiding wasting memory capacity and bandwidth. Using [`transform_iterator`](/library/api/thrust::transform_iterator) facilitates kernel fusion by deferring execution of transformation until the value is written while saving both memory capacity and bandwidth. + +The following code snippet demonstrated how to create a `transform_output_iterator` which applies `sqrtf` to the assigning value. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +make_transform_output_iterator + +## Example + +```cpp showLineNumbers={false} +#include +#include + +struct square_root +{ + __host__ __device__ + float operator()(float x) const + { + return sqrtf(x); + } +}; + +int main() +{ + thrust::device_vector v(4); + + using FloatIterator = thrust::device_vector::iterator; + thrust::transform_output_iterator iter(v.begin(), square_root()); + + iter[0] = 1.0f; // stores sqrtf( 1.0f) + iter[1] = 4.0f; // stores sqrtf( 4.0f) + iter[2] = 9.0f; // stores sqrtf( 9.0f) + iter[3] = 16.0f; // stores sqrtf(16.0f) + // iter[4] is an out-of-bounds error + + v[0]; // returns 1.0f; + v[1]; // returns 2.0f; + v[2]; // returns 3.0f; + v[3]; // returns 4.0f; + +} +``` + + + + + + + + + + + + + +**Inherits from:** `thrust::iterator_adaptor< transform_output_iterator< UnaryFunction, OutputIterator >, OutputIterator, use_default, use_default, use_default, transform_output_iterator_proxy< UnaryFunction, OutputIterator > >` (public) + +--- + +## Constructors + +### transform_output_iterator + + + + + +```cpp showLineNumbers={false} +thrust::transform_output_iterator::transform_output_iterator() = default +``` + + + + + +inline + +This constructor takes as argument an `OutputIterator` and an `UnaryFunction` and copies them to a new `transform_output_iterator`. + + +```cpp showLineNumbers={false} +thrust::transform_output_iterator::transform_output_iterator( + OutputIterator const &out, + UnaryFunction fun +) +``` + + +**Parameters** + + +An `OutputIterator` pointing to the output range whereto the result of `transform_output_iterator``'s` `UnaryFunction` will be written. + + + +An `UnaryFunction` used to transform the objects assigned to this `transform_output_iterator`. + + + + + +--- + +## Methods + +### base inline const + + +```cpp showLineNumbers={false} +OutputIterator const & thrust::iterator_adaptor, OutputIterator, use_default, use_default, use_default, transform_output_iterator_proxy, use_default>::base() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + +### base_reference inline + + + + + +```cpp showLineNumbers={false} +OutputIterator & thrust::iterator_adaptor, OutputIterator, use_default, use_default, use_default, transform_output_iterator_proxy, use_default>::base_reference() +``` + + +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +const + + +```cpp showLineNumbers={false} +OutputIterator const & thrust::iterator_adaptor, OutputIterator, use_default, use_default, use_default, transform_output_iterator_proxy, use_default>::base_reference() const +``` + + +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base_type` | `OutputIterator` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | diff --git a/fern/cudapages/thrust/thrust/thrust/uniform_int_distribution.mdx b/fern/cudapages/thrust/thrust/thrust/uniform_int_distribution.mdx new file mode 100644 index 0000000..8009597 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/uniform_int_distribution.mdx @@ -0,0 +1,277 @@ +--- +title: thrust::uniform_int_distribution +description: "A [`uniform_int_distribution`](/library/api/thrust::uniform_int_distribution) random number distribution produces signed or unsigned integer uniform random numbers from a given range." +--- + +A `uniform_int_distribution` random number distribution produces signed or unsigned integer uniform random numbers from a given range. + +The following code snippet demonstrates examples of using a `uniform_int_distribution` with a random number engine to produce random integers drawn from a given range: + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + // create a minstd_rand object to act as our source of randomness + thrust::minstd_rand rng; + + // create a uniform_int_distribution to produce ints from [-7,13] + thrust::uniform_int_distribution dist(-7,13); + + // write a random number from the range [-7,13] to standard output + std::cout << dist(rng) << std::endl; + + // write the range of the distribution, just in case we forgot + std::cout << dist.min() << std::endl; + + // -7 is printed + + std::cout << dist.max() << std::endl; + + // 13 is printed + + // write the parameters of the distribution (which happen to be the bounds) to standard output + std::cout << dist.a() << std::endl; + + // -7 is printed + + std::cout << dist.b() << std::endl; + + // 13 is printed + + return 0; +} +``` + + + + + +The type of integer to produce. + + + + + +--- + +## Constructors + +### uniform_int_distribution explicit + + + + +This constructor creates a new `uniform_int_distribution` from two values defining the range of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::uniform_int_distribution::uniform_int_distribution( + IntType a = 0, + IntType b = ::cuda::std::numeric_limits::max() +) +``` + + +**Parameters** + + +The smallest integer to potentially produce. Defaults to `0`. + + + +The largest integer to potentially produce. Defaults to the largest representable integer in the platform. + + + + + +This constructor creates a new `uniform_int_distribution` from a [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object encapsulating the range of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::uniform_int_distribution::uniform_int_distribution( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object encapsulating the parameters (i.e., the range) of the distribution. + + + + + +--- + +## Methods + +### reset + +This does nothing. + +It is included to conform to the requirements of the RandomDistribution concept. + + +```cpp showLineNumbers={false} +void thrust::random::uniform_int_distribution::reset() +``` + + +### operator() + + + + +This method produces a new uniform random integer drawn from this `uniform_int_distribution``'s` range using a `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::uniform_int_distribution::operator()( + UniformRandomNumberGenerator &urng +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + + + +This method produces a new uniform random integer as if by creating a new `uniform_int_distribution` from the given [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object, and calling its `operator()` method with the given `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::uniform_int_distribution::operator()( + UniformRandomNumberGenerator &urng, + const param_type &parm +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + +A [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object encapsulating the parameters of the `uniform_int_distribution` to draw from. + + + + + +### a const + +This method returns the value of the parameter with which this `uniform_int_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_int_distribution::a() const +``` + + +**Returns:** The lower bound of this `uniform_int_distribution``'s` range. + +### b const + +This method returns the value of the parameter with which this `uniform_int_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_int_distribution::b() const +``` + + +**Returns:** The upper bound of this `uniform_int_distribution``'s` range. + +### param + + + + +This method changes the parameters of this `uniform_int_distribution` using the values encapsulated in a given [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object. + + +```cpp showLineNumbers={false} +void thrust::random::uniform_int_distribution::param( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object encapsulating the new range of this `uniform_int_distribution`. + + + + + +const + +This method returns a [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object encapsulating the parameters with which this `uniform_int_distribution` was constructed. + + +```cpp showLineNumbers={false} +param_type thrust::random::uniform_int_distribution::param() const +``` + + +**Returns:** A [`param_type`](/library/api/thrust::random::uniform_int_distribution::param_type) object enapsulating the range of this `uniform_int_distribution`. + + + + +### min const + +This method returns the smallest integer this `uniform_int_distribution` can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_int_distribution::min() const +``` + + +**Returns:** The lower bound of this `uniform_int_distribution``'s` range. + +### max const + +This method returns the largest integer this `uniform_int_distribution` can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_int_distribution::max() const +``` + + +**Returns:** The upper bound of this `uniform_int_distribution``'s` range. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `result_type` | `IntType` | The type of the integer produced by this `uniform_int_distribution`. | +| `param_type` | `::cuda::std::pair< IntType, IntType >` | The type of the object encapsulating this `uniform_int_distribution``'s` parameters. | diff --git a/fern/cudapages/thrust/thrust/thrust/uniform_real_distribution.mdx b/fern/cudapages/thrust/thrust/thrust/uniform_real_distribution.mdx new file mode 100644 index 0000000..c019d60 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/uniform_real_distribution.mdx @@ -0,0 +1,277 @@ +--- +title: thrust::uniform_real_distribution +description: "A [`uniform_real_distribution`](/library/api/thrust::uniform_real_distribution) random number distribution produces floating point uniform random numbers from a half-open interval." +--- + +A `uniform_real_distribution` random number distribution produces floating point uniform random numbers from a half-open interval. + +The following code snippet demonstrates examples of using a `uniform_real_distribution` with a random number engine to produce random integers drawn from a given range: + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include + +int main() +{ + // create a minstd_rand object to act as our source of randomness + thrust::minstd_rand rng; + + // create a uniform_real_distribution to produce floats from [-7,13) + thrust::uniform_real_distribution dist(-7,13); + + // write a random number from the range [-7,13) to standard output + std::cout << dist(rng) << std::endl; + + // write the range of the distribution, just in case we forgot + std::cout << dist.min() << std::endl; + + // -7.0 is printed + + std::cout << dist.max() << std::endl; + + // 13.0 is printed + + // write the parameters of the distribution (which happen to be the bounds) to standard output + std::cout << dist.a() << std::endl; + + // -7.0 is printed + + std::cout << dist.b() << std::endl; + + // 13.0 is printed + + return 0; +} +``` + + + + + +The type of floating point number to produce. + + + + + +--- + +## Constructors + +### uniform_real_distribution explicit + + + + +This constructor creates a new `uniform_real_distribution` from two values defining the half-open interval of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::uniform_real_distribution::uniform_real_distribution( + RealType a = 0.0, + RealType b = 1.0 +) +``` + + +**Parameters** + + +The smallest floating point number to potentially produce. Defaults to `0.0`. + + + +The smallest number larger than the largest floating point number to potentially produce. Defaults to `1.0`. + + + + + +This constructor creates a new `uniform_real_distribution` from a [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object encapsulating the range of the distribution. + + +```cpp showLineNumbers={false} +thrust::random::uniform_real_distribution::uniform_real_distribution( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object encapsulating the parameters (i.e., the range) of the distribution. + + + + + +--- + +## Methods + +### reset + +This does nothing. + +It is included to conform to the requirements of the RandomDistribution concept. + + +```cpp showLineNumbers={false} +void thrust::random::uniform_real_distribution::reset() +``` + + +### operator() + + + + +This method produces a new uniform random integer drawn from this `uniform_real_distribution``'s` range using a `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::uniform_real_distribution::operator()( + UniformRandomNumberGenerator &urng +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + + + +This method produces a new uniform random integer as if by creating a new `uniform_real_distribution` from the given [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object, and calling its `operator()` method with the given `UniformRandomNumberGenerator` as a source of randomness. + + +```cpp showLineNumbers={false} +template +result_type thrust::random::uniform_real_distribution::operator()( + UniformRandomNumberGenerator &urng, + const param_type &parm +) +``` + + +**Parameters** + + +The `UniformRandomNumberGenerator` to use as a source of randomness. + + + +A [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object encapsulating the parameters of the `uniform_real_distribution` to draw from. + + + + + +### a const + +This method returns the value of the parameter with which this `uniform_real_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_real_distribution::a() const +``` + + +**Returns:** The lower bound of this `uniform_real_distribution``'s` half-open interval. + +### b const + +This method returns the value of the parameter with which this `uniform_real_distribution` was constructed. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_real_distribution::b() const +``` + + +**Returns:** The upper bound of this `uniform_real_distribution``'s` half-open interval. + +### param + + + + +This method changes the parameters of this `uniform_real_distribution` using the values encapsulated in a given [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object. + + +```cpp showLineNumbers={false} +void thrust::random::uniform_real_distribution::param( + const param_type &parm +) +``` + + +**Parameters** + + +A [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object encapsulating the new half-open interval of this `uniform_real_distribution`. + + + + + +const + +This method returns a [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object encapsulating the parameters with which this `uniform_real_distribution` was constructed. + + +```cpp showLineNumbers={false} +param_type thrust::random::uniform_real_distribution::param() const +``` + + +**Returns:** A [`param_type`](/library/api/thrust::random::uniform_real_distribution::param_type) object enapsulating the half-open interval of this `uniform_real_distribution`. + + + + +### min const + +This method returns the smallest floating point number this `uniform_real_distribution` can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_real_distribution::min() const +``` + + +**Returns:** The lower bound of this `uniform_real_distribution``'s` half-open interval. + +### max const + +This method returns the smallest number larger than largest floating point number this `uniform_real_distribution` can potentially produce. + + +```cpp showLineNumbers={false} +result_type thrust::random::uniform_real_distribution::max() const +``` + + +**Returns:** The upper bound of this `uniform_real_distribution``'s` half-open interval. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `result_type` | `RealType` | The type of the floating point number produced by this `uniform_real_distribution`. | +| `param_type` | `::cuda::std::pair< RealType, RealType >` | The type of the object encapsulating this `uniform_real_distribution``'s` parameters. | diff --git a/fern/cudapages/thrust/thrust/thrust/xor_combine_engine.mdx b/fern/cudapages/thrust/thrust/thrust/xor_combine_engine.mdx new file mode 100644 index 0000000..e3888b9 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/xor_combine_engine.mdx @@ -0,0 +1,239 @@ +--- +title: thrust::xor_combine_engine +description: "An [`xor_combine_engine`](/library/api/thrust::xor_combine_engine) adapts two existing base random number engines and produces random values by combining the values produced by each." +--- + +An `xor_combine_engine` adapts two existing base random number engines and produces random values by combining the values produced by each. + +The following code snippet shows an example of using an `xor_combine_engine` instance: + +```cpp showLineNumbers={false} +#include +``` + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include + +int main() +{ + // create an xor_combine_engine from minstd_rand and minstd_rand0 + // use a shift of 0 for each + thrust::xor_combine_engine rng; + + // print a random number to standard output + std::cout << rng() << std::endl; + + return 0; +} +``` + + + + + +The type of the first base random number engine to adapt. + + + +The size of the first shift to use in the generation algorithm. + + + +The type of the second base random number engine to adapt. + + + +The second of the second shift to use in the generation algorithm. Defaults to `0`. + + + + + +--- + +## Constructors + +### xor_combine_engine + + + + +This constructor constructs a new `xor_combine_engine` and constructs its adapted engines using their null constructors. + + +```cpp showLineNumbers={false} +thrust::random::xor_combine_engine::xor_combine_engine() +``` + + + + + +This constructor constructs a new `xor_combine_engine` using given [`base1_type`](/library/api/thrust::random::xor_combine_engine::base1_type) and [`base2_type`](/library/api/thrust::random::xor_combine_engine::base2_type) engines to initialize its adapted base engines. + + +```cpp showLineNumbers={false} +thrust::random::xor_combine_engine::xor_combine_engine( + const base1_type &urng1, + const base2_type &urng2 +) +``` + + +**Parameters** + + +A [`base1_type`](/library/api/thrust::random::xor_combine_engine::base1_type) to use to initialize this `xor_combine_engine``'s` first adapted base engine. + + + +A [`base2_type`](/library/api/thrust::random::xor_combine_engine::base2_type) to use to initialize this `xor_combine_engine``'s` first adapted base engine. + + + + + +This constructor initializes a new `xor_combine_engine` with a given seed. + + +```cpp showLineNumbers={false} +thrust::random::xor_combine_engine::xor_combine_engine( + result_type s +) +``` + + +**Parameters** + + +The seed used to initialize this `xor_combine_engine``'s` adapted base engines. + + + + + +--- + +## Methods + +### seed + + + + +This method initializes the state of this `xor_combine_engine``'s` adapted base engines by using their `default_seed` values. + + +```cpp showLineNumbers={false} +void thrust::random::xor_combine_engine::seed() +``` + + + + + +This method initializes the state of this `xor_combine_engine``'s` adapted base engines by using the given seed. + + +```cpp showLineNumbers={false} +void thrust::random::xor_combine_engine::seed( + result_type s +) +``` + + +**Parameters** + + +The seed with which to initialize this `xor_combine_engine``'s` adapted base engines. + + + + + +### operator() + +This member function produces a new random value and updates this `xor_combine_engine``'s` state. + + +```cpp showLineNumbers={false} +result_type thrust::random::xor_combine_engine::operator()( + void +) +``` + + +**Returns:** A new random number. + +### discard + +This member function advances this `xor_combine_engine``'s` state a given number of times and discards the results. + + +```cpp showLineNumbers={false} +void thrust::random::xor_combine_engine::discard( + unsigned long long z +) +``` + + + +This function is provided because an implementation may be able to accelerate it. + + +**Parameters** + + +The number of random values to discard. + + +### base1 const + +This member function returns a const reference to this `xor_combine_engine``'s` first adapted base engine. + + +```cpp showLineNumbers={false} +const base1_type & thrust::random::xor_combine_engine::base1() const +``` + + +**Returns:** A const reference to the first base engine this `xor_combine_engine` adapts. + +### base2 const + +This member function returns a const reference to this `xor_combine_engine``'s` second adapted base engine. + + +```cpp showLineNumbers={false} +const base2_type & thrust::random::xor_combine_engine::base2() const +``` + + +**Returns:** A const reference to the second base engine this `xor_combine_engine` adapts. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `base1_type` | `Engine1` | The type of the first adapted base random number engine. | +| `base2_type` | `Engine2` | The type of the second adapted base random number engine. | +| `result_type` | `typename thrust::detail::eval_if<(sizeof(typename base2_type::result_type) > sizeof(typename base1_type::result_type)), ::cuda::std::type_identity< typename base2_type::result_type >, ::cuda::std::type_identity< typename base1_type::result_type > >::type` | The type of the unsigned integer produced by this `xor_combine_engine`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `shift1` static | `const size_t` | The size of the first shift used in the generation algorithm. | +| `shift2` static | `const size_t` | The size of the second shift used in the generation algorithm. | +| `min` static | `const result_type` | The smallest value this `xor_combine_engine` may potentially produce. | +| `max` static | `const result_type` | The largest value this `xor_combine_engine` may potentially produce. | diff --git a/fern/cudapages/thrust/thrust/thrust/zip_function.mdx b/fern/cudapages/thrust/thrust/thrust/zip_function.mdx new file mode 100644 index 0000000..f235835 --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/zip_function.mdx @@ -0,0 +1,140 @@ +--- +title: thrust::zip_function +description: "[`zip_function`](/library/api/thrust::zip_function) is a function object that allows the easy use of N-ary function objects with `zip_iterators` without redefining them to take a `tuple` instead of N arguments." +--- + +`zip_function` is a function object that allows the easy use of N-ary function objects with `zip_iterators` without redefining them to take a `tuple` instead of N arguments. + +This means that if a functor that takes 2 arguments which could be used with the `transform` function and `device_iterators` can be extended to take 3 arguments and `zip_iterators` without rewriting the functor in terms of `tuple`. + +The `make_zip_function` convenience function is provided to avoid having to explicitly define the type of the functor when creating a `zip_function`, whic is especially helpful when using lambdas as the functor. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +make_zip_function, +[zip_iterator](/library/api/thrust::zip_iterator) + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include +#include + +struct SumTuple { + float operator()(auto tup) const { + return cuda::std::get<0>(tup) + cuda::std::get<1>(tup) + ::cuda::std::get<2>(tup); + } +}; +struct SumArgs { + float operator()(float a, float b, float c) const { + return a + b + c; + } +}; + +int main() { + thrust::device_vector A{0.f, 1.f, 2.f}; + thrust::device_vector B{1.f, 2.f, 3.f}; + thrust::device_vector C{2.f, 3.f, 4.f}; + thrust::device_vector D(3); + + auto begin = thrust::make_zip_iterator(A.begin(), B.begin(), C.begin()); + auto end = thrust::make_zip_iterator(A.end(), B.end(), C.end()); + + // The following four invocations of transform are equivalent: + // Transform with 3-tuple + thrust::transform(begin, end, D.begin(), SumTuple{}); + + // Transform with 3 parameters + thrust::zip_function adapted{}; + thrust::transform(begin, end, D.begin(), adapted); + + // Transform with 3 parameters with convenience function + thrust::transform(begin, end, D.begin(), thrust::make_zip_function(SumArgs{})); + + // Transform with 3 parameters with convenience function and lambda + thrust::transform(begin, end, D.begin(), thrust::make_zip_function([] (float a, float b, float c) { + return a + b + c; + })); + return 0; +} +``` + + + + + + + + + + +--- + +## Constructors + +### zip_function + + + + +Default constructs the contained function object. + + +```cpp showLineNumbers={false} +thrust::zip_function::zip_function() = default +``` + + + + + +inline + + +```cpp showLineNumbers={false} +thrust::zip_function::zip_function( + Function func +) +``` + + + + + +--- + +## Methods + +### operator() inline const + + +```cpp showLineNumbers={false} +template +decltype( + auto +) const +``` + + +### underlying_function inline const + +Returns a reference to the underlying function. + + +```cpp showLineNumbers={false} +Function & thrust::zip_function::underlying_function() const +``` + + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `func` | `Function` | | diff --git a/fern/cudapages/thrust/thrust/thrust/zip_iterator.mdx b/fern/cudapages/thrust/thrust/thrust/zip_iterator.mdx new file mode 100644 index 0000000..7c23e9e --- /dev/null +++ b/fern/cudapages/thrust/thrust/thrust/zip_iterator.mdx @@ -0,0 +1,200 @@ +--- +title: thrust::zip_iterator +description: "[`zip_iterator`](/library/api/thrust::zip_iterator) is an iterator which represents a pointer into a range of `tuples` whose elements are themselves taken from a `tuple` of input iterators." +--- + +`zip_iterator` is an iterator which represents a pointer into a range of `tuples` whose elements are themselves taken from a `tuple` of input iterators. + +This iterator is useful for creating a virtual array of structures while achieving the same performance and bandwidth as the structure of arrays idiom. `zip_iterator` also facilitates kernel fusion by providing a convenient means of amortizing the execution of the same operation over multiple ranges. + +The following code snippet demonstrates how to create a `zip_iterator` which represents the result of "zipping" multiple ranges together. + +Defining the type of a `zip_iterator` can be complex. The next code example demonstrates how to use the `make_zip_iterator` function with the `make_tuple` function to avoid explicitly specifying the type of the `zip_iterator`. This example shows how to use `zip_iterator` to copy multiple ranges with a single call to `thrust::copy`. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +make_zip_iterator, +make_tuple, +tuple, +get + +## Example + +```cpp showLineNumbers={false} +#include +#include +#include +... +thrust::device_vector int_v{0, 1, 2}; +thrust::device_vector float_v{0.0f, 1.0f, 2.0f}; +thrust::device_vector char_v{'a', 'b', 'c'}; + +// aliases for iterators +using IntIterator = thrust::device_vector::iterator; +using FloatIterator = thrust::device_vector::iterator; +using CharIterator = thrust::device_vector::iterator; + +// alias for a tuple of these iterators +using IteratorTuple = cuda::std::tuple; + +// alias the zip_iterator of this tuple +using ZipIterator = thrust::zip_iterator; + +// finally, create the zip_iterator +ZipIterator iter(cuda::std::make_tuple(int_v.begin(), float_v.begin(), char_v.begin())); + +*iter; // returns (0, 0.0f, 'a') +iter[0]; // returns (0, 0.0f, 'a') +iter[1]; // returns (1, 1.0f, 'b') +iter[2]; // returns (2, 2.0f, 'c') + +cuda::std::get<0>(iter[2]); // returns 2 +cuda::std::get<1>(iter[0]); // returns 0.0f +cuda::std::get<2>(iter[1]); // returns 'b' + +// iter[3] is an out-of-bounds error +``` + +```cpp showLineNumbers={false} +#include +#include +#include + +int main() +{ + thrust::device_vector int_in{0, 1, 2}, int_out(3); + thrust::device_vector float_in{0.0f, 10.0f, 20.0f}, float_out(3); + + thrust::copy(thrust::make_zip_iterator(int_in.begin(), float_in.begin()), + thrust::make_zip_iterator(int_in.end(), float_in.end()), + thrust::make_zip_iterator(int_out.begin(),float_out.begin())); + + // int_out is now [0, 1, 2] + // float_out is now [0.0f, 10.0f, 20.0f] + + return 0; +} +``` + + + + + + + + + + +**Inherits from:** `detail::make_zip_iterator_base::type` (public) + +--- + +## Constructors + +### zip_iterator + + + + + +```cpp showLineNumbers={false} +thrust::zip_iterator::zip_iterator() = default +``` + + + + + +inline + +This constructor creates a new `zip_iterator` from a `tuple` of iterators. + + +```cpp showLineNumbers={false} +thrust::zip_iterator::zip_iterator( + IteratorTuple iterator_tuple +) +``` + + +**Parameters** + + +The `tuple` of iterators to copy from. + + + + + +inline + +This constructor creates a new `zip_iterator` from multiple iterators. + + +```cpp showLineNumbers={false} +template +thrust::zip_iterator::zip_iterator( + Iterators &&... iterators +) +``` + + +**Parameters** + + +The iterators to zip. + + + + + +inline + +This copy constructor creates a new `zip_iterator` from another `zip_iterator`. + + +```cpp showLineNumbers={false} +template +thrust::zip_iterator::zip_iterator( + const zip_iterator &other +) +``` + + +**Parameters** + + +The `zip_iterator` to copy. + + + + + +--- + +## Methods + +### get_iterator_tuple inline const + +This method returns a `const` reference to this `zip_iterator``'s` `tuple` of iterators. + + +```cpp showLineNumbers={false} +const IteratorTuple & thrust::zip_iterator::get_iterator_tuple() const +``` + + +**Returns:** A `const` reference to this `zip_iterator``'s` `tuple` of iterators. + +--- + +## Types + +### Typedefs + +| Name | Definition | Description | +|---|---|---| +| `iterator_tuple` | `IteratorTuple` | The underlying iterator tuple type. Alias to `zip_iterator`'s first template argument. | diff --git a/fern/docs.yml b/fern/docs.yml index 8c9bef7..5a2e078 100644 --- a/fern/docs.yml +++ b/fern/docs.yml @@ -37,6 +37,9 @@ tabs: Django Core Reference: display-name: Django Core Reference icon: book + C++ Full Reference: + display-name: C++ Full Reference + icon: book libraries: nemo-rl: @@ -112,36 +115,66 @@ navigation: contents: [] - tab: C++ Golden Pages layout: - - section: CUB + - section: CUB (v5 - Hand Curated) contents: - page: BlockReduce - path: pages/cub/block_reduce_v3.mdx + path: pages/cub/block_reduce_v5.mdx - page: BlockScan - path: pages/cub/block_scan_v4.mdx + path: pages/cub/block_scan_v5.mdx - page: WarpReduce - path: pages/cub/warp_reduce_v4.mdx + path: pages/cub/warp_reduce_v5.mdx - page: ArgMax - path: pages/cub/simple_struct_v4.mdx - - section: Thrust + path: pages/cub/simple_struct_v5.mdx + - section: CUB (v6 - Renderer Output) + contents: + - page: BlockReduce + path: pages/cub/block_reduce_v6.mdx + - page: BlockScan + path: pages/cub/block_scan_v6.mdx + - page: WarpReduce + path: pages/cub/warp_reduce_v6.mdx + - page: ArgMax + path: pages/cub/simple_struct_v6.mdx + - section: Thrust (v5 - Hand Curated) contents: - page: device_vector - path: pages/thrust/device_vector_v3.mdx + path: pages/thrust/device_vector_v5.mdx - page: pointer - path: pages/thrust/pointer_v4.mdx + path: pages/thrust/pointer_v5.mdx - page: strided_iterator - path: pages/thrust/deprecated_example_v4.mdx + path: pages/thrust/deprecated_example_v5.mdx - page: disjoint_unsynchronized_pool_resource - path: pages/thrust/group_member_example_v4.mdx - - section: libcudacxx + path: pages/thrust/group_member_example_v5.mdx + - section: Thrust (v6 - Renderer Output) + contents: + - page: device_vector + path: pages/thrust/device_vector_v6.mdx + - page: pointer + path: pages/thrust/pointer_v6.mdx + - page: strided_iterator + path: pages/thrust/deprecated_example_v6.mdx + - page: disjoint_unsynchronized_pool_resource + path: pages/thrust/group_member_example_v6.mdx + - section: libcudacxx (v5 - Hand Curated) contents: - page: resource_with - path: pages/libcudacxx/concept_example_v3.mdx + path: pages/libcudacxx/concept_example_v5.mdx - page: counting_iterator - path: pages/libcudacxx/deep_template_class_v4.mdx + path: pages/libcudacxx/deep_template_class_v5.mdx - page: buffer - path: pages/libcudacxx/empty_docstring_class_v4.mdx + path: pages/libcudacxx/empty_docstring_class_v5.mdx - page: stream - path: pages/libcudacxx/raises_example_v4.mdx + path: pages/libcudacxx/raises_example_v5.mdx + - section: libcudacxx (v6 - Renderer Output) + contents: + - page: resource_with + path: pages/libcudacxx/concept_example_v6.mdx + - page: counting_iterator + path: pages/libcudacxx/deep_template_class_v6.mdx + - page: buffer + path: pages/libcudacxx/empty_docstring_class_v6.mdx + - page: stream + path: pages/libcudacxx/raises_example_v6.mdx - tab: Library Reference layout: - library: nemo-rl @@ -151,7 +184,600 @@ navigation: - tab: LangChain Core Reference layout: - library: langchain-core - + - tab: C++ Full Reference + layout: + - section: CUB + contents: + - page: AgentAdjacentDifferencePolicy + path: cudapages/cub/cub/cub/AgentAdjacentDifferencePolicy.mdx + - page: AgentHistogramPolicy + path: cudapages/cub/cub/cub/AgentHistogramPolicy.mdx + - page: AgentMergeSortPolicy + path: cudapages/cub/cub/cub/AgentMergeSortPolicy.mdx + - page: AgentRadixSortDownsweepPolicy + path: cudapages/cub/cub/cub/AgentRadixSortDownsweepPolicy.mdx + - page: AgentRadixSortExclusiveSumPolicy + path: cudapages/cub/cub/cub/AgentRadixSortExclusiveSumPolicy.mdx + - page: AgentRadixSortHistogramPolicy + path: cudapages/cub/cub/cub/AgentRadixSortHistogramPolicy.mdx + - page: AgentRadixSortOnesweepPolicy + path: cudapages/cub/cub/cub/AgentRadixSortOnesweepPolicy.mdx + - page: AgentRadixSortUpsweepPolicy + path: cudapages/cub/cub/cub/AgentRadixSortUpsweepPolicy.mdx + - page: AgentReduceByKeyPolicy + path: cudapages/cub/cub/cub/AgentReduceByKeyPolicy.mdx + - page: AgentReducePolicy + path: cudapages/cub/cub/cub/AgentReducePolicy.mdx + - page: AgentRlePolicy + path: cudapages/cub/cub/cub/AgentRlePolicy.mdx + - page: AgentScanByKeyPolicy + path: cudapages/cub/cub/cub/AgentScanByKeyPolicy.mdx + - page: AgentScanPolicy + path: cudapages/cub/cub/cub/AgentScanPolicy.mdx + - page: AgentSelectIfPolicy + path: cudapages/cub/cub/cub/AgentSelectIfPolicy.mdx + - page: AgentSubWarpMergeSortPolicy + path: cudapages/cub/cub/cub/AgentSubWarpMergeSortPolicy.mdx + - page: AgentThreeWayPartitionPolicy + path: cudapages/cub/cub/cub/AgentThreeWayPartitionPolicy.mdx + - page: AgentUniqueByKeyPolicy + path: cudapages/cub/cub/cub/AgentUniqueByKeyPolicy.mdx + - page: AgentWarpReducePolicy + path: cudapages/cub/cub/cub/AgentWarpReducePolicy.mdx + - page: ArgIndexInputIterator + path: cudapages/cub/cub/cub/ArgIndexInputIterator.mdx + - page: ArgMax + path: cudapages/cub/cub/cub/ArgMax.mdx + - page: ArgMin + path: cudapages/cub/cub/cub/ArgMin.mdx + - page: BFEDigitExtractor + path: cudapages/cub/cub/cub/BFEDigitExtractor.mdx + - page: BaseDigitExtractor + path: cudapages/cub/cub/cub/BaseDigitExtractor.mdx + - page: BaseDigitExtractor_KeyT_true + path: cudapages/cub/cub/cub/BaseDigitExtractor_KeyT_true.mdx + - page: BlockAdjacentDifference + path: cudapages/cub/cub/cub/BlockAdjacentDifference.mdx + - page: BlockDiscontinuity + path: cudapages/cub/cub/cub/BlockDiscontinuity.mdx + - page: BlockExchange + path: cudapages/cub/cub/cub/BlockExchange.mdx + - page: BlockHistogram + path: cudapages/cub/cub/cub/BlockHistogram.mdx + - page: BlockLoad + path: cudapages/cub/cub/cub/BlockLoad.mdx + - page: BlockLoadType + path: cudapages/cub/cub/cub/BlockLoadType.mdx + - page: BlockMergeSort + path: cudapages/cub/cub/cub/BlockMergeSort.mdx + - page: BlockMergeSortStrategy + path: cudapages/cub/cub/cub/BlockMergeSortStrategy.mdx + - page: BlockRadixRank + path: cudapages/cub/cub/cub/BlockRadixRank.mdx + - page: BlockRadixRankEmptyCallback + path: cudapages/cub/cub/cub/BlockRadixRankEmptyCallback.mdx + - page: BlockRadixRankMatch + path: cudapages/cub/cub/cub/BlockRadixRankMatch.mdx + - page: BlockRadixRankMatchEarlyCounts + path: cudapages/cub/cub/cub/BlockRadixRankMatchEarlyCounts.mdx + - page: BlockRadixSort + path: cudapages/cub/cub/cub/BlockRadixSort.mdx + - page: BlockRakingLayout + path: cudapages/cub/cub/cub/BlockRakingLayout.mdx + - page: BlockReduce + path: cudapages/cub/cub/cub/BlockReduce.mdx + - page: BlockRunLengthDecode + path: cudapages/cub/cub/cub/BlockRunLengthDecode.mdx + - page: BlockScan + path: cudapages/cub/cub/cub/BlockScan.mdx + - page: BlockScanRunningPrefixOp + path: cudapages/cub/cub/cub/BlockScanRunningPrefixOp.mdx + - page: BlockShuffle + path: cudapages/cub/cub/cub/BlockShuffle.mdx + - page: BlockStore + path: cudapages/cub/cub/cub/BlockStore.mdx + - page: CacheModifiedInputIterator + path: cudapages/cub/cub/cub/CacheModifiedInputIterator.mdx + - page: CacheModifiedOutputIterator + path: cudapages/cub/cub/cub/CacheModifiedOutputIterator.mdx + - page: CachingDeviceAllocator + path: cudapages/cub/cub/cub/CachingDeviceAllocator.mdx + - page: CastOp + path: cudapages/cub/cub/cub/CastOp.mdx + - page: ChainedPolicy + path: cudapages/cub/cub/cub/ChainedPolicy.mdx + - page: DeviceAdjacentDifference + path: cudapages/cub/cub/cub/DeviceAdjacentDifference.mdx + - page: DeviceCopy + path: cudapages/cub/cub/cub/DeviceCopy.mdx + - page: DeviceFind + path: cudapages/cub/cub/cub/DeviceFind.mdx + - page: DeviceFor + path: cudapages/cub/cub/cub/DeviceFor.mdx + - page: DeviceHistogram + path: cudapages/cub/cub/cub/DeviceHistogram.mdx + - page: DeviceMemcpy + path: cudapages/cub/cub/cub/DeviceMemcpy.mdx + - page: DeviceMerge + path: cudapages/cub/cub/cub/DeviceMerge.mdx + - page: DeviceMergeSort + path: cudapages/cub/cub/cub/DeviceMergeSort.mdx + - page: DevicePartition + path: cudapages/cub/cub/cub/DevicePartition.mdx + - page: DeviceRadixSort + path: cudapages/cub/cub/cub/DeviceRadixSort.mdx + - page: DeviceReduce + path: cudapages/cub/cub/cub/DeviceReduce.mdx + - page: DeviceRleDispatch + path: cudapages/cub/cub/cub/DeviceRleDispatch.mdx + - page: DeviceRunLengthEncode + path: cudapages/cub/cub/cub/DeviceRunLengthEncode.mdx + - page: DeviceScan + path: cudapages/cub/cub/cub/DeviceScan.mdx + - page: DeviceSegmentedRadixSort + path: cudapages/cub/cub/cub/DeviceSegmentedRadixSort.mdx + - page: DeviceSegmentedReduce + path: cudapages/cub/cub/cub/DeviceSegmentedReduce.mdx + - page: DeviceSegmentedScan + path: cudapages/cub/cub/cub/DeviceSegmentedScan.mdx + - page: DeviceSegmentedSort + path: cudapages/cub/cub/cub/DeviceSegmentedSort.mdx + - page: DeviceSelect + path: cudapages/cub/cub/cub/DeviceSelect.mdx + - page: DeviceTopK + path: cudapages/cub/cub/cub/DeviceTopK.mdx + - page: DeviceTransform + path: cudapages/cub/cub/cub/DeviceTransform.mdx + - page: DispatchAdjacentDifference + path: cudapages/cub/cub/cub/DispatchAdjacentDifference.mdx + - page: DispatchHistogram + path: cudapages/cub/cub/cub/DispatchHistogram.mdx + - page: DispatchMergeSort + path: cudapages/cub/cub/cub/DispatchMergeSort.mdx + - page: DispatchRadixSort + path: cudapages/cub/cub/cub/DispatchRadixSort.mdx + - page: DispatchReduce + path: cudapages/cub/cub/cub/DispatchReduce.mdx + - page: DispatchReduceByKey + path: cudapages/cub/cub/cub/DispatchReduceByKey.mdx + - page: DispatchScan + path: cudapages/cub/cub/cub/DispatchScan.mdx + - page: DispatchScanByKey + path: cudapages/cub/cub/cub/DispatchScanByKey.mdx + - page: DispatchSegmentedRadixSort + path: cudapages/cub/cub/cub/DispatchSegmentedRadixSort.mdx + - page: DispatchSegmentedReduce + path: cudapages/cub/cub/cub/DispatchSegmentedReduce.mdx + - page: DispatchSegmentedSort + path: cudapages/cub/cub/cub/DispatchSegmentedSort.mdx + - page: DispatchSelectIf + path: cudapages/cub/cub/cub/DispatchSelectIf.mdx + - page: DispatchThreeWayPartitionIf + path: cudapages/cub/cub/cub/DispatchThreeWayPartitionIf.mdx + - page: DispatchUniqueByKey + path: cudapages/cub/cub/cub/DispatchUniqueByKey.mdx + - page: GridEvenShare + path: cudapages/cub/cub/cub/GridEvenShare.mdx + - page: GridQueue + path: cudapages/cub/cub/cub/GridQueue.mdx + - page: InequalityWrapper + path: cudapages/cub/cub/cub/InequalityWrapper.mdx + - page: PtxVersionCacheTag + path: cudapages/cub/cub/cub/PtxVersionCacheTag.mdx + - page: RadixSortTwiddle + path: cudapages/cub/cub/cub/RadixSortTwiddle.mdx + - page: ReduceByKeyOp + path: cudapages/cub/cub/cub/ReduceByKeyOp.mdx + - page: ReduceByKeyScanTileState + path: cudapages/cub/cub/cub/ReduceByKeyScanTileState.mdx + - page: ReduceByKeyScanTileState_ValueT_KeyT_false + path: cudapages/cub/cub/cub/ReduceByKeyScanTileState_ValueT_KeyT_false.mdx + - page: ReduceBySegmentOp + path: cudapages/cub/cub/cub/ReduceBySegmentOp.mdx + - page: ScanTileState + path: cudapages/cub/cub/cub/ScanTileState.mdx + - page: ScanTileState_T_false + path: cudapages/cub/cub/cub/ScanTileState_T_false.mdx + - page: ShiftDigitExtractor + path: cudapages/cub/cub/cub/ShiftDigitExtractor.mdx + - page: SmVersionCacheTag + path: cudapages/cub/cub/cub/SmVersionCacheTag.mdx + - page: SwizzleScanOp + path: cudapages/cub/cub/cub/SwizzleScanOp.mdx + - page: TilePrefixCallbackOp + path: cudapages/cub/cub/cub/TilePrefixCallbackOp.mdx + - page: WarpExchange + path: cudapages/cub/cub/cub/WarpExchange.mdx + - page: WarpLoad + path: cudapages/cub/cub/cub/WarpLoad.mdx + - page: WarpMergeSort + path: cudapages/cub/cub/cub/WarpMergeSort.mdx + - page: WarpReduce + path: cudapages/cub/cub/cub/WarpReduce.mdx + - page: WarpScan + path: cudapages/cub/cub/cub/WarpScan.mdx + - page: WarpStore + path: cudapages/cub/cub/cub/WarpStore.mdx + - section: Thrust + contents: + - page: allocator_delete + path: cudapages/thrust/thrust/thrust/allocator_delete.mdx + - page: array_allocator_delete + path: cudapages/thrust/thrust/thrust/array_allocator_delete.mdx + - page: bidirectional_device_iterator_tag + path: cudapages/thrust/thrust/thrust/bidirectional_device_iterator_tag.mdx + - page: bidirectional_traversal_tag + path: cudapages/thrust/thrust/thrust/bidirectional_traversal_tag.mdx + - page: compile_time_value + path: cudapages/thrust/thrust/thrust/compile_time_value.mdx + - page: complex + path: cudapages/thrust/thrust/thrust/complex.mdx + - page: constant_iterator + path: cudapages/thrust/thrust/thrust/constant_iterator.mdx + - page: counting_iterator + path: cudapages/thrust/thrust/thrust/counting_iterator.mdx + - page: device_allocator + path: cudapages/thrust/thrust/thrust/device_allocator.mdx + - page: device_execution_policy + path: cudapages/thrust/thrust/thrust/device_execution_policy.mdx + - page: device_malloc_allocator + path: cudapages/thrust/thrust/thrust/device_malloc_allocator.mdx + - page: device_new_allocator + path: cudapages/thrust/thrust/thrust/device_new_allocator.mdx + - page: device_ptr + path: cudapages/thrust/thrust/thrust/device_ptr.mdx + - page: device_ptr_memory_resource + path: cudapages/thrust/thrust/thrust/device_ptr_memory_resource.mdx + - page: device_reference + path: cudapages/thrust/thrust/thrust/device_reference.mdx + - page: device_vector + path: cudapages/thrust/thrust/thrust/device_vector.mdx + - page: discard_block_engine + path: cudapages/thrust/thrust/thrust/discard_block_engine.mdx + - page: discard_iterator + path: cudapages/thrust/thrust/thrust/discard_iterator.mdx + - page: error_category + path: cudapages/thrust/thrust/thrust/error_category.mdx + - page: error_code + path: cudapages/thrust/thrust/thrust/error_code.mdx + - page: error_condition + path: cudapages/thrust/thrust/thrust/error_condition.mdx + - page: forward_device_iterator_tag + path: cudapages/thrust/thrust/thrust/forward_device_iterator_tag.mdx + - page: forward_traversal_tag + path: cudapages/thrust/thrust/thrust/forward_traversal_tag.mdx + - page: host_execution_policy + path: cudapages/thrust/thrust/thrust/host_execution_policy.mdx + - page: host_vector + path: cudapages/thrust/thrust/thrust/host_vector.mdx + - page: incrementable_traversal_tag + path: cudapages/thrust/thrust/thrust/incrementable_traversal_tag.mdx + - page: input_device_iterator_tag + path: cudapages/thrust/thrust/thrust/input_device_iterator_tag.mdx + - page: is_error_code_enum + path: cudapages/thrust/thrust/thrust/is_error_code_enum.mdx + - page: is_error_condition_enum + path: cudapages/thrust/thrust/thrust/is_error_condition_enum.mdx + - page: iterator_adaptor + path: cudapages/thrust/thrust/thrust/iterator_adaptor.mdx + - page: iterator_core_access + path: cudapages/thrust/thrust/thrust/iterator_core_access.mdx + - page: iterator_difference + path: cudapages/thrust/thrust/thrust/iterator_difference.mdx + - page: iterator_facade + path: cudapages/thrust/thrust/thrust/iterator_facade.mdx + - page: iterator_pointer + path: cudapages/thrust/thrust/thrust/iterator_pointer.mdx + - page: iterator_reference + path: cudapages/thrust/thrust/thrust/iterator_reference.mdx + - page: iterator_system + path: cudapages/thrust/thrust/thrust/iterator_system.mdx + - page: iterator_system_const_void_ptr + path: cudapages/thrust/thrust/thrust/iterator_system_const_void_ptr.mdx + - page: iterator_system_cudaconstant_iterator_T_Index + path: cudapages/thrust/thrust/thrust/iterator_system_cudaconstant_iterator_T_Index.mdx + - page: iterator_system_cudacounting_iterator_Start + path: cudapages/thrust/thrust/thrust/iterator_system_cudacounting_iterator_Start.mdx + - page: iterator_system_cudadiscard_iterator + path: cudapages/thrust/thrust/thrust/iterator_system_cudadiscard_iterator.mdx + - page: iterator_system_cudapermutation_iterator_Iter_Offset + path: cudapages/thrust/thrust/thrust/iterator_system_cudapermutation_iterator_Iter_Offset.mdx + - page: iterator_system_cudashuffle_iterator_IndexType_Bijection + path: cudapages/thrust/thrust/thrust/iterator_system_cudashuffle_iterator_IndexType_Bijection.mdx + - page: iterator_system_cudastdreverse_iterator_Iter + path: cudapages/thrust/thrust/thrust/iterator_system_cudastdreverse_iterator_Iter.mdx + - page: iterator_system_cudastrided_iterator_Iter_Stride + path: cudapages/thrust/thrust/thrust/iterator_system_cudastrided_iterator_Iter_Stride.mdx + - page: iterator_system_cudatabulate_output_iterator_Fn_Index + path: cudapages/thrust/thrust/thrust/iterator_system_cudatabulate_output_iterator_Fn_Index.mdx + - page: iterator_system_cudatransform_input_output_iterator_InputFn_OutputFn_Iter + path: cudapages/thrust/thrust/thrust/iterator_system_cudatransform_input_output_iterator_InputFn_OutputFn_Iter.mdx + - page: iterator_system_cudatransform_iterator_Fn_Iter + path: cudapages/thrust/thrust/thrust/iterator_system_cudatransform_iterator_Fn_Iter.mdx + - page: iterator_system_cudatransform_output_iterator_Fn_Iter + path: cudapages/thrust/thrust/thrust/iterator_system_cudatransform_output_iterator_Fn_Iter.mdx + - page: iterator_system_cudazip_iterator_Iterators + path: cudapages/thrust/thrust/thrust/iterator_system_cudazip_iterator_Iterators.mdx + - page: iterator_system_cudazip_transform_iterator_Fn_Iterators + path: cudapages/thrust/thrust/thrust/iterator_system_cudazip_transform_iterator_Fn_Iterators.mdx + - page: iterator_system_void_ptr + path: cudapages/thrust/thrust/thrust/iterator_system_void_ptr.mdx + - page: iterator_traversal + path: cudapages/thrust/thrust/thrust/iterator_traversal.mdx + - page: iterator_traversal_cudaconstant_iterator_T_Index + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudaconstant_iterator_T_Index.mdx + - page: iterator_traversal_cudacounting_iterator_Start + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudacounting_iterator_Start.mdx + - page: iterator_traversal_cudadiscard_iterator + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudadiscard_iterator.mdx + - page: iterator_traversal_cudapermutation_iterator_Iter_Offset + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudapermutation_iterator_Iter_Offset.mdx + - page: iterator_traversal_cudashuffle_iterator_IndexType_Bijection + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudashuffle_iterator_IndexType_Bijection.mdx + - page: iterator_traversal_cudastdreverse_iterator_Iter + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudastdreverse_iterator_Iter.mdx + - page: iterator_traversal_cudastrided_iterator_Iter_Stride + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudastrided_iterator_Iter_Stride.mdx + - page: iterator_traversal_cudatabulate_output_iterator_Fn_Index + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudatabulate_output_iterator_Fn_Index.mdx + - page: iterator_traversal_cudatransform_input_output_iterator_InputFn_OutputFn_Iter + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_input_output_iterator_InputFn_OutputFn_Iter.mdx + - page: iterator_traversal_cudatransform_iterator_Fn_Iter + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_iterator_Fn_Iter.mdx + - page: iterator_traversal_cudatransform_output_iterator_Fn_Iter + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudatransform_output_iterator_Fn_Iter.mdx + - page: iterator_traversal_cudazip_iterator_Iterators + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudazip_iterator_Iterators.mdx + - page: iterator_traversal_cudazip_transform_iterator_Fn_Iterators + path: cudapages/thrust/thrust/thrust/iterator_traversal_cudazip_transform_iterator_Fn_Iterators.mdx + - page: iterator_value + path: cudapages/thrust/thrust/thrust/iterator_value.mdx + - page: linear_congruential_engine + path: cudapages/thrust/thrust/thrust/linear_congruential_engine.mdx + - page: linear_feedback_shift_engine + path: cudapages/thrust/thrust/thrust/linear_feedback_shift_engine.mdx + - page: no_traversal_tag + path: cudapages/thrust/thrust/thrust/no_traversal_tag.mdx + - page: normal_distribution + path: cudapages/thrust/thrust/thrust/normal_distribution.mdx + - page: offset_iterator + path: cudapages/thrust/thrust/thrust/offset_iterator.mdx + - page: output_device_iterator_tag + path: cudapages/thrust/thrust/thrust/output_device_iterator_tag.mdx + - page: per_device_allocator + path: cudapages/thrust/thrust/thrust/per_device_allocator.mdx + - page: permutation_iterator + path: cudapages/thrust/thrust/thrust/permutation_iterator.mdx + - page: pointer + path: cudapages/thrust/thrust/thrust/pointer.mdx + - page: proclaim_contiguous_iterator + path: cudapages/thrust/thrust/thrust/proclaim_contiguous_iterator.mdx + - page: project1st + path: cudapages/thrust/thrust/thrust/project1st.mdx + - page: project1st_void_void + path: cudapages/thrust/thrust/thrust/project1st_void_void.mdx + - page: project2nd + path: cudapages/thrust/thrust/thrust/project2nd.mdx + - page: project2nd_void_void + path: cudapages/thrust/thrust/thrust/project2nd_void_void.mdx + - page: random_access_device_iterator_tag + path: cudapages/thrust/thrust/thrust/random_access_device_iterator_tag.mdx + - page: random_access_traversal_tag + path: cudapages/thrust/thrust/thrust/random_access_traversal_tag.mdx + - page: runtime_value + path: cudapages/thrust/thrust/thrust/runtime_value.mdx + - page: shuffle_iterator + path: cudapages/thrust/thrust/thrust/shuffle_iterator.mdx + - page: single_pass_traversal_tag + path: cudapages/thrust/thrust/thrust/single_pass_traversal_tag.mdx + - page: square + path: cudapages/thrust/thrust/thrust/square.mdx + - page: square_void + path: cudapages/thrust/thrust/thrust/square_void.mdx + - page: strided_iterator + path: cudapages/thrust/thrust/thrust/strided_iterator.mdx + - page: subtract_with_carry_engine + path: cudapages/thrust/thrust/thrust/subtract_with_carry_engine.mdx + - page: system_error + path: cudapages/thrust/thrust/thrust/system_error.mdx + - page: tabulate_output_iterator + path: cudapages/thrust/thrust/thrust/tabulate_output_iterator.mdx + - page: tagged_deleter + path: cudapages/thrust/thrust/thrust/tagged_deleter.mdx + - page: transform_input_output_iterator + path: cudapages/thrust/thrust/thrust/transform_input_output_iterator.mdx + - page: transform_iterator + path: cudapages/thrust/thrust/thrust/transform_iterator.mdx + - page: transform_output_iterator + path: cudapages/thrust/thrust/thrust/transform_output_iterator.mdx + - page: uniform_int_distribution + path: cudapages/thrust/thrust/thrust/uniform_int_distribution.mdx + - page: uniform_real_distribution + path: cudapages/thrust/thrust/thrust/uniform_real_distribution.mdx + - page: xor_combine_engine + path: cudapages/thrust/thrust/thrust/xor_combine_engine.mdx + - page: zip_function + path: cudapages/thrust/thrust/thrust/zip_function.mdx + - page: zip_iterator + path: cudapages/thrust/thrust/thrust/zip_iterator.mdx + - section: Thrust / mr + contents: + - page: allocator + path: cudapages/thrust/thrust/thrust/mr/allocator.mdx + - page: disjoint_synchronized_pool_resource + path: cudapages/thrust/thrust/thrust/mr/disjoint_synchronized_pool_resource.mdx + - page: disjoint_unsynchronized_pool_resource + path: cudapages/thrust/thrust/thrust/mr/disjoint_unsynchronized_pool_resource.mdx + - page: fancy_pointer_resource + path: cudapages/thrust/thrust/thrust/mr/fancy_pointer_resource.mdx + - page: memory_resource + path: cudapages/thrust/thrust/thrust/mr/memory_resource.mdx + - page: memory_resource_void_ptr + path: cudapages/thrust/thrust/thrust/mr/memory_resource_void_ptr.mdx + - page: new_delete_resource + path: cudapages/thrust/thrust/thrust/mr/new_delete_resource.mdx + - page: new_delete_resource_base + path: cudapages/thrust/thrust/thrust/mr/new_delete_resource_base.mdx + - page: polymorphic_adaptor_resource + path: cudapages/thrust/thrust/thrust/mr/polymorphic_adaptor_resource.mdx + - page: pool_options + path: cudapages/thrust/thrust/thrust/mr/pool_options.mdx + - page: stateless_resource_allocator + path: cudapages/thrust/thrust/thrust/mr/stateless_resource_allocator.mdx + - page: synchronized_pool_resource + path: cudapages/thrust/thrust/thrust/mr/synchronized_pool_resource.mdx + - page: unsynchronized_pool_resource + path: cudapages/thrust/thrust/thrust/mr/unsynchronized_pool_resource.mdx + - page: validator + path: cudapages/thrust/thrust/thrust/mr/validator.mdx + - page: validator2 + path: cudapages/thrust/thrust/thrust/mr/validator2.mdx + - page: validator2_T_T + path: cudapages/thrust/thrust/thrust/mr/validator2_T_T.mdx + - section: Thrust / random + contents: + - page: discard_block_engine + path: cudapages/thrust/thrust/thrust/random/discard_block_engine.mdx + - page: linear_congruential_engine + path: cudapages/thrust/thrust/thrust/random/linear_congruential_engine.mdx + - page: linear_feedback_shift_engine + path: cudapages/thrust/thrust/thrust/random/linear_feedback_shift_engine.mdx + - page: normal_distribution + path: cudapages/thrust/thrust/thrust/random/normal_distribution.mdx + - page: subtract_with_carry_engine + path: cudapages/thrust/thrust/thrust/random/subtract_with_carry_engine.mdx + - page: uniform_int_distribution + path: cudapages/thrust/thrust/thrust/random/uniform_int_distribution.mdx + - page: uniform_real_distribution + path: cudapages/thrust/thrust/thrust/random/uniform_real_distribution.mdx + - page: xor_combine_engine + path: cudapages/thrust/thrust/thrust/random/xor_combine_engine.mdx + - section: Thrust / system + contents: + - page: error_category + path: cudapages/thrust/thrust/thrust/system/error_category.mdx + - page: error_code + path: cudapages/thrust/thrust/thrust/system/error_code.mdx + - page: error_condition + path: cudapages/thrust/thrust/thrust/system/error_condition.mdx + - page: is_error_code_enum + path: cudapages/thrust/thrust/thrust/system/is_error_code_enum.mdx + - page: is_error_code_enum_cudaerrcerrc_t + path: cudapages/thrust/thrust/thrust/system/is_error_code_enum_cudaerrcerrc_t.mdx + - page: is_error_condition_enum + path: cudapages/thrust/thrust/thrust/system/is_error_condition_enum.mdx + - page: is_error_condition_enum_errcerrc_t + path: cudapages/thrust/thrust/thrust/system/is_error_condition_enum_errcerrc_t.mdx + - page: system_error + path: cudapages/thrust/thrust/thrust/system/system_error.mdx + - section: libcudacxx + contents: + - page: arch_traits_t + path: cudapages/cuda/cuda/cuda/arch_traits_t.mdx + - page: buffer + path: cudapages/cuda/cuda/cuda/buffer.mdx + - page: compute_capability + path: cudapages/cuda/cuda/cuda/compute_capability.mdx + - page: constant_iterator + path: cudapages/cuda/cuda/cuda/constant_iterator.mdx + - page: copy_configuration + path: cudapages/cuda/cuda/cuda/copy_configuration.mdx + - page: counting_iterator + path: cudapages/cuda/cuda/cuda/counting_iterator.mdx + - page: device_memory_pool + path: cudapages/cuda/cuda/cuda/device_memory_pool.mdx + - page: device_memory_pool_ref + path: cudapages/cuda/cuda/cuda/device_memory_pool_ref.mdx + - page: device_ref + path: cudapages/cuda/cuda/cuda/device_ref.mdx + - page: discard_iterator + path: cudapages/cuda/cuda/cuda/discard_iterator.mdx + - page: event + path: cudapages/cuda/cuda/cuda/event.mdx + - page: event_ref + path: cudapages/cuda/cuda/cuda/event_ref.mdx + - page: get_stream_t + path: cudapages/cuda/cuda/cuda/get_stream_t.mdx + - page: has_property + path: cudapages/cuda/cuda/cuda/has_property.mdx + - page: has_property_with + path: cudapages/cuda/cuda/cuda/has_property_with.mdx + - page: heterogeneous_iterator + path: cudapages/cuda/cuda/cuda/heterogeneous_iterator.mdx + - page: managed_memory_pool + path: cudapages/cuda/cuda/cuda/managed_memory_pool.mdx + - page: managed_memory_pool_ref + path: cudapages/cuda/cuda/cuda/managed_memory_pool_ref.mdx + - page: memory_pool_properties + path: cudapages/cuda/cuda/cuda/memory_pool_properties.mdx + - page: permutation_iterator + path: cudapages/cuda/cuda/cuda/permutation_iterator.mdx + - page: pinned_memory_pool + path: cudapages/cuda/cuda/cuda/pinned_memory_pool.mdx + - page: pinned_memory_pool_ref + path: cudapages/cuda/cuda/cuda/pinned_memory_pool_ref.mdx + - page: property_with_value + path: cudapages/cuda/cuda/cuda/property_with_value.mdx + - page: shuffle_iterator + path: cudapages/cuda/cuda/cuda/shuffle_iterator.mdx + - page: stream + path: cudapages/cuda/cuda/cuda/stream.mdx + - page: stream_ref + path: cudapages/cuda/cuda/cuda/stream_ref.mdx + - page: strided_iterator + path: cudapages/cuda/cuda/cuda/strided_iterator.mdx + - page: tabulate_output_iterator + path: cudapages/cuda/cuda/cuda/tabulate_output_iterator.mdx + - page: timed_event + path: cudapages/cuda/cuda/cuda/timed_event.mdx + - page: transform_input_output_iterator + path: cudapages/cuda/cuda/cuda/transform_input_output_iterator.mdx + - page: transform_iterator + path: cudapages/cuda/cuda/cuda/transform_iterator.mdx + - page: transform_output_iterator + path: cudapages/cuda/cuda/cuda/transform_output_iterator.mdx + - page: zip_function + path: cudapages/cuda/cuda/cuda/zip_function.mdx + - page: zip_iterator + path: cudapages/cuda/cuda/cuda/zip_iterator.mdx + - page: zip_transform_iterator + path: cudapages/cuda/cuda/cuda/zip_transform_iterator.mdx + - section: libcudacxx / device_attributes + contents: + - page: compute_capability_t + path: cudapages/cuda/cuda/cuda/device_attributes/compute_capability_t.mdx + - section: libcudacxx / mr + contents: + - page: basic_any_resource + path: cudapages/cuda/cuda/cuda/mr/basic_any_resource.mdx + - page: basic_resource_ref + path: cudapages/cuda/cuda/cuda/mr/basic_resource_ref.mdx + - page: device_accessible + path: cudapages/cuda/cuda/cuda/mr/device_accessible.mdx + - page: host_accessible + path: cudapages/cuda/cuda/cuda/mr/host_accessible.mdx + - page: legacy_managed_memory_resource + path: cudapages/cuda/cuda/cuda/mr/legacy_managed_memory_resource.mdx + - page: legacy_pinned_memory_resource + path: cudapages/cuda/cuda/cuda/mr/legacy_pinned_memory_resource.mdx + - page: properties_list + path: cudapages/cuda/cuda/cuda/mr/properties_list.mdx + - page: resource + path: cudapages/cuda/cuda/cuda/mr/resource.mdx + - page: resource_with + path: cudapages/cuda/cuda/cuda/mr/resource_with.mdx + - page: shared_resource + path: cudapages/cuda/cuda/cuda/mr/shared_resource.mdx + - page: synchronous_resource + path: cudapages/cuda/cuda/cuda/mr/synchronous_resource.mdx + - page: synchronous_resource_adapter + path: cudapages/cuda/cuda/cuda/mr/synchronous_resource_adapter.mdx + - page: synchronous_resource_with + path: cudapages/cuda/cuda/cuda/mr/synchronous_resource_with.mdx + - section: libcudacxx / std + contents: + - page: pointer_traits + path: cudapages/cuda/cuda/cuda/std/pointer_traits.mdx navbar-links: - type: minimal diff --git a/fern/docs/pages/nominal-data-model.mdx b/fern/docs/pages/nominal-data-model.mdx new file mode 100644 index 0000000..bfea6df --- /dev/null +++ b/fern/docs/pages/nominal-data-model.mdx @@ -0,0 +1,150 @@ +--- +title: Set up your data in Nominal +sidebar-title: Get started +description: Example recipes of using Nominal's data model for specific applications. +slug: nominal-data-model +--- + +Learn how to set up your data for performant data ingestion in Nominal. + +## Assess your testing phase + +When setting up your data in Nominal, the recommended structure depends on the type and phase of testing you are performing: + + + + If you need a historical record for long-lasting hardware systems, use this approach. + + Use cases: + * Test each complete drone's performance with a new software version. + * Test different maneuvers of an aircraft under edge cases to verify handling difficult situations. + + Get started with [System-level testing](#system-level-testing). + + + Ensure software behaves correctly on physical hardware. + + Use cases: + * Automated development workflow where software versions are continuously tested on hardware to prevent regressions. + * Running 1,000+ simulations simultaneously where similar scenarios are run against different software versions. + + Get started with [Large-scale simulation testing](#large-scale-simulation-testing). + + + +## Structure your data + +Add data to Nominal depending on the type of testing you are performing. + +### System-level testing + +This example shows how to test a drone by consolidating telemetry files from multiple sources into an asset: + + + + Use assets in Nominal as the organizing unit for ingesting, labeling, and analyzing time-synchronized data. + + + + 1. From the **Assets** page, create a new asset for the drone by clicking the **Create asset** button in the top right of the page. + 1. Add a name to identify your asset. In this example, we use `Drone S001` + 1. Add any other identifiers for filtering or sorting by these identifiers. For example: + + | Asset Name | Properties | Labels | + | ------------- | ---------------------------------------- | --------------------------------- | + | Drone Alpha | `serial_num: QC-001`, `firmware: v2.3.1` | `quadcopter`, `prototype` | + | Drone Bravo | `serial_num: QC-002`, `firmware: v2.3.1` | `quadcopter`, `prototype` | + | Drone Charlie | `serial_num: QC-003`, `firmware: v2.3.0` | `quadcopter`, `outdoor-certified` | + | Drone Delta | `serial_num: QC-004`, `firmware: v2.3.1` | `quadcopter`, `outdoor-certified` | + | Drone Echo | `serial_num: QC-005`, `firmware: v2.3.1` | `quadcopter`, `prototype` | + 1. Click **Create asset** to save your asset. + + + + + Upload files, attach videos, or stream live telemetry directly to this dataset for ongoing management. + + + + Centralize your drone data by creating a dataset that holds all your drone data. + + + In this example: + * The `dron-s001-data-drone.csv` file includes the data from the drone. + * The `dron-s001-data-remote-control.csv` file includes the data from the remote control. + + To upload both files to the same dataset without losing context where data is coming from: + 1. In the dataset, upload data from the drone by clicking **Upload files** and upload `dron-s001-data-drone.csv`. + 1. Make sure the timestamp information is correct. + 1. To identify the source of the data, add a tag to the file: + 1. Click **Add tags**. + 1. Create a `source` tag with the value `drone`. + + This allows you to filter the data by source. + 1. Click **Upload**. + 1. Repeat the process for the `dron-s001-data-remote-control.csv` file, and create a `source` tag with the value `remote`. + 1. Verify data is uploaded correctly by checking the following in your dataset: + 1. Click on the **Channels** tab, and verify the channels from your CSV file populated into the dataset. + 1. Click on the **Recent tags** tab, and verify the tag key and tag value matches expectations. + + Now you can filter by source to switch between data from the remote control and the drone. + + + + 1. Open your asset page and click on the **Data sources** tab. + 1. Add data from the drone to the asset: + 1. Search existing dataset, select `ASSET_NAME dataset` dataset. + 1. Use the `drone` refname to identify the source of data. + + Specifying a ref name identifies the source when there's overlapping signal names in a dataset. For example, if both your `remote` and `drone` computers log a channel called `altitude`, the ref name creates fully-qualified channel names like `remote.altitude` and `drone.altitude`. + 1. Add tag filter `drone`to filter the data from the dataset to just data from the drone. + 1. Add data from the remote control to the asset: + 1. Search existing dataset, select `ASSET_NAME dataset` dataset. + 1. Use refname `remote` to identify the source of data. + 1. Add tag filter `remote` to filter the data from the dataset to just data from the remote control. + + Now all files that you add to the asset uses these ref names and filters automatically. + + + + + + Now that your data is set up, you can set up livestreaming for your data or upload one-off data to your asset: + + + + + Programmatic setups allow scripts to send data to Nominal from any source (examples: NAS devices, DAQ-connected computers, or cloud lambdas). Whether streaming or batch processing, these scripts automatically tag data with metadata and asset identifiers based on your organization's logic. + + This approach is most effective when data formats and locations remain stable, allowing ingest scripts to persist even as test configurations evolve. + + + Use manual uploads with the Nominal app for one-off loads, exploratory testing, or early-stage prototypes where data formats are still evolving: + + 1. Upload files to the asset's dataset with tags. + 1. Make sure the correct data source tag applies to all files before uploading. + + + + + + + +### Additional examples + +See the following for additional examples: + + + + +For monitoring live data from operational assets to Nominal, it's recommended to setup a single dataset where all hardware devices stream to with unique asset tags. Follow the same steps under **Offline field testing** (Setting up the asset) to create separate Nominal Assets with its appropriate asset tag filter. + + + + + +Since multiple simulation runs can run in parallel (often in relative time), we'll want to ingest data from each simulation with a unique `sim_id` that can be used to filter the run with only data from the appropriate simulation run. + + + diff --git a/fern/docs/pages/steps-toc-test.mdx b/fern/docs/pages/steps-toc-test.mdx new file mode 100644 index 0000000..1e48ad3 --- /dev/null +++ b/fern/docs/pages/steps-toc-test.mdx @@ -0,0 +1,92 @@ +--- +title: Steps TOC Depth Test +subtitle: Testing that Steps respect heading rank in the table of contents +slug: steps-toc-test +--- + +This page tests that the TOC sidebar correctly nests `` based on the heading level used inside them, rather than always hardcoding depth 3. + +## Section A: H2 headings inside Steps + +The steps below use `##` headings. With the fix, they should appear at **depth 2** in the TOC (same level as "Section A" above). Before the fix, they'd incorrectly appear at depth 3. + + + +## Install dependencies + +Run the following command: + +```bash +npm install my-package +``` + +## Configure the client + +Create a configuration file: + +```typescript +const client = new Client({ apiKey: "your-key" }); +``` + +## Make your first request + +Send a test request: + +```typescript +const response = await client.get("/health"); +console.log(response.status); +``` + + + +## Section B: H3 headings inside Steps (default behavior) + +These steps use `###` headings — the original default. They should appear at **depth 3** in the TOC, nesting under "Section B". + + + +### Set up authentication + +Configure your API keys in the dashboard. + +### Create a webhook + +Register a webhook endpoint to receive events. + +### Test the integration + +Verify everything works end-to-end. + + + +## Section C: Explicit Step components (no headings) + +These use the explicit `` syntax without heading-based migration. TOC depth should default to 3 (unchanged behavior). + + + + Sign up at the dashboard. + + + Navigate to Settings and create a key. + + + Use the SDK to make your first call. + + + +## Section D: Steps without toc (control group) + +These steps do **not** have `toc={true}`, so they should NOT appear in the TOC at all. + + + +### Hidden step 1 + +This should not be in the TOC. + +### Hidden step 2 + +Neither should this. + + diff --git a/fern/library-docs/langchain-core-docs/_navigation.yml b/fern/library-docs/langchain-core-docs/_navigation.yml new file mode 100644 index 0000000..9617342 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/_navigation.yml @@ -0,0 +1,1305 @@ +# AUTO-GENERATED by `fern docs md generate` — DO NOT EDIT +- type: section + title: _api + slug: langchain-core/langchain_core/_api + children: + - type: section + title: beta_decorator + slug: langchain-core/langchain_core/_api/beta_decorator + children: + - type: page + title: beta_decorator + slug: langchain-core/langchain_core/_api/beta_decorator + pageId: langchain-core/langchain_core/_api/beta_decorator.mdx + - type: section + title: deprecation + slug: langchain-core/langchain_core/_api/deprecation + children: + - type: page + title: deprecation + slug: langchain-core/langchain_core/_api/deprecation + pageId: langchain-core/langchain_core/_api/deprecation.mdx + - type: section + title: internal + slug: langchain-core/langchain_core/_api/internal + children: + - type: page + title: internal + slug: langchain-core/langchain_core/_api/internal + pageId: langchain-core/langchain_core/_api/internal.mdx + - type: section + title: path + slug: langchain-core/langchain_core/_api/path + children: + - type: page + title: path + slug: langchain-core/langchain_core/_api/path + pageId: langchain-core/langchain_core/_api/path.mdx +- type: section + title: _import_utils + slug: langchain-core/langchain_core/_import_utils + children: + - type: page + title: _import_utils + slug: langchain-core/langchain_core/_import_utils + pageId: langchain-core/langchain_core/_import_utils.mdx +- type: section + title: _security + slug: langchain-core/langchain_core/_security + children: + - type: section + title: _ssrf_protection + slug: langchain-core/langchain_core/_security/_ssrf_protection + children: + - type: page + title: _ssrf_protection + slug: langchain-core/langchain_core/_security/_ssrf_protection + pageId: langchain-core/langchain_core/_security/_ssrf_protection.mdx +- type: section + title: agents + slug: langchain-core/langchain_core/agents + children: + - type: page + title: agents + slug: langchain-core/langchain_core/agents + pageId: langchain-core/langchain_core/agents.mdx +- type: section + title: caches + slug: langchain-core/langchain_core/caches + children: + - type: page + title: caches + slug: langchain-core/langchain_core/caches + pageId: langchain-core/langchain_core/caches.mdx +- type: section + title: callbacks + slug: langchain-core/langchain_core/callbacks + children: + - type: section + title: base + slug: langchain-core/langchain_core/callbacks/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/callbacks/base + pageId: langchain-core/langchain_core/callbacks/base.mdx + - type: section + title: file + slug: langchain-core/langchain_core/callbacks/file + children: + - type: page + title: file + slug: langchain-core/langchain_core/callbacks/file + pageId: langchain-core/langchain_core/callbacks/file.mdx + - type: section + title: manager + slug: langchain-core/langchain_core/callbacks/manager + children: + - type: page + title: manager + slug: langchain-core/langchain_core/callbacks/manager + pageId: langchain-core/langchain_core/callbacks/manager.mdx + - type: section + title: stdout + slug: langchain-core/langchain_core/callbacks/stdout + children: + - type: page + title: stdout + slug: langchain-core/langchain_core/callbacks/stdout + pageId: langchain-core/langchain_core/callbacks/stdout.mdx + - type: section + title: streaming_stdout + slug: langchain-core/langchain_core/callbacks/streaming_stdout + children: + - type: page + title: streaming_stdout + slug: langchain-core/langchain_core/callbacks/streaming_stdout + pageId: langchain-core/langchain_core/callbacks/streaming_stdout.mdx + - type: section + title: usage + slug: langchain-core/langchain_core/callbacks/usage + children: + - type: page + title: usage + slug: langchain-core/langchain_core/callbacks/usage + pageId: langchain-core/langchain_core/callbacks/usage.mdx +- type: section + title: chat_history + slug: langchain-core/langchain_core/chat_history + children: + - type: page + title: chat_history + slug: langchain-core/langchain_core/chat_history + pageId: langchain-core/langchain_core/chat_history.mdx +- type: section + title: chat_loaders + slug: langchain-core/langchain_core/chat_loaders + children: + - type: page + title: chat_loaders + slug: langchain-core/langchain_core/chat_loaders + pageId: langchain-core/langchain_core/chat_loaders.mdx +- type: section + title: chat_sessions + slug: langchain-core/langchain_core/chat_sessions + children: + - type: page + title: chat_sessions + slug: langchain-core/langchain_core/chat_sessions + pageId: langchain-core/langchain_core/chat_sessions.mdx +- type: section + title: document_loaders + slug: langchain-core/langchain_core/document_loaders + children: + - type: section + title: base + slug: langchain-core/langchain_core/document_loaders/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/document_loaders/base + pageId: langchain-core/langchain_core/document_loaders/base.mdx + - type: section + title: blob_loaders + slug: langchain-core/langchain_core/document_loaders/blob_loaders + children: + - type: page + title: blob_loaders + slug: langchain-core/langchain_core/document_loaders/blob_loaders + pageId: langchain-core/langchain_core/document_loaders/blob_loaders.mdx + - type: section + title: langsmith + slug: langchain-core/langchain_core/document_loaders/langsmith + children: + - type: page + title: langsmith + slug: langchain-core/langchain_core/document_loaders/langsmith + pageId: langchain-core/langchain_core/document_loaders/langsmith.mdx +- type: section + title: documents + slug: langchain-core/langchain_core/documents + children: + - type: section + title: base + slug: langchain-core/langchain_core/documents/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/documents/base + pageId: langchain-core/langchain_core/documents/base.mdx + - type: section + title: compressor + slug: langchain-core/langchain_core/documents/compressor + children: + - type: page + title: compressor + slug: langchain-core/langchain_core/documents/compressor + pageId: langchain-core/langchain_core/documents/compressor.mdx + - type: section + title: transformers + slug: langchain-core/langchain_core/documents/transformers + children: + - type: page + title: transformers + slug: langchain-core/langchain_core/documents/transformers + pageId: langchain-core/langchain_core/documents/transformers.mdx +- type: section + title: embeddings + slug: langchain-core/langchain_core/embeddings + children: + - type: section + title: embeddings + slug: langchain-core/langchain_core/embeddings/embeddings + children: + - type: page + title: embeddings + slug: langchain-core/langchain_core/embeddings/embeddings + pageId: langchain-core/langchain_core/embeddings/embeddings.mdx + - type: section + title: fake + slug: langchain-core/langchain_core/embeddings/fake + children: + - type: page + title: fake + slug: langchain-core/langchain_core/embeddings/fake + pageId: langchain-core/langchain_core/embeddings/fake.mdx +- type: section + title: env + slug: langchain-core/langchain_core/env + children: + - type: page + title: env + slug: langchain-core/langchain_core/env + pageId: langchain-core/langchain_core/env.mdx +- type: section + title: example_selectors + slug: langchain-core/langchain_core/example_selectors + children: + - type: section + title: base + slug: langchain-core/langchain_core/example_selectors/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/example_selectors/base + pageId: langchain-core/langchain_core/example_selectors/base.mdx + - type: section + title: length_based + slug: langchain-core/langchain_core/example_selectors/length_based + children: + - type: page + title: length_based + slug: langchain-core/langchain_core/example_selectors/length_based + pageId: langchain-core/langchain_core/example_selectors/length_based.mdx + - type: section + title: semantic_similarity + slug: langchain-core/langchain_core/example_selectors/semantic_similarity + children: + - type: page + title: semantic_similarity + slug: langchain-core/langchain_core/example_selectors/semantic_similarity + pageId: langchain-core/langchain_core/example_selectors/semantic_similarity.mdx +- type: section + title: exceptions + slug: langchain-core/langchain_core/exceptions + children: + - type: page + title: exceptions + slug: langchain-core/langchain_core/exceptions + pageId: langchain-core/langchain_core/exceptions.mdx +- type: section + title: globals + slug: langchain-core/langchain_core/globals + children: + - type: page + title: globals + slug: langchain-core/langchain_core/globals + pageId: langchain-core/langchain_core/globals.mdx +- type: section + title: indexing + slug: langchain-core/langchain_core/indexing + children: + - type: section + title: api + slug: langchain-core/langchain_core/indexing/api + children: + - type: page + title: api + slug: langchain-core/langchain_core/indexing/api + pageId: langchain-core/langchain_core/indexing/api.mdx + - type: section + title: base + slug: langchain-core/langchain_core/indexing/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/indexing/base + pageId: langchain-core/langchain_core/indexing/base.mdx + - type: section + title: in_memory + slug: langchain-core/langchain_core/indexing/in_memory + children: + - type: page + title: in_memory + slug: langchain-core/langchain_core/indexing/in_memory + pageId: langchain-core/langchain_core/indexing/in_memory.mdx +- type: section + title: language_models + slug: langchain-core/langchain_core/language_models + children: + - type: section + title: _utils + slug: langchain-core/langchain_core/language_models/_utils + children: + - type: page + title: _utils + slug: langchain-core/langchain_core/language_models/_utils + pageId: langchain-core/langchain_core/language_models/_utils.mdx + - type: section + title: base + slug: langchain-core/langchain_core/language_models/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/language_models/base + pageId: langchain-core/langchain_core/language_models/base.mdx + - type: section + title: chat_models + slug: langchain-core/langchain_core/language_models/chat_models + children: + - type: page + title: chat_models + slug: langchain-core/langchain_core/language_models/chat_models + pageId: langchain-core/langchain_core/language_models/chat_models.mdx + - type: section + title: fake + slug: langchain-core/langchain_core/language_models/fake + children: + - type: page + title: fake + slug: langchain-core/langchain_core/language_models/fake + pageId: langchain-core/langchain_core/language_models/fake.mdx + - type: section + title: fake_chat_models + slug: langchain-core/langchain_core/language_models/fake_chat_models + children: + - type: page + title: fake_chat_models + slug: langchain-core/langchain_core/language_models/fake_chat_models + pageId: langchain-core/langchain_core/language_models/fake_chat_models.mdx + - type: section + title: llms + slug: langchain-core/langchain_core/language_models/llms + children: + - type: page + title: llms + slug: langchain-core/langchain_core/language_models/llms + pageId: langchain-core/langchain_core/language_models/llms.mdx + - type: section + title: model_profile + slug: langchain-core/langchain_core/language_models/model_profile + children: + - type: page + title: model_profile + slug: langchain-core/langchain_core/language_models/model_profile + pageId: langchain-core/langchain_core/language_models/model_profile.mdx +- type: section + title: load + slug: langchain-core/langchain_core/load + children: + - type: section + title: _validation + slug: langchain-core/langchain_core/load/_validation + children: + - type: page + title: _validation + slug: langchain-core/langchain_core/load/_validation + pageId: langchain-core/langchain_core/load/_validation.mdx + - type: section + title: dump + slug: langchain-core/langchain_core/load/dump + children: + - type: page + title: dump + slug: langchain-core/langchain_core/load/dump + pageId: langchain-core/langchain_core/load/dump.mdx + - type: section + title: load + slug: langchain-core/langchain_core/load/load + children: + - type: page + title: load + slug: langchain-core/langchain_core/load/load + pageId: langchain-core/langchain_core/load/load.mdx + - type: section + title: mapping + slug: langchain-core/langchain_core/load/mapping + children: + - type: page + title: mapping + slug: langchain-core/langchain_core/load/mapping + pageId: langchain-core/langchain_core/load/mapping.mdx + - type: section + title: serializable + slug: langchain-core/langchain_core/load/serializable + children: + - type: page + title: serializable + slug: langchain-core/langchain_core/load/serializable + pageId: langchain-core/langchain_core/load/serializable.mdx +- type: section + title: messages + slug: langchain-core/langchain_core/messages + children: + - type: section + title: ai + slug: langchain-core/langchain_core/messages/ai + children: + - type: page + title: ai + slug: langchain-core/langchain_core/messages/ai + pageId: langchain-core/langchain_core/messages/ai.mdx + - type: section + title: base + slug: langchain-core/langchain_core/messages/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/messages/base + pageId: langchain-core/langchain_core/messages/base.mdx + - type: section + title: block_translators + slug: langchain-core/langchain_core/messages/block_translators + children: + - type: section + title: anthropic + slug: langchain-core/langchain_core/messages/block_translators/anthropic + children: + - type: page + title: anthropic + slug: langchain-core/langchain_core/messages/block_translators/anthropic + pageId: langchain-core/langchain_core/messages/block_translators/anthropic.mdx + - type: section + title: bedrock + slug: langchain-core/langchain_core/messages/block_translators/bedrock + children: + - type: page + title: bedrock + slug: langchain-core/langchain_core/messages/block_translators/bedrock + pageId: langchain-core/langchain_core/messages/block_translators/bedrock.mdx + - type: section + title: bedrock_converse + slug: langchain-core/langchain_core/messages/block_translators/bedrock_converse + children: + - type: page + title: bedrock_converse + slug: langchain-core/langchain_core/messages/block_translators/bedrock_converse + pageId: langchain-core/langchain_core/messages/block_translators/bedrock_converse.mdx + - type: section + title: google_genai + slug: langchain-core/langchain_core/messages/block_translators/google_genai + children: + - type: page + title: google_genai + slug: langchain-core/langchain_core/messages/block_translators/google_genai + pageId: langchain-core/langchain_core/messages/block_translators/google_genai.mdx + - type: section + title: google_vertexai + slug: langchain-core/langchain_core/messages/block_translators/google_vertexai + children: + - type: page + title: google_vertexai + slug: langchain-core/langchain_core/messages/block_translators/google_vertexai + pageId: langchain-core/langchain_core/messages/block_translators/google_vertexai.mdx + - type: section + title: groq + slug: langchain-core/langchain_core/messages/block_translators/groq + children: + - type: page + title: groq + slug: langchain-core/langchain_core/messages/block_translators/groq + pageId: langchain-core/langchain_core/messages/block_translators/groq.mdx + - type: section + title: langchain_v0 + slug: langchain-core/langchain_core/messages/block_translators/langchain_v0 + children: + - type: page + title: langchain_v0 + slug: langchain-core/langchain_core/messages/block_translators/langchain_v0 + pageId: langchain-core/langchain_core/messages/block_translators/langchain_v0.mdx + - type: section + title: openai + slug: langchain-core/langchain_core/messages/block_translators/openai + children: + - type: page + title: openai + slug: langchain-core/langchain_core/messages/block_translators/openai + pageId: langchain-core/langchain_core/messages/block_translators/openai.mdx + - type: section + title: chat + slug: langchain-core/langchain_core/messages/chat + children: + - type: page + title: chat + slug: langchain-core/langchain_core/messages/chat + pageId: langchain-core/langchain_core/messages/chat.mdx + - type: section + title: content + slug: langchain-core/langchain_core/messages/content + children: + - type: page + title: content + slug: langchain-core/langchain_core/messages/content + pageId: langchain-core/langchain_core/messages/content.mdx + - type: section + title: function + slug: langchain-core/langchain_core/messages/function + children: + - type: page + title: function + slug: langchain-core/langchain_core/messages/function + pageId: langchain-core/langchain_core/messages/function.mdx + - type: section + title: human + slug: langchain-core/langchain_core/messages/human + children: + - type: page + title: human + slug: langchain-core/langchain_core/messages/human + pageId: langchain-core/langchain_core/messages/human.mdx + - type: section + title: modifier + slug: langchain-core/langchain_core/messages/modifier + children: + - type: page + title: modifier + slug: langchain-core/langchain_core/messages/modifier + pageId: langchain-core/langchain_core/messages/modifier.mdx + - type: section + title: system + slug: langchain-core/langchain_core/messages/system + children: + - type: page + title: system + slug: langchain-core/langchain_core/messages/system + pageId: langchain-core/langchain_core/messages/system.mdx + - type: section + title: tool + slug: langchain-core/langchain_core/messages/tool + children: + - type: page + title: tool + slug: langchain-core/langchain_core/messages/tool + pageId: langchain-core/langchain_core/messages/tool.mdx + - type: section + title: utils + slug: langchain-core/langchain_core/messages/utils + children: + - type: page + title: utils + slug: langchain-core/langchain_core/messages/utils + pageId: langchain-core/langchain_core/messages/utils.mdx +- type: section + title: output_parsers + slug: langchain-core/langchain_core/output_parsers + children: + - type: section + title: base + slug: langchain-core/langchain_core/output_parsers/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/output_parsers/base + pageId: langchain-core/langchain_core/output_parsers/base.mdx + - type: section + title: format_instructions + slug: langchain-core/langchain_core/output_parsers/format_instructions + children: + - type: page + title: format_instructions + slug: langchain-core/langchain_core/output_parsers/format_instructions + pageId: langchain-core/langchain_core/output_parsers/format_instructions.mdx + - type: section + title: json + slug: langchain-core/langchain_core/output_parsers/json + children: + - type: page + title: json + slug: langchain-core/langchain_core/output_parsers/json + pageId: langchain-core/langchain_core/output_parsers/json.mdx + - type: section + title: list + slug: langchain-core/langchain_core/output_parsers/list + children: + - type: page + title: list + slug: langchain-core/langchain_core/output_parsers/list + pageId: langchain-core/langchain_core/output_parsers/list.mdx + - type: section + title: openai_functions + slug: langchain-core/langchain_core/output_parsers/openai_functions + children: + - type: page + title: openai_functions + slug: langchain-core/langchain_core/output_parsers/openai_functions + pageId: langchain-core/langchain_core/output_parsers/openai_functions.mdx + - type: section + title: openai_tools + slug: langchain-core/langchain_core/output_parsers/openai_tools + children: + - type: page + title: openai_tools + slug: langchain-core/langchain_core/output_parsers/openai_tools + pageId: langchain-core/langchain_core/output_parsers/openai_tools.mdx + - type: section + title: pydantic + slug: langchain-core/langchain_core/output_parsers/pydantic + children: + - type: page + title: pydantic + slug: langchain-core/langchain_core/output_parsers/pydantic + pageId: langchain-core/langchain_core/output_parsers/pydantic.mdx + - type: section + title: string + slug: langchain-core/langchain_core/output_parsers/string + children: + - type: page + title: string + slug: langchain-core/langchain_core/output_parsers/string + pageId: langchain-core/langchain_core/output_parsers/string.mdx + - type: section + title: transform + slug: langchain-core/langchain_core/output_parsers/transform + children: + - type: page + title: transform + slug: langchain-core/langchain_core/output_parsers/transform + pageId: langchain-core/langchain_core/output_parsers/transform.mdx + - type: section + title: xml + slug: langchain-core/langchain_core/output_parsers/xml + children: + - type: page + title: xml + slug: langchain-core/langchain_core/output_parsers/xml + pageId: langchain-core/langchain_core/output_parsers/xml.mdx +- type: section + title: outputs + slug: langchain-core/langchain_core/outputs + children: + - type: section + title: chat_generation + slug: langchain-core/langchain_core/outputs/chat_generation + children: + - type: page + title: chat_generation + slug: langchain-core/langchain_core/outputs/chat_generation + pageId: langchain-core/langchain_core/outputs/chat_generation.mdx + - type: section + title: chat_result + slug: langchain-core/langchain_core/outputs/chat_result + children: + - type: page + title: chat_result + slug: langchain-core/langchain_core/outputs/chat_result + pageId: langchain-core/langchain_core/outputs/chat_result.mdx + - type: section + title: generation + slug: langchain-core/langchain_core/outputs/generation + children: + - type: page + title: generation + slug: langchain-core/langchain_core/outputs/generation + pageId: langchain-core/langchain_core/outputs/generation.mdx + - type: section + title: llm_result + slug: langchain-core/langchain_core/outputs/llm_result + children: + - type: page + title: llm_result + slug: langchain-core/langchain_core/outputs/llm_result + pageId: langchain-core/langchain_core/outputs/llm_result.mdx + - type: section + title: run_info + slug: langchain-core/langchain_core/outputs/run_info + children: + - type: page + title: run_info + slug: langchain-core/langchain_core/outputs/run_info + pageId: langchain-core/langchain_core/outputs/run_info.mdx +- type: section + title: prompt_values + slug: langchain-core/langchain_core/prompt_values + children: + - type: page + title: prompt_values + slug: langchain-core/langchain_core/prompt_values + pageId: langchain-core/langchain_core/prompt_values.mdx +- type: section + title: prompts + slug: langchain-core/langchain_core/prompts + children: + - type: section + title: base + slug: langchain-core/langchain_core/prompts/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/prompts/base + pageId: langchain-core/langchain_core/prompts/base.mdx + - type: section + title: chat + slug: langchain-core/langchain_core/prompts/chat + children: + - type: page + title: chat + slug: langchain-core/langchain_core/prompts/chat + pageId: langchain-core/langchain_core/prompts/chat.mdx + - type: section + title: dict + slug: langchain-core/langchain_core/prompts/dict + children: + - type: page + title: dict + slug: langchain-core/langchain_core/prompts/dict + pageId: langchain-core/langchain_core/prompts/dict.mdx + - type: section + title: few_shot + slug: langchain-core/langchain_core/prompts/few_shot + children: + - type: page + title: few_shot + slug: langchain-core/langchain_core/prompts/few_shot + pageId: langchain-core/langchain_core/prompts/few_shot.mdx + - type: section + title: few_shot_with_templates + slug: langchain-core/langchain_core/prompts/few_shot_with_templates + children: + - type: page + title: few_shot_with_templates + slug: langchain-core/langchain_core/prompts/few_shot_with_templates + pageId: langchain-core/langchain_core/prompts/few_shot_with_templates.mdx + - type: section + title: image + slug: langchain-core/langchain_core/prompts/image + children: + - type: page + title: image + slug: langchain-core/langchain_core/prompts/image + pageId: langchain-core/langchain_core/prompts/image.mdx + - type: section + title: loading + slug: langchain-core/langchain_core/prompts/loading + children: + - type: page + title: loading + slug: langchain-core/langchain_core/prompts/loading + pageId: langchain-core/langchain_core/prompts/loading.mdx + - type: section + title: message + slug: langchain-core/langchain_core/prompts/message + children: + - type: page + title: message + slug: langchain-core/langchain_core/prompts/message + pageId: langchain-core/langchain_core/prompts/message.mdx + - type: section + title: prompt + slug: langchain-core/langchain_core/prompts/prompt + children: + - type: page + title: prompt + slug: langchain-core/langchain_core/prompts/prompt + pageId: langchain-core/langchain_core/prompts/prompt.mdx + - type: section + title: string + slug: langchain-core/langchain_core/prompts/string + children: + - type: page + title: string + slug: langchain-core/langchain_core/prompts/string + pageId: langchain-core/langchain_core/prompts/string.mdx + - type: section + title: structured + slug: langchain-core/langchain_core/prompts/structured + children: + - type: page + title: structured + slug: langchain-core/langchain_core/prompts/structured + pageId: langchain-core/langchain_core/prompts/structured.mdx +- type: section + title: rate_limiters + slug: langchain-core/langchain_core/rate_limiters + children: + - type: page + title: rate_limiters + slug: langchain-core/langchain_core/rate_limiters + pageId: langchain-core/langchain_core/rate_limiters.mdx +- type: section + title: retrievers + slug: langchain-core/langchain_core/retrievers + children: + - type: page + title: retrievers + slug: langchain-core/langchain_core/retrievers + pageId: langchain-core/langchain_core/retrievers.mdx +- type: section + title: runnables + slug: langchain-core/langchain_core/runnables + children: + - type: section + title: base + slug: langchain-core/langchain_core/runnables/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/runnables/base + pageId: langchain-core/langchain_core/runnables/base.mdx + - type: section + title: branch + slug: langchain-core/langchain_core/runnables/branch + children: + - type: page + title: branch + slug: langchain-core/langchain_core/runnables/branch + pageId: langchain-core/langchain_core/runnables/branch.mdx + - type: section + title: config + slug: langchain-core/langchain_core/runnables/config + children: + - type: page + title: config + slug: langchain-core/langchain_core/runnables/config + pageId: langchain-core/langchain_core/runnables/config.mdx + - type: section + title: configurable + slug: langchain-core/langchain_core/runnables/configurable + children: + - type: page + title: configurable + slug: langchain-core/langchain_core/runnables/configurable + pageId: langchain-core/langchain_core/runnables/configurable.mdx + - type: section + title: fallbacks + slug: langchain-core/langchain_core/runnables/fallbacks + children: + - type: page + title: fallbacks + slug: langchain-core/langchain_core/runnables/fallbacks + pageId: langchain-core/langchain_core/runnables/fallbacks.mdx + - type: section + title: graph + slug: langchain-core/langchain_core/runnables/graph + children: + - type: page + title: graph + slug: langchain-core/langchain_core/runnables/graph + pageId: langchain-core/langchain_core/runnables/graph.mdx + - type: section + title: graph_ascii + slug: langchain-core/langchain_core/runnables/graph_ascii + children: + - type: page + title: graph_ascii + slug: langchain-core/langchain_core/runnables/graph_ascii + pageId: langchain-core/langchain_core/runnables/graph_ascii.mdx + - type: section + title: graph_mermaid + slug: langchain-core/langchain_core/runnables/graph_mermaid + children: + - type: page + title: graph_mermaid + slug: langchain-core/langchain_core/runnables/graph_mermaid + pageId: langchain-core/langchain_core/runnables/graph_mermaid.mdx + - type: section + title: graph_png + slug: langchain-core/langchain_core/runnables/graph_png + children: + - type: page + title: graph_png + slug: langchain-core/langchain_core/runnables/graph_png + pageId: langchain-core/langchain_core/runnables/graph_png.mdx + - type: section + title: history + slug: langchain-core/langchain_core/runnables/history + children: + - type: page + title: history + slug: langchain-core/langchain_core/runnables/history + pageId: langchain-core/langchain_core/runnables/history.mdx + - type: section + title: passthrough + slug: langchain-core/langchain_core/runnables/passthrough + children: + - type: page + title: passthrough + slug: langchain-core/langchain_core/runnables/passthrough + pageId: langchain-core/langchain_core/runnables/passthrough.mdx + - type: section + title: retry + slug: langchain-core/langchain_core/runnables/retry + children: + - type: page + title: retry + slug: langchain-core/langchain_core/runnables/retry + pageId: langchain-core/langchain_core/runnables/retry.mdx + - type: section + title: router + slug: langchain-core/langchain_core/runnables/router + children: + - type: page + title: router + slug: langchain-core/langchain_core/runnables/router + pageId: langchain-core/langchain_core/runnables/router.mdx + - type: section + title: schema + slug: langchain-core/langchain_core/runnables/schema + children: + - type: page + title: schema + slug: langchain-core/langchain_core/runnables/schema + pageId: langchain-core/langchain_core/runnables/schema.mdx + - type: section + title: utils + slug: langchain-core/langchain_core/runnables/utils + children: + - type: page + title: utils + slug: langchain-core/langchain_core/runnables/utils + pageId: langchain-core/langchain_core/runnables/utils.mdx +- type: section + title: stores + slug: langchain-core/langchain_core/stores + children: + - type: page + title: stores + slug: langchain-core/langchain_core/stores + pageId: langchain-core/langchain_core/stores.mdx +- type: section + title: structured_query + slug: langchain-core/langchain_core/structured_query + children: + - type: page + title: structured_query + slug: langchain-core/langchain_core/structured_query + pageId: langchain-core/langchain_core/structured_query.mdx +- type: section + title: sys_info + slug: langchain-core/langchain_core/sys_info + children: + - type: page + title: sys_info + slug: langchain-core/langchain_core/sys_info + pageId: langchain-core/langchain_core/sys_info.mdx +- type: section + title: tools + slug: langchain-core/langchain_core/tools + children: + - type: section + title: base + slug: langchain-core/langchain_core/tools/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/tools/base + pageId: langchain-core/langchain_core/tools/base.mdx + - type: section + title: convert + slug: langchain-core/langchain_core/tools/convert + children: + - type: page + title: convert + slug: langchain-core/langchain_core/tools/convert + pageId: langchain-core/langchain_core/tools/convert.mdx + - type: section + title: render + slug: langchain-core/langchain_core/tools/render + children: + - type: page + title: render + slug: langchain-core/langchain_core/tools/render + pageId: langchain-core/langchain_core/tools/render.mdx + - type: section + title: retriever + slug: langchain-core/langchain_core/tools/retriever + children: + - type: page + title: retriever + slug: langchain-core/langchain_core/tools/retriever + pageId: langchain-core/langchain_core/tools/retriever.mdx + - type: section + title: simple + slug: langchain-core/langchain_core/tools/simple + children: + - type: page + title: simple + slug: langchain-core/langchain_core/tools/simple + pageId: langchain-core/langchain_core/tools/simple.mdx + - type: section + title: structured + slug: langchain-core/langchain_core/tools/structured + children: + - type: page + title: structured + slug: langchain-core/langchain_core/tools/structured + pageId: langchain-core/langchain_core/tools/structured.mdx +- type: section + title: tracers + slug: langchain-core/langchain_core/tracers + children: + - type: section + title: _compat + slug: langchain-core/langchain_core/tracers/_compat + children: + - type: page + title: _compat + slug: langchain-core/langchain_core/tracers/_compat + pageId: langchain-core/langchain_core/tracers/_compat.mdx + - type: section + title: _streaming + slug: langchain-core/langchain_core/tracers/_streaming + children: + - type: page + title: _streaming + slug: langchain-core/langchain_core/tracers/_streaming + pageId: langchain-core/langchain_core/tracers/_streaming.mdx + - type: section + title: base + slug: langchain-core/langchain_core/tracers/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/tracers/base + pageId: langchain-core/langchain_core/tracers/base.mdx + - type: section + title: context + slug: langchain-core/langchain_core/tracers/context + children: + - type: page + title: context + slug: langchain-core/langchain_core/tracers/context + pageId: langchain-core/langchain_core/tracers/context.mdx + - type: section + title: core + slug: langchain-core/langchain_core/tracers/core + children: + - type: page + title: core + slug: langchain-core/langchain_core/tracers/core + pageId: langchain-core/langchain_core/tracers/core.mdx + - type: section + title: evaluation + slug: langchain-core/langchain_core/tracers/evaluation + children: + - type: page + title: evaluation + slug: langchain-core/langchain_core/tracers/evaluation + pageId: langchain-core/langchain_core/tracers/evaluation.mdx + - type: section + title: event_stream + slug: langchain-core/langchain_core/tracers/event_stream + children: + - type: page + title: event_stream + slug: langchain-core/langchain_core/tracers/event_stream + pageId: langchain-core/langchain_core/tracers/event_stream.mdx + - type: section + title: langchain + slug: langchain-core/langchain_core/tracers/langchain + children: + - type: page + title: langchain + slug: langchain-core/langchain_core/tracers/langchain + pageId: langchain-core/langchain_core/tracers/langchain.mdx + - type: section + title: log_stream + slug: langchain-core/langchain_core/tracers/log_stream + children: + - type: page + title: log_stream + slug: langchain-core/langchain_core/tracers/log_stream + pageId: langchain-core/langchain_core/tracers/log_stream.mdx + - type: section + title: memory_stream + slug: langchain-core/langchain_core/tracers/memory_stream + children: + - type: page + title: memory_stream + slug: langchain-core/langchain_core/tracers/memory_stream + pageId: langchain-core/langchain_core/tracers/memory_stream.mdx + - type: section + title: root_listeners + slug: langchain-core/langchain_core/tracers/root_listeners + children: + - type: page + title: root_listeners + slug: langchain-core/langchain_core/tracers/root_listeners + pageId: langchain-core/langchain_core/tracers/root_listeners.mdx + - type: section + title: run_collector + slug: langchain-core/langchain_core/tracers/run_collector + children: + - type: page + title: run_collector + slug: langchain-core/langchain_core/tracers/run_collector + pageId: langchain-core/langchain_core/tracers/run_collector.mdx + - type: section + title: schemas + slug: langchain-core/langchain_core/tracers/schemas + children: + - type: page + title: schemas + slug: langchain-core/langchain_core/tracers/schemas + pageId: langchain-core/langchain_core/tracers/schemas.mdx + - type: section + title: stdout + slug: langchain-core/langchain_core/tracers/stdout + children: + - type: page + title: stdout + slug: langchain-core/langchain_core/tracers/stdout + pageId: langchain-core/langchain_core/tracers/stdout.mdx +- type: section + title: utils + slug: langchain-core/langchain_core/utils + children: + - type: section + title: _merge + slug: langchain-core/langchain_core/utils/_merge + children: + - type: page + title: _merge + slug: langchain-core/langchain_core/utils/_merge + pageId: langchain-core/langchain_core/utils/_merge.mdx + - type: section + title: aiter + slug: langchain-core/langchain_core/utils/aiter + children: + - type: page + title: aiter + slug: langchain-core/langchain_core/utils/aiter + pageId: langchain-core/langchain_core/utils/aiter.mdx + - type: section + title: env + slug: langchain-core/langchain_core/utils/env + children: + - type: page + title: env + slug: langchain-core/langchain_core/utils/env + pageId: langchain-core/langchain_core/utils/env.mdx + - type: section + title: formatting + slug: langchain-core/langchain_core/utils/formatting + children: + - type: page + title: formatting + slug: langchain-core/langchain_core/utils/formatting + pageId: langchain-core/langchain_core/utils/formatting.mdx + - type: section + title: function_calling + slug: langchain-core/langchain_core/utils/function_calling + children: + - type: page + title: function_calling + slug: langchain-core/langchain_core/utils/function_calling + pageId: langchain-core/langchain_core/utils/function_calling.mdx + - type: section + title: html + slug: langchain-core/langchain_core/utils/html + children: + - type: page + title: html + slug: langchain-core/langchain_core/utils/html + pageId: langchain-core/langchain_core/utils/html.mdx + - type: section + title: image + slug: langchain-core/langchain_core/utils/image + children: + - type: page + title: image + slug: langchain-core/langchain_core/utils/image + pageId: langchain-core/langchain_core/utils/image.mdx + - type: section + title: input + slug: langchain-core/langchain_core/utils/input + children: + - type: page + title: input + slug: langchain-core/langchain_core/utils/input + pageId: langchain-core/langchain_core/utils/input.mdx + - type: section + title: interactive_env + slug: langchain-core/langchain_core/utils/interactive_env + children: + - type: page + title: interactive_env + slug: langchain-core/langchain_core/utils/interactive_env + pageId: langchain-core/langchain_core/utils/interactive_env.mdx + - type: section + title: iter + slug: langchain-core/langchain_core/utils/iter + children: + - type: page + title: iter + slug: langchain-core/langchain_core/utils/iter + pageId: langchain-core/langchain_core/utils/iter.mdx + - type: section + title: json + slug: langchain-core/langchain_core/utils/json + children: + - type: page + title: json + slug: langchain-core/langchain_core/utils/json + pageId: langchain-core/langchain_core/utils/json.mdx + - type: section + title: json_schema + slug: langchain-core/langchain_core/utils/json_schema + children: + - type: page + title: json_schema + slug: langchain-core/langchain_core/utils/json_schema + pageId: langchain-core/langchain_core/utils/json_schema.mdx + - type: section + title: mustache + slug: langchain-core/langchain_core/utils/mustache + children: + - type: page + title: mustache + slug: langchain-core/langchain_core/utils/mustache + pageId: langchain-core/langchain_core/utils/mustache.mdx + - type: section + title: pydantic + slug: langchain-core/langchain_core/utils/pydantic + children: + - type: page + title: pydantic + slug: langchain-core/langchain_core/utils/pydantic + pageId: langchain-core/langchain_core/utils/pydantic.mdx + - type: section + title: strings + slug: langchain-core/langchain_core/utils/strings + children: + - type: page + title: strings + slug: langchain-core/langchain_core/utils/strings + pageId: langchain-core/langchain_core/utils/strings.mdx + - type: section + title: usage + slug: langchain-core/langchain_core/utils/usage + children: + - type: page + title: usage + slug: langchain-core/langchain_core/utils/usage + pageId: langchain-core/langchain_core/utils/usage.mdx + - type: section + title: utils + slug: langchain-core/langchain_core/utils/utils + children: + - type: page + title: utils + slug: langchain-core/langchain_core/utils/utils + pageId: langchain-core/langchain_core/utils/utils.mdx + - type: section + title: uuid + slug: langchain-core/langchain_core/utils/uuid + children: + - type: page + title: uuid + slug: langchain-core/langchain_core/utils/uuid + pageId: langchain-core/langchain_core/utils/uuid.mdx +- type: section + title: vectorstores + slug: langchain-core/langchain_core/vectorstores + children: + - type: section + title: base + slug: langchain-core/langchain_core/vectorstores/base + children: + - type: page + title: base + slug: langchain-core/langchain_core/vectorstores/base + pageId: langchain-core/langchain_core/vectorstores/base.mdx + - type: section + title: in_memory + slug: langchain-core/langchain_core/vectorstores/in_memory + children: + - type: page + title: in_memory + slug: langchain-core/langchain_core/vectorstores/in_memory + pageId: langchain-core/langchain_core/vectorstores/in_memory.mdx + - type: section + title: utils + slug: langchain-core/langchain_core/vectorstores/utils + children: + - type: page + title: utils + slug: langchain-core/langchain_core/vectorstores/utils + pageId: langchain-core/langchain_core/vectorstores/utils.mdx +- type: section + title: version + slug: langchain-core/langchain_core/version + children: + - type: page + title: version + slug: langchain-core/langchain_core/version + pageId: langchain-core/langchain_core/version.mdx diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core.mdx new file mode 100644 index 0000000..ffe1b40 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core.mdx @@ -0,0 +1,75 @@ +--- +layout: overview +slug: langchain-core/langchain_core +title: langchain_core +--- + +`langchain-core` defines the base abstractions for the LangChain ecosystem. + +The interfaces for core components like chat models, LLMs, vector stores, retrievers, +and more are defined here. The universal invocation protocol (Runnables) along with +a syntax for combining components are also defined here. + +**No third-party integrations are defined here.** The dependencies are kept purposefully +very lightweight. + +## Subpackages + +- **[`langchain_core._api`](/langchain-core/langchain_core/_api)** +- **[`langchain_core._security`](/langchain-core/langchain_core/_security)** +- **[`langchain_core.callbacks`](/langchain-core/langchain_core/callbacks)** +- **[`langchain_core.document_loaders`](/langchain-core/langchain_core/document_loaders)** +- **[`langchain_core.documents`](/langchain-core/langchain_core/documents)** +- **[`langchain_core.embeddings`](/langchain-core/langchain_core/embeddings)** +- **[`langchain_core.example_selectors`](/langchain-core/langchain_core/example_selectors)** +- **[`langchain_core.indexing`](/langchain-core/langchain_core/indexing)** +- **[`langchain_core.language_models`](/langchain-core/langchain_core/language_models)** +- **[`langchain_core.load`](/langchain-core/langchain_core/load)** +- **[`langchain_core.messages`](/langchain-core/langchain_core/messages)** +- **[`langchain_core.output_parsers`](/langchain-core/langchain_core/output_parsers)** +- **[`langchain_core.outputs`](/langchain-core/langchain_core/outputs)** +- **[`langchain_core.prompts`](/langchain-core/langchain_core/prompts)** +- **[`langchain_core.runnables`](/langchain-core/langchain_core/runnables)** +- **[`langchain_core.tools`](/langchain-core/langchain_core/tools)** +- **[`langchain_core.tracers`](/langchain-core/langchain_core/tracers)** +- **[`langchain_core.utils`](/langchain-core/langchain_core/utils)** +- **[`langchain_core.vectorstores`](/langchain-core/langchain_core/vectorstores)** + +## Submodules + +- **[`langchain_core._import_utils`](/langchain-core/langchain_core/_import_utils)** +- **[`langchain_core.agents`](/langchain-core/langchain_core/agents)** +- **[`langchain_core.caches`](/langchain-core/langchain_core/caches)** +- **[`langchain_core.chat_history`](/langchain-core/langchain_core/chat_history)** +- **[`langchain_core.chat_loaders`](/langchain-core/langchain_core/chat_loaders)** +- **[`langchain_core.chat_sessions`](/langchain-core/langchain_core/chat_sessions)** +- **[`langchain_core.env`](/langchain-core/langchain_core/env)** +- **[`langchain_core.exceptions`](/langchain-core/langchain_core/exceptions)** +- **[`langchain_core.globals`](/langchain-core/langchain_core/globals)** +- **[`langchain_core.prompt_values`](/langchain-core/langchain_core/prompt_values)** +- **[`langchain_core.rate_limiters`](/langchain-core/langchain_core/rate_limiters)** +- **[`langchain_core.retrievers`](/langchain-core/langchain_core/retrievers)** +- **[`langchain_core.stores`](/langchain-core/langchain_core/stores)** +- **[`langchain_core.structured_query`](/langchain-core/langchain_core/structured_query)** +- **[`langchain_core.sys_info`](/langchain-core/langchain_core/sys_info)** +- **[`langchain_core.version`](/langchain-core/langchain_core/version)** + +## Package Contents + +### Data + +[`__version__`](#langchain_core-__version__) + +### API + + + + + +```python +langchain_core.__version__ = VERSION +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api.mdx new file mode 100644 index 0000000..e09269f --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api.mdx @@ -0,0 +1,121 @@ +--- +layout: overview +slug: langchain-core/langchain_core/_api +title: langchain_core._api +--- + +Helper functions for managing the LangChain API. + +This module is only relevant for LangChain developers, not for users. + +!!! warning + + This module and its submodules are for internal use only. Do not use them in your + own code. We may change the API at any time with no warning. + +## Submodules + +- **[`langchain_core._api.beta_decorator`](/langchain-core/langchain_core/_api/beta_decorator)** +- **[`langchain_core._api.deprecation`](/langchain-core/langchain_core/_api/deprecation)** +- **[`langchain_core._api.internal`](/langchain-core/langchain_core/_api/internal)** +- **[`langchain_core._api.path`](/langchain-core/langchain_core/_api/path)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-_api-__dir__) | Return a list of available attributes for this module. | +| [`__getattr__`](#langchain_core-_api-__getattr__) | Dynamically import and return an attribute from a submodule. | + +### Data + +[`__all__`](#langchain_core-_api-__all__) + +[`_dynamic_imports`](#langchain_core-_api-_dynamic_imports) + +### API + + + + + +```python +langchain_core._api.__dir__() -> list[str] +``` + + + + + + +Return a list of available attributes for this module. + +**Returns:** `list[str]` + +List of attribute names that can be imported from this module. + + + + + + + + +```python +langchain_core._api.__getattr__( + attr_name: str +) -> object +``` + + + + + + +Dynamically import and return an attribute from a submodule. + +This function enables lazy loading of API functions from submodules, reducing +initial import time and circular dependency issues. + +**Parameters:** + + +Name of the attribute to import. + + +**Returns:** `object` + +The imported attribute object. + +**Raises:** + +- `AttributeError`: If the attribute is not a valid dynamic import. + + + + + + + + +```python +langchain_core._api.__all__ = ('LangChainBetaWarning', 'LangChainDeprecationWarning', 'as_import_path', 'beta'... +``` + + + + + + + + + +```python +langchain_core._api._dynamic_imports = {'LangChainBetaWarning': 'beta_decorator', 'beta': 'beta_decorator', 'suppress_l... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/beta_decorator.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/beta_decorator.mdx new file mode 100644 index 0000000..76463f3 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/beta_decorator.mdx @@ -0,0 +1,208 @@ +--- +layout: overview +slug: langchain-core/langchain_core/_api/beta_decorator +title: langchain_core._api.beta_decorator +--- + +Helper functions for marking parts of the LangChain API as beta. + +This module was loosely adapted from matplotlib's [`_api/deprecation.py`](https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecation.py) +module. + +!!! warning + + This module is for internal use only. Do not use it in your own code. We may change + the API at any time with no warning. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LangChainBetaWarning`](#langchain_core-_api-beta_decorator-LangChainBetaWarning) | A class for issuing beta warnings for LangChain users. | + +### Functions + +| Name | Description | +|------|-------------| +| [`beta`](#langchain_core-_api-beta_decorator-beta) | Decorator to mark a function, a class, or a property as beta. | +| [`suppress_langchain_beta_warning`](#langchain_core-_api-beta_decorator-suppress_langchain_beta_warning) | Context manager to suppress `LangChainDeprecationWarning`. | +| [`surface_langchain_beta_warnings`](#langchain_core-_api-beta_decorator-surface_langchain_beta_warnings) | Unmute LangChain beta warnings. | +| [`warn_beta`](#langchain_core-_api-beta_decorator-warn_beta) | Display a standardized beta annotation. | + +### Data + +[`T`](#langchain_core-_api-beta_decorator-T) + +### API + + + + + +```python +class langchain_core._api.beta_decorator.LangChainBetaWarning() +``` + + + + + + +**Bases:** `DeprecationWarning` + +A class for issuing beta warnings for LangChain users. + + + + + + + + +```python +langchain_core._api.beta_decorator.beta( + message: str = '', + name: str = '', + obj_type: str = '', + addendum: str = '' +) -> collections.abc.Callable[[T], langchain_core._api.beta_decorator.T] +``` + + + + + + +Decorator to mark a function, a class, or a property as beta. + +When marking a classmethod, a staticmethod, or a property, the `@beta` decorator +should go *under* `@classmethod` and `@staticmethod` (i.e., `beta` should directly +decorate the underlying callable), but *over* `@property`. + +When marking a class `C` intended to be used as a base class in a multiple +inheritance hierarchy, `C` *must* define an `__init__` method (if `C` instead +inherited its `__init__` from its own base class, then `@beta` would mess up +`__init__` inheritance when installing its own (annotation-emitting) `C.__init__`). + +**Parameters:** + + +Override the default beta message. + +The %(since)s, %(name)s, %(alternative)s, %(obj_type)s, %(addendum)s, and +%(removal)s format specifiers will be replaced by the values of the +respective arguments passed to this function. + + + +The name of the beta object. + + + +The object type being beta. + + + +Additional text appended directly to the final message. + + +**Returns:** `Callable[[T], T]` + +A decorator which can be used to mark functions or classes as beta. + + + + + + + + +```python +langchain_core._api.beta_decorator.suppress_langchain_beta_warning() -> collections.abc.Generator[None, None, None] +``` + + + + + + +Context manager to suppress `LangChainDeprecationWarning`. + + + + + + + + +```python +langchain_core._api.beta_decorator.surface_langchain_beta_warnings() -> None +``` + + + + + + +Unmute LangChain beta warnings. + + + + + + + + +```python +langchain_core._api.beta_decorator.warn_beta( + message: str = '', + name: str = '', + obj_type: str = '', + addendum: str = '' +) -> None +``` + + + + + + +Display a standardized beta annotation. + +**Parameters:** + + +Override the default beta message. + +The %(name)s, %(obj_type)s, %(addendum)s format specifiers will be replaced +by the values of the respective arguments passed to this function. + + + +The name of the annotated object. + + + +The object type being annotated. + + + +Additional text appended directly to the final message. + + + + + + + + + +```python +langchain_core._api.beta_decorator.T = TypeVar('T', bound=(Callable[..., Any] | type)) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/deprecation.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/deprecation.mdx new file mode 100644 index 0000000..0e78181 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/deprecation.mdx @@ -0,0 +1,458 @@ +--- +layout: overview +slug: langchain-core/langchain_core/_api/deprecation +title: langchain_core._api.deprecation +--- + +Helper functions for deprecating parts of the LangChain API. + +This module was adapted from matplotlib's [`_api/deprecation.py`](https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_api/deprecation.py) +module. + +!!! warning + + This module is for internal use only. Do not use it in your own code. We may change + the API at any time with no warning. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LangChainDeprecationWarning`](#langchain_core-_api-deprecation-LangChainDeprecationWarning) | A class for issuing deprecation warnings for LangChain users. | +| [`LangChainPendingDeprecationWarning`](#langchain_core-_api-deprecation-LangChainPendingDeprecationWarning) | A class for issuing deprecation warnings for LangChain users. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_build_deprecation_message`](#langchain_core-_api-deprecation-_build_deprecation_message) | Build a simple deprecation message for `__deprecated__` attribute. | +| [`_validate_deprecation_params`](#langchain_core-_api-deprecation-_validate_deprecation_params) | Validate the deprecation parameters. | +| [`deprecated`](#langchain_core-_api-deprecation-deprecated) | Decorator to mark a function, a class, or a property as deprecated. | +| [`rename_parameter`](#langchain_core-_api-deprecation-rename_parameter) | Decorator indicating that parameter *old* of *func* is renamed to *new*. | +| [`suppress_langchain_deprecation_warning`](#langchain_core-_api-deprecation-suppress_langchain_deprecation_warning) | Context manager to suppress `LangChainDeprecationWarning`. | +| [`surface_langchain_deprecation_warnings`](#langchain_core-_api-deprecation-surface_langchain_deprecation_warnings) | Unmute LangChain deprecation warnings. | +| [`warn_deprecated`](#langchain_core-_api-deprecation-warn_deprecated) | Display a standardized deprecation. | + +### Data + +[`T`](#langchain_core-_api-deprecation-T) + +[`_P`](#langchain_core-_api-deprecation-_P) + +[`_R`](#langchain_core-_api-deprecation-_R) + +### API + + + + + +```python +class langchain_core._api.deprecation.LangChainDeprecationWarning() +``` + + + + + + +**Bases:** `DeprecationWarning` + +A class for issuing deprecation warnings for LangChain users. + + + + + + + + +```python +class langchain_core._api.deprecation.LangChainPendingDeprecationWarning() +``` + + + + + + +**Bases:** `PendingDeprecationWarning` + +A class for issuing deprecation warnings for LangChain users. + + + + + + + + +```python +langchain_core._api.deprecation._build_deprecation_message( + alternative: str = '', + alternative_import: str = '' +) -> str +``` + + + + + + +Build a simple deprecation message for `__deprecated__` attribute. + +**Parameters:** + + +An alternative API name. + + + +A fully qualified import path for the alternative. + + +**Returns:** `str` + +A deprecation message string for IDE/type checker display. + + + + + + + + +```python +langchain_core._api.deprecation._validate_deprecation_params( + removal: str, + alternative: str, + alternative_import: str, + pending: bool +) -> None +``` + + + + + + +Validate the deprecation parameters. + + + + + + + + +```python +langchain_core._api.deprecation.deprecated( + since: str, + message: str = '', + name: str = '', + alternative: str = '', + alternative_import: str = '', + pending: bool = False, + obj_type: str = '', + addendum: str = '', + removal: str = '', + package: str = '' +) -> collections.abc.Callable[[T], langchain_core._api.deprecation.T] +``` + + + + + + +Decorator to mark a function, a class, or a property as deprecated. + +When deprecating a classmethod, a staticmethod, or a property, the `@deprecated` +decorator should go *under* `@classmethod` and `@staticmethod` (i.e., `deprecated` +should directly decorate the underlying callable), but *over* `@property`. + +When deprecating a class `C` intended to be used as a base class in a multiple +inheritance hierarchy, `C` *must* define an `__init__` method (if `C` instead +inherited its `__init__` from its own base class, then `@deprecated` would mess up +`__init__` inheritance when installing its own (deprecation-emitting) `C.__init__`). + +Parameters are the same as for `warn_deprecated`, except that *obj_type* defaults to +'class' if decorating a class, 'attribute' if decorating a property, and 'function' +otherwise. + +**Parameters:** + + +The release at which this API became deprecated. + + + +Override the default deprecation message. + +The `%(since)s`, `%(name)s`, `%(alternative)s`, `%(obj_type)s`, +`%(addendum)s`, and `%(removal)s` format specifiers will be replaced by the +values of the respective arguments passed to this function. + + + +The name of the deprecated object. + + + +An alternative API that the user may use in place of the deprecated +API. + +The deprecation warning will tell the user about this alternative if +provided. + + + +An alternative import that the user may use instead. + + + +If `True`, uses a `PendingDeprecationWarning` instead of a +`DeprecationWarning`. + +Cannot be used together with removal. + + + +The object type being deprecated. + + + +Additional text appended directly to the final message. + + + +The expected removal version. + +With the default (an empty string), a removal version is automatically +computed from since. Set to other Falsy values to not schedule a removal +date. + +Cannot be used together with pending. + + + +The package of the deprecated object. + + +**Returns:** `Callable[[T], T]` + +A decorator to mark a function or class as deprecated. + + + + + + + + +```python +langchain_core._api.deprecation.rename_parameter( + since: str, + removal: str, + old: str, + new: str +) -> collections.abc.Callable[[Callable[_P, _R]], collections.abc.Callable[langchain_core._api.deprecation._P, langchain_core._api.deprecation._R]] +``` + + + + + + +Decorator indicating that parameter *old* of *func* is renamed to *new*. + +The actual implementation of *func* should use *new*, not *old*. If *old* is passed +to *func*, a `DeprecationWarning` is emitted, and its value is used, even if *new* +is also passed by keyword. + +**Parameters:** + + +The version in which the parameter was renamed. + + + +The version in which the old parameter will be removed. + + + +The old parameter name. + + + +The new parameter name. + + +**Returns:** `Callable[[Callable[_P, _R]], Callable[_P, _R]]` + +A decorator indicating that a parameter was renamed. + + + + + + + + +```python +langchain_core._api.deprecation.suppress_langchain_deprecation_warning() -> collections.abc.Generator[None, None, None] +``` + + + + + + +Context manager to suppress `LangChainDeprecationWarning`. + + + + + + + + +```python +langchain_core._api.deprecation.surface_langchain_deprecation_warnings() -> None +``` + + + + + + +Unmute LangChain deprecation warnings. + + + + + + + + +```python +langchain_core._api.deprecation.warn_deprecated( + since: str, + message: str = '', + name: str = '', + alternative: str = '', + alternative_import: str = '', + pending: bool = False, + obj_type: str = '', + addendum: str = '', + removal: str = '', + package: str = '' +) -> None +``` + + + + + + +Display a standardized deprecation. + +**Parameters:** + + +The release at which this API became deprecated. + + + +Override the default deprecation message. + +The `%(since)s`, `%(name)s`, `%(alternative)s`, `%(obj_type)s`, +`%(addendum)s`, and `%(removal)s` format specifiers will be replaced by the +values of the respective arguments passed to this function. + + + +The name of the deprecated object. + + + +An alternative API that the user may use in place of the +deprecated API. + +The deprecation warning will tell the user about this alternative if +provided. + + + +An alternative import that the user may use instead. + + + +If `True`, uses a `PendingDeprecationWarning` instead of a +`DeprecationWarning`. + +Cannot be used together with removal. + + + +The object type being deprecated. + + + +Additional text appended directly to the final message. + + + +The expected removal version. + +With the default (an empty string), a removal version is automatically +computed from since. Set to other Falsy values to not schedule a removal +date. + +Cannot be used together with pending. + + + +The package of the deprecated object. + + + + + + + + + +```python +langchain_core._api.deprecation.T = TypeVar('T', bound=(type | Callable[..., Any] | Any)) +``` + + + + + + + + + +```python +langchain_core._api.deprecation._P = ParamSpec('_P') +``` + + + + + + + + + +```python +langchain_core._api.deprecation._R = TypeVar('_R') +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/internal.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/internal.mdx new file mode 100644 index 0000000..d9fb0aa --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/internal.mdx @@ -0,0 +1,35 @@ +--- +layout: overview +slug: langchain-core/langchain_core/_api/internal +title: langchain_core._api.internal +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`is_caller_internal`](#langchain_core-_api-internal-is_caller_internal) | Return whether the caller at `depth` of this function is internal. | + +### API + + + + + +```python +langchain_core._api.internal.is_caller_internal( + depth: int = 2 +) -> bool +``` + + + + + + +Return whether the caller at `depth` of this function is internal. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/path.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/path.mdx new file mode 100644 index 0000000..fc0ed44 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_api/path.mdx @@ -0,0 +1,135 @@ +--- +layout: overview +slug: langchain-core/langchain_core/_api/path +title: langchain_core._api.path +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`as_import_path`](#langchain_core-_api-path-as_import_path) | Path of the file as a LangChain import exclude langchain top namespace. | +| [`get_relative_path`](#langchain_core-_api-path-get_relative_path) | Get the path of the file as a relative path to the package directory. | + +### Data + +[`HERE`](#langchain_core-_api-path-HERE) + +[`PACKAGE_DIR`](#langchain_core-_api-path-PACKAGE_DIR) + +[`SEPARATOR`](#langchain_core-_api-path-SEPARATOR) + +### API + + + + + +```python +langchain_core._api.path.as_import_path( + file: pathlib.Path | str, + suffix: str | None = None, + relative_to: pathlib.Path = PACKAGE_DIR +) -> str +``` + + + + + + +Path of the file as a LangChain import exclude langchain top namespace. + +**Parameters:** + + +The file path to convert. + + + +An optional suffix to append to the import path. + + + +The base path to make the file path relative to. + + +**Returns:** `str` + +The import path as a string. + + + + + + + + +```python +langchain_core._api.path.get_relative_path( + file: pathlib.Path | str, + relative_to: pathlib.Path = PACKAGE_DIR +) -> str +``` + + + + + + +Get the path of the file as a relative path to the package directory. + +**Parameters:** + + +The file path to convert. + + + +The base path to make the file path relative to. + + +**Returns:** `str` + +The relative path as a string. + + + + + + + + +```python +langchain_core._api.path.HERE = Path(__file__).parent +``` + + + + + + + + + +```python +langchain_core._api.path.PACKAGE_DIR = HERE.parent +``` + + + + + + + + + +```python +langchain_core._api.path.SEPARATOR = os.sep +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_import_utils.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_import_utils.mdx new file mode 100644 index 0000000..85479cd --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_import_utils.mdx @@ -0,0 +1,65 @@ +--- +layout: overview +slug: langchain-core/langchain_core/_import_utils +title: langchain_core._import_utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`import_attr`](#langchain_core-_import_utils-import_attr) | Import an attribute from a module located in a package. | + +### API + + + + + +```python +langchain_core._import_utils.import_attr( + attr_name: str, + module_name: str | None, + package: str | None +) -> object +``` + + + + + + +Import an attribute from a module located in a package. + +This utility function is used in custom `__getattr__` methods within `__init__.py` +files to dynamically import attributes. + +**Parameters:** + + +The name of the attribute to import. + + + +The name of the module to import from. + +If `None`, the attribute is imported from the package itself. + + + +The name of the package where the module is located. + + +**Returns:** `object` + +The imported attribute. + +**Raises:** + +- `ImportError`: If the module cannot be found. +- `AttributeError`: If the attribute does not exist in the module or package. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_security.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_security.mdx new file mode 100644 index 0000000..29ef499 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_security.mdx @@ -0,0 +1,9 @@ +--- +layout: overview +slug: langchain-core/langchain_core/_security +title: langchain_core._security +--- + +## Submodules + +- **[`langchain_core._security._ssrf_protection`](/langchain-core/langchain_core/_security/_ssrf_protection)** diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_security/_ssrf_protection.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_security/_ssrf_protection.mdx new file mode 100644 index 0000000..1d45def --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/_security/_ssrf_protection.mdx @@ -0,0 +1,499 @@ +--- +layout: overview +slug: langchain-core/langchain_core/_security/_ssrf_protection +title: langchain_core._security._ssrf_protection +--- + +SSRF Protection for validating URLs against Server-Side Request Forgery attacks. + +This module provides utilities to validate user-provided URLs and prevent SSRF attacks +by blocking requests to: +- Private IP ranges (RFC 1918, loopback, link-local) +- Cloud metadata endpoints (AWS, GCP, Azure, etc.) +- Localhost addresses +- Invalid URL schemes + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_validate_url_ssrf_https_only`](#langchain_core-_security-_ssrf_protection-_validate_url_ssrf_https_only) | Validate URL for SSRF protection (HTTPS only, strict mode). | +| [`_validate_url_ssrf_relaxed`](#langchain_core-_security-_ssrf_protection-_validate_url_ssrf_relaxed) | Validate URL for SSRF protection (relaxed mode - allows private IPs). | +| [`_validate_url_ssrf_strict`](#langchain_core-_security-_ssrf_protection-_validate_url_ssrf_strict) | Validate URL for SSRF protection (strict mode). | +| [`is_cloud_metadata`](#langchain_core-_security-_ssrf_protection-is_cloud_metadata) | Check if hostname or IP is a cloud metadata endpoint. | +| [`is_localhost`](#langchain_core-_security-_ssrf_protection-is_localhost) | Check if hostname or IP is localhost. | +| [`is_private_ip`](#langchain_core-_security-_ssrf_protection-is_private_ip) | Check if an IP address is in a private range. | +| [`is_safe_url`](#langchain_core-_security-_ssrf_protection-is_safe_url) | Check if a URL is safe (non-throwing version of validate_safe_url). | +| [`validate_safe_url`](#langchain_core-_security-_ssrf_protection-validate_safe_url) | Validate a URL for SSRF protection. | + +### Data + +[`CLOUD_METADATA_HOSTNAMES`](#langchain_core-_security-_ssrf_protection-CLOUD_METADATA_HOSTNAMES) + +[`CLOUD_METADATA_IPS`](#langchain_core-_security-_ssrf_protection-CLOUD_METADATA_IPS) + +[`LOCALHOST_NAMES`](#langchain_core-_security-_ssrf_protection-LOCALHOST_NAMES) + +[`PRIVATE_IP_RANGES`](#langchain_core-_security-_ssrf_protection-PRIVATE_IP_RANGES) + +[`SSRFProtectedHttpsUrl`](#langchain_core-_security-_ssrf_protection-SSRFProtectedHttpsUrl) + +[`SSRFProtectedHttpsUrlStr`](#langchain_core-_security-_ssrf_protection-SSRFProtectedHttpsUrlStr) + +[`SSRFProtectedUrl`](#langchain_core-_security-_ssrf_protection-SSRFProtectedUrl) + +[`SSRFProtectedUrlRelaxed`](#langchain_core-_security-_ssrf_protection-SSRFProtectedUrlRelaxed) + +### API + + + + + +```python +langchain_core._security._ssrf_protection._validate_url_ssrf_https_only( + v: typing.Any +) -> typing.Any +``` + + + + + + +Validate URL for SSRF protection (HTTPS only, strict mode). + + + + + + + + +```python +langchain_core._security._ssrf_protection._validate_url_ssrf_relaxed( + v: typing.Any +) -> typing.Any +``` + + + + + + +Validate URL for SSRF protection (relaxed mode - allows private IPs). + + + + + + + + +```python +langchain_core._security._ssrf_protection._validate_url_ssrf_strict( + v: typing.Any +) -> typing.Any +``` + + + + + + +Validate URL for SSRF protection (strict mode). + + + + + + + + +```python +langchain_core._security._ssrf_protection.is_cloud_metadata( + hostname: str, + ip_str: str | None = None +) -> bool +``` + + + + + + +Check if hostname or IP is a cloud metadata endpoint. + +**Parameters:** + + +Hostname to check + + + +Optional IP address to check + + +**Returns:** `bool` + +True if hostname or IP is a known cloud metadata endpoint + + + + + + + + +```python +langchain_core._security._ssrf_protection.is_localhost( + hostname: str, + ip_str: str | None = None +) -> bool +``` + + + + + + +Check if hostname or IP is localhost. + +**Parameters:** + + +Hostname to check + + + +Optional IP address to check + + +**Returns:** `bool` + +True if hostname or IP is localhost + + + + + + + + +```python +langchain_core._security._ssrf_protection.is_private_ip( + ip_str: str +) -> bool +``` + + + + + + +Check if an IP address is in a private range. + +**Parameters:** + + +IP address as a string (e.g., "192.168.1.1") + + +**Returns:** `bool` + +True if IP is in a private range, False otherwise + + + + + + + + +```python +langchain_core._security._ssrf_protection.is_safe_url( + url: str | pydantic.AnyHttpUrl, + allow_private: bool = False, + allow_http: bool = True +) -> bool +``` + + + + + + +Check if a URL is safe (non-throwing version of validate_safe_url). + +**Parameters:** + + +The URL to check + + + +If True, allows private IPs and localhost + + + +If True, allows both HTTP and HTTPS + + +**Returns:** `bool` + +True if URL is safe, False otherwise + +**Examples:** + + + +```python +>>> is_safe_url("https://example.com") +True +``` + + + + + +```python +>>> is_safe_url("http://127.0.0.1:8080") +False +``` + + + + + +```python +>>> is_safe_url("http://localhost:8080", allow_private=True) +True +``` + + + + + + + + + + +```python +langchain_core._security._ssrf_protection.validate_safe_url( + url: str | pydantic.AnyHttpUrl, + allow_private: bool = False, + allow_http: bool = True +) -> str +``` + + + + + + +Validate a URL for SSRF protection. + +This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks +by blocking requests to private networks and cloud metadata endpoints. + +**Parameters:** + + +The URL to validate (string or Pydantic HttpUrl) + + + +If True, allows private IPs and localhost (for development). + Cloud metadata endpoints are ALWAYS blocked. + + + +If True, allows both HTTP and HTTPS. If False, only HTTPS. + + +**Returns:** `str` + +The validated URL as a string + +**Raises:** + +- `ValueError`: If URL is invalid or potentially dangerous + +**Examples:** + + + +```python +>>> validate_safe_url("https://hooks.slack.com/services/xxx") +'https://hooks.slack.com/services/xxx' +``` + + + + + +```python +>>> validate_safe_url("http://127.0.0.1:8080") +ValueError: Localhost URLs are not allowed +``` + + + + + +```python +>>> validate_safe_url("http://192.168.1.1") +ValueError: URL resolves to private IP: 192.168.1.1 +``` + + + + + +```python +>>> validate_safe_url("http://169.254.169.254/latest/meta-data/") +ValueError: URL resolves to cloud metadata IP: 169.254.169.254 +``` + + + + + +```python +>>> validate_safe_url("http://localhost:8080", allow_private=True) +'http://localhost:8080' +``` + + + + + + + + + + +```python +langchain_core._security._ssrf_protection.CLOUD_METADATA_HOSTNAMES = ['metadata.google.internal', 'metadata', 'instance-data'] +``` + + + + + + + + + +```python +langchain_core._security._ssrf_protection.CLOUD_METADATA_IPS = ['169.254.169.254', '169.254.170.2', '100.100.100.200'] +``` + + + + + + + + + +```python +langchain_core._security._ssrf_protection.LOCALHOST_NAMES = ['localhost', 'localhost.localdomain'] +``` + + + + + + + + + +```python +langchain_core._security._ssrf_protection.PRIVATE_IP_RANGES = [ipaddress.ip_network('10.0.0.0/8'), ipaddress.ip_network('172.16.0.0/12'), ipad... +``` + + + + + + + + + +```python +langchain_core._security._ssrf_protection.SSRFProtectedHttpsUrl = Annotated[HttpUrl, BeforeValidator(_validate_url_ssrf_https_only)] +``` + + + + + + +A Pydantic HttpUrl with SSRF protection that only allows HTTPS. + +This blocks private IPs, localhost, cloud metadata endpoints, and HTTP URLs. + + + + + + + +```python +langchain_core._security._ssrf_protection.SSRFProtectedHttpsUrlStr = Annotated[str, BeforeValidator(_validate_url_ssrf_https_only)] +``` + + + + + + +A string type with SSRF protection that only allows HTTPS URLs. + +Same as SSRFProtectedHttpsUrl but returns a string instead of HttpUrl. +Useful for FastAPI query parameters where you need a string URL. + + + + + + + +```python +langchain_core._security._ssrf_protection.SSRFProtectedUrl = Annotated[HttpUrl, BeforeValidator(_validate_url_ssrf_strict)] +``` + + + + + + +A Pydantic HttpUrl type with built-in SSRF protection. + +This blocks private IPs, localhost, and cloud metadata endpoints. + + + + + + + +```python +langchain_core._security._ssrf_protection.SSRFProtectedUrlRelaxed = Annotated[HttpUrl, BeforeValidator(_validate_url_ssrf_relaxed)] +``` + + + + + + +A Pydantic HttpUrl with relaxed SSRF protection (allows private IPs). + +Use this for development/testing webhooks where localhost/private IPs are needed. +Cloud metadata endpoints are still blocked. + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/agents.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/agents.mdx new file mode 100644 index 0000000..db746f2 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/agents.mdx @@ -0,0 +1,415 @@ +--- +layout: overview +slug: langchain-core/langchain_core/agents +title: langchain_core.agents +--- + +Schema definitions for representing agent actions, observations, and return values. + +!!! warning + + The schema definitions are provided for backwards compatibility. + +!!! warning + + New agents should be built using the + [`langchain` library](https://pypi.org/project/langchain/), which provides a + simpler and more flexible way to define agents. + + See docs on [building agents](https://docs.langchain.com/oss/python/langchain/agents). + +Agents use language models to choose a sequence of actions to take. + +A basic agent works in the following manner: + +1. Given a prompt an agent uses an LLM to request an action to take + (e.g., a tool to run). +2. The agent executes the action (e.g., runs the tool), and receives an observation. +3. The agent returns the observation to the LLM, which can then be used to generate + the next action. +4. When the agent reaches a stopping condition, it returns a final return value. + +The schemas for the agents themselves are defined in `langchain.agents.agent`. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AgentAction`](#langchain_core-agents-AgentAction) | Represents a request to execute an action by an agent. | +| [`AgentActionMessageLog`](#langchain_core-agents-AgentActionMessageLog) | Representation of an action to be executed by an agent. | +| [`AgentFinish`](#langchain_core-agents-AgentFinish) | Final return value of an `ActionAgent`. | +| [`AgentStep`](#langchain_core-agents-AgentStep) | Result of running an `AgentAction`. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_convert_agent_action_to_messages`](#langchain_core-agents-_convert_agent_action_to_messages) | Convert an agent action to a message. | +| [`_convert_agent_observation_to_messages`](#langchain_core-agents-_convert_agent_observation_to_messages) | Convert an agent action to a message. | +| [`_create_function_message`](#langchain_core-agents-_create_function_message) | Convert agent action and observation into a function message. | + +### API + + + + + +```python +class langchain_core.agents.AgentAction( + tool: str, + tool_input: str | dict, + log: str, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [Serializable](/langchain-core/langchain_core/load/serializable#langchain_core-load-serializable-Serializable) + +Represents a request to execute an action by an agent. + +The action consists of the name of the tool to execute and the input to pass +to the tool. The log is used to pass along extra information about the action. + + + +Additional information to log about the action. + +This log can be used in a few ways. First, it can be used to audit what exactly the +LLM predicted to lead to this `(tool, tool_input)`. + +Second, it can be used in future iterations to show the LLMs prior thoughts. This is +useful when `(tool, tool_input)` does not contain full information about the LLM +prediction (for example, any `thought` before the tool/tool_input). + + + +Return the messages that correspond to this action. + + + +The name of the `Tool` to execute. + + + +The input to pass in to the `Tool`. + + + + + + + + +```python +langchain_core.agents.AgentAction.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "agent"]` + + + + + + + +```python +langchain_core.agents.AgentAction.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +`AgentAction` is serializable. + +**Returns:** `bool` + +`True` + + + + + + + + + +```python +class langchain_core.agents.AgentActionMessageLog() +``` + + + + + + +**Bases:** [AgentAction](#langchain_core-agents-AgentAction) + +Representation of an action to be executed by an agent. + +This is similar to `AgentAction`, but includes a message log consisting of +chat messages. + +This is useful when working with `ChatModels`, and is used to reconstruct +conversation history from the agent's perspective. + + + +Similar to log, this can be used to pass along extra information about what exact +messages were predicted by the LLM before parsing out the `(tool, tool_input)`. + +This is again useful if `(tool, tool_input)` cannot be used to fully recreate the +LLM prediction, and you need that LLM prediction (for future agent iteration). + +Compared to `log`, this is useful when the underlying LLM is a chat model (and +therefore returns messages rather than a string). + + + + + + + + + + +```python +class langchain_core.agents.AgentFinish( + return_values: dict, + log: str, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [Serializable](/langchain-core/langchain_core/load/serializable#langchain_core-load-serializable-Serializable) + +Final return value of an `ActionAgent`. + +Agents return an `AgentFinish` when they have reached a stopping condition. + + + +Additional information to log about the return value. + +This is used to pass along the full LLM prediction, not just the parsed out +return value. + +For example, if the full LLM prediction was `Final Answer: 2` you may want to just +return `2` as a return value, but pass along the full string as a `log` (for +debugging or observability purposes). + + + +Messages that correspond to this observation. + + + +Dictionary of return values. + + + + + + + + +```python +langchain_core.agents.AgentFinish.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "agent"]` + + + + + + + +```python +langchain_core.agents.AgentFinish.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + + + +```python +class langchain_core.agents.AgentStep() +``` + + + + + + +**Bases:** [Serializable](/langchain-core/langchain_core/load/serializable#langchain_core-load-serializable-Serializable) + +Result of running an `AgentAction`. + + + +The `AgentAction` that was executed. + + + +Messages that correspond to this observation. + + + +The result of the `AgentAction`. + + + + + + + +```python +langchain_core.agents._convert_agent_action_to_messages( + agent_action: langchain_core.agents.AgentAction +) -> collections.abc.Sequence[langchain_core.messages.BaseMessage] +``` + + + + + + +Convert an agent action to a message. + +This code is used to reconstruct the original AI message from the agent action. + +**Parameters:** + + +Agent action to convert. + + +**Returns:** `Sequence[BaseMessage]` + +`AIMessage` that corresponds to the original tool invocation. + + + + + + + + +```python +langchain_core.agents._convert_agent_observation_to_messages( + agent_action: langchain_core.agents.AgentAction, + observation: typing.Any +) -> collections.abc.Sequence[langchain_core.messages.BaseMessage] +``` + + + + + + +Convert an agent action to a message. + +This code is used to reconstruct the original AI message from the agent action. + +**Parameters:** + + +Agent action to convert. + + + +Observation to convert to a message. + + +**Returns:** `Sequence[BaseMessage]` + +`AIMessage` that corresponds to the original tool invocation. + + + + + + + + +```python +langchain_core.agents._create_function_message( + agent_action: langchain_core.agents.AgentAction, + observation: typing.Any +) -> langchain_core.messages.FunctionMessage +``` + + + + + + +Convert agent action and observation into a function message. + +**Parameters:** + + +the tool invocation request from the agent. + + + +the result of the tool invocation. + + +**Returns:** `FunctionMessage` + +`FunctionMessage` that corresponds to the original tool invocation. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/caches.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/caches.mdx new file mode 100644 index 0000000..519b020 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/caches.mdx @@ -0,0 +1,541 @@ +--- +layout: overview +slug: langchain-core/langchain_core/caches +title: langchain_core.caches +--- + +Optional caching layer for language models. + +Distinct from provider-based [prompt caching](https://docs.langchain.com/oss/python/langchain/models#prompt-caching). + +!!! warning "Beta feature" + + This is a beta feature. Please be wary of deploying experimental code to production + unless you've taken appropriate precautions. + +A cache is useful for two reasons: + +1. It can save you money by reducing the number of API calls you make to the LLM + provider if you're often requesting the same completion multiple times. +2. It can speed up your application by reducing the number of API calls you make to the + LLM provider. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseCache`](#langchain_core-caches-BaseCache) | Interface for a caching layer for LLMs and Chat models. | +| [`InMemoryCache`](#langchain_core-caches-InMemoryCache) | Cache that stores things in memory. | + +### Data + +[`RETURN_VAL_TYPE`](#langchain_core-caches-RETURN_VAL_TYPE) + +### API + + + + + +```python +class langchain_core.caches.BaseCache() +``` + + + + + + +Abstract + +Interface for a caching layer for LLMs and Chat models. + +The cache interface consists of the following methods: + +- lookup: Look up a value based on a prompt and `llm_string`. +- update: Update the cache based on a prompt and `llm_string`. +- clear: Clear the cache. + +In addition, the cache interface provides an async version of each method. + +The default implementation of the async methods is to run the synchronous +method in an executor. It's recommended to override the async methods +and provide async implementations to avoid unnecessary overhead. + + + + + + +```python +langchain_core.caches.BaseCache.aclear( + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Async clear cache that can take additional keyword arguments. + + + + + + + +```python +langchain_core.caches.BaseCache.alookup( + prompt: str, + llm_string: str +) -> langchain_core.caches.RETURN_VAL_TYPE | None +``` + + + + + + +async + +Async look up based on `prompt` and `llm_string`. + +A cache implementation is expected to generate a key from the 2-tuple +of `prompt` and `llm_string` (e.g., by concatenating them with a delimiter). + +**Parameters:** + + +A string representation of the prompt. + +In the case of a chat model, the prompt is a non-trivial +serialization of the prompt into the language model. + + + +A string representation of the LLM configuration. + +This is used to capture the invocation parameters of the LLM +(e.g., model name, temperature, stop tokens, max tokens, etc.). + +These invocation parameters are serialized into a string +representation. + + +**Returns:** `RETURN_VAL_TYPE | None` + +On a cache miss, return `None`. On a cache hit, return the cached value. +The cached value is a list of `Generation` (or subclasses). + + + + + + + +```python +langchain_core.caches.BaseCache.aupdate( + prompt: str, + llm_string: str, + return_val: langchain_core.caches.RETURN_VAL_TYPE +) -> None +``` + + + + + + +async + +Async update cache based on `prompt` and `llm_string`. + +The prompt and llm_string are used to generate a key for the cache. +The key should match that of the look up method. + +**Parameters:** + + +A string representation of the prompt. + +In the case of a chat model, the prompt is a non-trivial +serialization of the prompt into the language model. + + + +A string representation of the LLM configuration. + +This is used to capture the invocation parameters of the LLM +(e.g., model name, temperature, stop tokens, max tokens, etc.). + +These invocation parameters are serialized into a string +representation. + + + +The value to be cached. The value is a list of `Generation` +(or subclasses). + + + + + + + + +```python +langchain_core.caches.BaseCache.clear( + kwargs: typing.Any = {} +) -> None +``` + + + + + + +abstract + +Clear cache that can take additional keyword arguments. + + + + + + + +```python +langchain_core.caches.BaseCache.lookup( + prompt: str, + llm_string: str +) -> langchain_core.caches.RETURN_VAL_TYPE | None +``` + + + + + + +abstract + +Look up based on `prompt` and `llm_string`. + +A cache implementation is expected to generate a key from the 2-tuple +of `prompt` and `llm_string` (e.g., by concatenating them with a delimiter). + +**Parameters:** + + +A string representation of the prompt. + +In the case of a chat model, the prompt is a non-trivial +serialization of the prompt into the language model. + + + +A string representation of the LLM configuration. + +This is used to capture the invocation parameters of the LLM +(e.g., model name, temperature, stop tokens, max tokens, etc.). + +These invocation parameters are serialized into a string representation. + + +**Returns:** `RETURN_VAL_TYPE | None` + +On a cache miss, return `None`. On a cache hit, return the cached value. +The cached value is a list of `Generation` (or subclasses). + + + + + + + +```python +langchain_core.caches.BaseCache.update( + prompt: str, + llm_string: str, + return_val: langchain_core.caches.RETURN_VAL_TYPE +) -> None +``` + + + + + + +abstract + +Update cache based on `prompt` and `llm_string`. + +The `prompt` and `llm_string` are used to generate a key for the cache. The key +should match that of the lookup method. + +**Parameters:** + + +A string representation of the prompt. + +In the case of a chat model, the prompt is a non-trivial +serialization of the prompt into the language model. + + + +A string representation of the LLM configuration. + +This is used to capture the invocation parameters of the LLM +(e.g., model name, temperature, stop tokens, max tokens, etc.). + +These invocation parameters are serialized into a string +representation. + + + +The value to be cached. + +The value is a list of `Generation` (or subclasses). + + + + + + + + + + +```python +class langchain_core.caches.InMemoryCache( + maxsize: int | None = None +) +``` + + + + + + +**Bases:** [BaseCache](#langchain_core-caches-BaseCache) + +Cache that stores things in memory. + + + + + + + + +```python +langchain_core.caches.InMemoryCache.aclear( + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Async clear cache. + + + + + + + +```python +langchain_core.caches.InMemoryCache.alookup( + prompt: str, + llm_string: str +) -> langchain_core.caches.RETURN_VAL_TYPE | None +``` + + + + + + +async + +Async look up based on `prompt` and `llm_string`. + +**Parameters:** + + +A string representation of the prompt. + +In the case of a chat model, the prompt is a non-trivial +serialization of the prompt into the language model. + + + +A string representation of the LLM configuration. + + +**Returns:** `RETURN_VAL_TYPE | None` + +On a cache miss, return `None`. On a cache hit, return the cached value. + + + + + + + +```python +langchain_core.caches.InMemoryCache.aupdate( + prompt: str, + llm_string: str, + return_val: langchain_core.caches.RETURN_VAL_TYPE +) -> None +``` + + + + + + +async + +Async update cache based on `prompt` and `llm_string`. + +**Parameters:** + + +A string representation of the prompt. + +In the case of a chat model, the prompt is a non-trivial +serialization of the prompt into the language model. + + + +A string representation of the LLM configuration. + + + +The value to be cached. The value is a list of `Generation` +(or subclasses). + + + + + + + + +```python +langchain_core.caches.InMemoryCache.clear( + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Clear cache. + + + + + + + +```python +langchain_core.caches.InMemoryCache.lookup( + prompt: str, + llm_string: str +) -> langchain_core.caches.RETURN_VAL_TYPE | None +``` + + + + + + +Look up based on `prompt` and `llm_string`. + +**Parameters:** + + +A string representation of the prompt. + +In the case of a chat model, the prompt is a non-trivial +serialization of the prompt into the language model. + + + +A string representation of the LLM configuration. + + +**Returns:** `RETURN_VAL_TYPE | None` + +On a cache miss, return `None`. On a cache hit, return the cached value. + + + + + + + +```python +langchain_core.caches.InMemoryCache.update( + prompt: str, + llm_string: str, + return_val: langchain_core.caches.RETURN_VAL_TYPE +) -> None +``` + + + + + + +Update cache based on `prompt` and `llm_string`. + +**Parameters:** + + +A string representation of the prompt. + +In the case of a chat model, the prompt is a non-trivial +serialization of the prompt into the language model. + + + +A string representation of the LLM configuration. + + + +The value to be cached. + +The value is a list of `Generation` (or subclasses). + + + + + + + + + + +```python +langchain_core.caches.RETURN_VAL_TYPE = Sequence[Generation] +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks.mdx new file mode 100644 index 0000000..ae93621 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks.mdx @@ -0,0 +1,91 @@ +--- +layout: overview +slug: langchain-core/langchain_core/callbacks +title: langchain_core.callbacks +--- + +Callback handlers allow listening to events in LangChain. + +## Submodules + +- **[`langchain_core.callbacks.base`](/langchain-core/langchain_core/callbacks/base)** +- **[`langchain_core.callbacks.file`](/langchain-core/langchain_core/callbacks/file)** +- **[`langchain_core.callbacks.manager`](/langchain-core/langchain_core/callbacks/manager)** +- **[`langchain_core.callbacks.stdout`](/langchain-core/langchain_core/callbacks/stdout)** +- **[`langchain_core.callbacks.streaming_stdout`](/langchain-core/langchain_core/callbacks/streaming_stdout)** +- **[`langchain_core.callbacks.usage`](/langchain-core/langchain_core/callbacks/usage)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-callbacks-__dir__) | - | +| [`__getattr__`](#langchain_core-callbacks-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-callbacks-__all__) + +[`_dynamic_imports`](#langchain_core-callbacks-_dynamic_imports) + +### API + + + + + +```python +langchain_core.callbacks.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.callbacks.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.callbacks.__all__ = ('AsyncCallbackHandler', 'AsyncCallbackManager', 'AsyncCallbackManagerForChainGr... +``` + + + + + + + + + +```python +langchain_core.callbacks._dynamic_imports = {'AsyncCallbackHandler': 'base', 'BaseCallbackHandler': 'base', 'BaseCallbackMan... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/base.mdx new file mode 100644 index 0000000..5e15543 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/base.mdx @@ -0,0 +1,2445 @@ +--- +layout: overview +slug: langchain-core/langchain_core/callbacks/base +title: langchain_core.callbacks.base +--- + +Base callback handler for LangChain. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AsyncCallbackHandler`](#langchain_core-callbacks-base-AsyncCallbackHandler) | Base async callback handler. | +| [`BaseCallbackHandler`](#langchain_core-callbacks-base-BaseCallbackHandler) | Base callback handler. | +| [`BaseCallbackManager`](#langchain_core-callbacks-base-BaseCallbackManager) | Base callback manager. | +| [`CallbackManagerMixin`](#langchain_core-callbacks-base-CallbackManagerMixin) | Mixin for callback manager. | +| [`ChainManagerMixin`](#langchain_core-callbacks-base-ChainManagerMixin) | Mixin for chain callbacks. | +| [`LLMManagerMixin`](#langchain_core-callbacks-base-LLMManagerMixin) | Mixin for LLM callbacks. | +| [`RetrieverManagerMixin`](#langchain_core-callbacks-base-RetrieverManagerMixin) | Mixin for `Retriever` callbacks. | +| [`RunManagerMixin`](#langchain_core-callbacks-base-RunManagerMixin) | Mixin for run manager. | +| [`ToolManagerMixin`](#langchain_core-callbacks-base-ToolManagerMixin) | Mixin for tool callbacks. | + +### Data + +[`Callbacks`](#langchain_core-callbacks-base-Callbacks) + +[`_LOGGER`](#langchain_core-callbacks-base-_LOGGER) + +### API + + + + + +```python +class langchain_core.callbacks.base.AsyncCallbackHandler() +``` + + + + + + +**Bases:** [BaseCallbackHandler](#langchain_core-callbacks-base-BaseCallbackHandler) + +Base async callback handler. + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_agent_action( + action: langchain_core.agents.AgentAction, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run on agent action. + +**Parameters:** + + +The agent action. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_agent_finish( + finish: langchain_core.agents.AgentFinish, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run on the agent end. + +**Parameters:** + + +The agent finish. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_chain_end( + outputs: dict[str, typing.Any], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when a chain ends running. + +**Parameters:** + + +The outputs of the chain. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_chain_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when chain errors. + +**Parameters:** + + +The error that occurred. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_chain_start( + serialized: dict[str, typing.Any], + inputs: dict[str, typing.Any], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when a chain starts running. + +**Parameters:** + + +The serialized chain. + + + +The inputs. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +The metadata. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_chat_model_start( + serialized: dict[str, typing.Any], + messages: list[list[langchain_core.messages.BaseMessage]], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +async + +Run when a chat model starts running. + +!!! warning + + This method is called for chat models. If you're implementing a handler for + a non-chat model, you should use `on_llm_start` instead. + +!!! note + + When overriding this method, the signature **must** include the two + required positional arguments ``serialized`` and ``messages``. Avoid + using ``*args`` in your override — doing so causes an ``IndexError`` + in the fallback path when the callback system converts ``messages`` + to prompt strings for ``on_llm_start``. Always declare the + signature explicitly: + + .. code-block:: python + + async def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + **kwargs: Any, + ) -> None: + raise NotImplementedError # triggers fallback to on_llm_start + +**Parameters:** + + +The serialized chat model. + + + +The messages. Must be a list of message lists — this is a +required positional argument and must be present in any override. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +The metadata. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_custom_event( + name: str, + data: typing.Any, + run_id: uuid.UUID, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Override to define a handler for custom events. + +**Parameters:** + + +The name of the custom event. + + + +The data for the custom event. + +Format will match the format specified by the user. + + + +The ID of the run. + + + +The tags associated with the custom event (includes inherited tags). + + + +The metadata associated with the custom event (includes inherited +metadata). + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_llm_end( + response: langchain_core.outputs.LLMResult, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when the model ends running. + +**Parameters:** + + +The response which was generated. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_llm_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when LLM errors. + +**Parameters:** + + +The error that occurred. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + +- response (LLMResult): The response which was generated before + the error occurred. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_llm_new_token( + token: str, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None = None, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run on new output token. Only available when streaming is enabled. + +For both chat models and non-chat models (legacy text completion LLMs). + +**Parameters:** + + +The new token. + + + +The new generated chunk, containing content and other information. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_llm_start( + serialized: dict[str, typing.Any], + prompts: list[str], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when the model starts running. + +!!! warning + + This method is called for non-chat models (regular text completion LLMs). If + you're implementing a handler for a chat model, you should use + `on_chat_model_start` instead. + +**Parameters:** + + +The serialized LLM. + + + +The prompts. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +The metadata. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_retriever_end( + documents: collections.abc.Sequence[langchain_core.documents.Document], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run on the retriever end. + +**Parameters:** + + +The documents retrieved. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_retriever_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run on retriever error. + +**Parameters:** + + +The error that occurred. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_retriever_start( + serialized: dict[str, typing.Any], + query: str, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run on the retriever start. + +**Parameters:** + + +The serialized retriever. + + + +The query. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +The metadata. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_retry( + retry_state: tenacity.RetryCallState, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +async + +Run on a retry event. + +**Parameters:** + + +The retry state. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_text( + text: str, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run on an arbitrary text. + +**Parameters:** + + +The text. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_tool_end( + output: typing.Any, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when the tool ends running. + +**Parameters:** + + +The output of the tool. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_tool_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when tool errors. + +**Parameters:** + + +The error that occurred. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.AsyncCallbackHandler.on_tool_start( + serialized: dict[str, typing.Any], + input_str: str, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + inputs: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when the tool starts running. + +**Parameters:** + + +The serialized tool. + + + +The input string. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +The metadata. + + + +The inputs. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.base.BaseCallbackHandler() +``` + + + + + + +**Bases:** [LLMManagerMixin](#langchain_core-callbacks-base-LLMManagerMixin), [ChainManagerMixin](#langchain_core-callbacks-base-ChainManagerMixin), [ToolManagerMixin](#langchain_core-callbacks-base-ToolManagerMixin), [RetrieverManagerMixin](#langchain_core-callbacks-base-RetrieverManagerMixin), [CallbackManagerMixin](#langchain_core-callbacks-base-CallbackManagerMixin), [RunManagerMixin](#langchain_core-callbacks-base-RunManagerMixin) + +Base callback handler. + + + +Whether to ignore agent callbacks. + + + +Whether to ignore chain callbacks. + + + +Whether to ignore chat model callbacks. + + + +Ignore custom event. + + + +Whether to ignore LLM callbacks. + + + +Whether to ignore retriever callbacks. + + + +Whether to ignore retry callbacks. + + + +Whether to raise an error if an exception occurs. + + + +Whether to run the callback inline. + + + + + + + +```python +class langchain_core.callbacks.base.BaseCallbackManager( + handlers: list[langchain_core.callbacks.base.BaseCallbackHandler], + inheritable_handlers: list[langchain_core.callbacks.base.BaseCallbackHandler] | None = None, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + inheritable_tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + inheritable_metadata: dict[str, typing.Any] | None = None +) +``` + + + + + + +**Bases:** [CallbackManagerMixin](#langchain_core-callbacks-base-CallbackManagerMixin) + +Base callback manager. + + + + + + + + + + + + +Whether the callback manager is async. + + + + + + + + + + + +```python +langchain_core.callbacks.base.BaseCallbackManager.add_handler( + handler: langchain_core.callbacks.base.BaseCallbackHandler, + inherit: bool = True +) -> None +``` + + + + + + +Add a handler to the callback manager. + +**Parameters:** + + +The handler to add. + + + +Whether to inherit the handler. + + + + + + + + +```python +langchain_core.callbacks.base.BaseCallbackManager.add_metadata( + metadata: dict[str, typing.Any], + inherit: bool = True +) -> None +``` + + + + + + +Add metadata to the callback manager. + +**Parameters:** + + +The metadata to add. + + + +Whether to inherit the metadata. + + + + + + + + +```python +langchain_core.callbacks.base.BaseCallbackManager.add_tags( + tags: list[str], + inherit: bool = True +) -> None +``` + + + + + + +Add tags to the callback manager. + +**Parameters:** + + +The tags to add. + + + +Whether to inherit the tags. + + + + + + + + +```python +langchain_core.callbacks.base.BaseCallbackManager.copy() -> typing_extensions.Self +``` + + + + + + +Return a copy of the callback manager. + + + + + + + +```python +langchain_core.callbacks.base.BaseCallbackManager.merge( + other: langchain_core.callbacks.base.BaseCallbackManager +) -> typing_extensions.Self +``` + + + + + + +Merge the callback manager with another callback manager. + +May be overwritten in subclasses. + +Primarily used internally within `merge_configs`. + +**Returns:** `Self` + +The merged callback manager of the same type as the current object. + + + + + + + +```python +langchain_core.callbacks.base.BaseCallbackManager.remove_handler( + handler: langchain_core.callbacks.base.BaseCallbackHandler +) -> None +``` + + + + + + +Remove a handler from the callback manager. + +**Parameters:** + + +The handler to remove. + + + + + + + + +```python +langchain_core.callbacks.base.BaseCallbackManager.remove_metadata( + keys: list[str] +) -> None +``` + + + + + + +Remove metadata from the callback manager. + +**Parameters:** + + +The keys to remove. + + + + + + + + +```python +langchain_core.callbacks.base.BaseCallbackManager.remove_tags( + tags: list[str] +) -> None +``` + + + + + + +Remove tags from the callback manager. + +**Parameters:** + + +The tags to remove. + + + + + + + + +```python +langchain_core.callbacks.base.BaseCallbackManager.set_handler( + handler: langchain_core.callbacks.base.BaseCallbackHandler, + inherit: bool = True +) -> None +``` + + + + + + +Set handler as the only handler on the callback manager. + +**Parameters:** + + +The handler to set. + + + +Whether to inherit the handler. + + + + + + + + +```python +langchain_core.callbacks.base.BaseCallbackManager.set_handlers( + handlers: list[langchain_core.callbacks.base.BaseCallbackHandler], + inherit: bool = True +) -> None +``` + + + + + + +Set handlers as the only handlers on the callback manager. + +**Parameters:** + + +The handlers to set. + + + +Whether to inherit the handlers. + + + + + + + + + + +```python +class langchain_core.callbacks.base.CallbackManagerMixin() +``` + + + + + + +Mixin for callback manager. + + + + + + +```python +langchain_core.callbacks.base.CallbackManagerMixin.on_chain_start( + serialized: dict[str, typing.Any], + inputs: dict[str, typing.Any], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when a chain starts running. + +**Parameters:** + + +The serialized chain. + + + +The inputs. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +The metadata. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.CallbackManagerMixin.on_chat_model_start( + serialized: dict[str, typing.Any], + messages: list[list[langchain_core.messages.BaseMessage]], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when a chat model starts running. + +!!! warning + + This method is called for chat models. If you're implementing a handler for + a non-chat model, you should use `on_llm_start` instead. + +!!! note + + When overriding this method, the signature **must** include the two + required positional arguments ``serialized`` and ``messages``. Avoid + using ``*args`` in your override — doing so causes an ``IndexError`` + in the fallback path when the callback system converts ``messages`` + to prompt strings for ``on_llm_start``. Always declare the + signature explicitly: + + .. code-block:: python + + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + **kwargs: Any, + ) -> None: + raise NotImplementedError # triggers fallback to on_llm_start + +**Parameters:** + + +The serialized chat model. + + + +The messages. Must be a list of message lists — this is a +required positional argument and must be present in any override. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +The metadata. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.CallbackManagerMixin.on_llm_start( + serialized: dict[str, typing.Any], + prompts: list[str], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when LLM starts running. + +!!! warning + + This method is called for non-chat models (regular text completion LLMs). If + you're implementing a handler for a chat model, you should use + `on_chat_model_start` instead. + +**Parameters:** + + +The serialized LLM. + + + +The prompts. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +The metadata. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.CallbackManagerMixin.on_retriever_start( + serialized: dict[str, typing.Any], + query: str, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when the `Retriever` starts running. + +**Parameters:** + + +The serialized `Retriever`. + + + +The query. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +The metadata. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.CallbackManagerMixin.on_tool_start( + serialized: dict[str, typing.Any], + input_str: str, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + inputs: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when the tool starts running. + +**Parameters:** + + +The serialized chain. + + + +The input string. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +The metadata. + + + +The inputs. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.base.ChainManagerMixin() +``` + + + + + + +Mixin for chain callbacks. + + + + + + +```python +langchain_core.callbacks.base.ChainManagerMixin.on_agent_action( + action: langchain_core.agents.AgentAction, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run on agent action. + +**Parameters:** + + +The agent action. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.ChainManagerMixin.on_agent_finish( + finish: langchain_core.agents.AgentFinish, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run on the agent end. + +**Parameters:** + + +The agent finish. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.ChainManagerMixin.on_chain_end( + outputs: dict[str, typing.Any], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when chain ends running. + +**Parameters:** + + +The outputs of the chain. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.ChainManagerMixin.on_chain_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when chain errors. + +**Parameters:** + + +The error that occurred. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.base.LLMManagerMixin() +``` + + + + + + +Mixin for LLM callbacks. + + + + + + +```python +langchain_core.callbacks.base.LLMManagerMixin.on_llm_end( + response: langchain_core.outputs.LLMResult, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when LLM ends running. + +**Parameters:** + + +The response which was generated. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.LLMManagerMixin.on_llm_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when LLM errors. + +**Parameters:** + + +The error that occurred. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.LLMManagerMixin.on_llm_new_token( + token: str, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None = None, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run on new output token. + +Only available when streaming is enabled. + +For both chat models and non-chat models (legacy text completion LLMs). + +**Parameters:** + + +The new token. + + + +The new generated chunk, containing content and other information. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +The tags. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.base.RetrieverManagerMixin() +``` + + + + + + +Mixin for `Retriever` callbacks. + + + + + + +```python +langchain_core.callbacks.base.RetrieverManagerMixin.on_retriever_end( + documents: collections.abc.Sequence[langchain_core.documents.Document], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when `Retriever` ends running. + +**Parameters:** + + +The documents retrieved. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.RetrieverManagerMixin.on_retriever_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when `Retriever` errors. + +**Parameters:** + + +The error that occurred. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.base.RunManagerMixin() +``` + + + + + + +Mixin for run manager. + + + + + + +```python +langchain_core.callbacks.base.RunManagerMixin.on_custom_event( + name: str, + data: typing.Any, + run_id: uuid.UUID, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Override to define a handler for a custom event. + +**Parameters:** + + +The name of the custom event. + + + +The data for the custom event. + +Format will match the format specified by the user. + + + +The ID of the run. + + + +The tags associated with the custom event (includes inherited tags). + + + +The metadata associated with the custom event (includes inherited +metadata). + + + + + + + + +```python +langchain_core.callbacks.base.RunManagerMixin.on_retry( + retry_state: tenacity.RetryCallState, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run on a retry event. + +**Parameters:** + + +The retry state. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.RunManagerMixin.on_text( + text: str, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run on an arbitrary text. + +**Parameters:** + + +The text. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.base.ToolManagerMixin() +``` + + + + + + +Mixin for tool callbacks. + + + + + + +```python +langchain_core.callbacks.base.ToolManagerMixin.on_tool_end( + output: typing.Any, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when the tool ends running. + +**Parameters:** + + +The output of the tool. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.base.ToolManagerMixin.on_tool_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run when tool errors. + +**Parameters:** + + +The error that occurred. + + + +The ID of the current run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + + + + + + + + + +```python +langchain_core.callbacks.base.Callbacks = list[BaseCallbackHandler] | BaseCallbackManager | None +``` + + + + + + + + + +```python +langchain_core.callbacks.base._LOGGER = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/file.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/file.mdx new file mode 100644 index 0000000..42bc8db --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/file.mdx @@ -0,0 +1,477 @@ +--- +layout: overview +slug: langchain-core/langchain_core/callbacks/file +title: langchain_core.callbacks.file +--- + +Callback handler that writes to a file. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FileCallbackHandler`](#langchain_core-callbacks-file-FileCallbackHandler) | Callback handler that writes to a file. | + +### Data + +[`_GLOBAL_DEPRECATION_WARNED`](#langchain_core-callbacks-file-_GLOBAL_DEPRECATION_WARNED) + +### API + + + + + +```python +class langchain_core.callbacks.file.FileCallbackHandler( + filename: str, + mode: str = 'a', + color: str | None = None +) +``` + + + + + + +**Bases:** [BaseCallbackHandler](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-BaseCallbackHandler) + +Callback handler that writes to a file. + +This handler supports both context manager usage (recommended) and direct +instantiation (deprecated) for backwards compatibility. + +!!! note + + When not used as a context manager, a deprecation warning will be issued on + first use. The file will be opened immediately in `__init__` and closed in + `__del__` or when `close()` is called explicitly. + +**Parameters:** + + +The file path to write to. + + + +The file open mode. Defaults to `'a'` (append). + + + +Default color for text output. + + +**Examples:** + + + +```python +Using as a context manager (recommended): + +```python +with FileCallbackHandler("output.txt") as handler: + # Use handler with your chain/agent + chain.invoke(inputs, config={"callbacks": [handler]}) +``` + +Direct instantiation (deprecated): + +```python +handler = FileCallbackHandler("output.txt") +# File remains open until handler is garbage collected +try: + chain.invoke(inputs, config={"callbacks": [handler]}) +finally: + handler.close() # Explicit cleanup recommended +``` + + + + + + + + + + +```python +langchain_core.callbacks.file.FileCallbackHandler.__del__() -> None +``` + + + + + + +Destructor to cleanup when done. + + + + + + + +```python +langchain_core.callbacks.file.FileCallbackHandler.__enter__() -> typing_extensions.Self +``` + + + + + + +Enter the context manager. + +!!! note + + The file is already opened in `__init__`, so this just marks that the + handler is being used as a context manager. + +**Returns:** `Self` + +The `FileCallbackHandler` instance. + + + + + + + +```python +langchain_core.callbacks.file.FileCallbackHandler.__exit__( + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object +) -> None +``` + + + + + + +Exit the context manager and close the file. + +**Parameters:** + + +Exception type if an exception occurred. + + + +Exception value if an exception occurred. + + + +Exception traceback if an exception occurred. + + + + + + + + +```python +langchain_core.callbacks.file.FileCallbackHandler._write( + text: str, + color: str | None = None, + end: str = '' +) -> None +``` + + + + + + +Write text to the file with deprecation warning if needed. + +**Parameters:** + + +The text to write to the file. + + + +Optional color for the text. Defaults to `self.color`. + + + +String appended after the text. + + + +Optional file to write to. Defaults to `self.file`. + + +**Raises:** + +- `RuntimeError`: If the file is closed or not available. + + + + + + + +```python +langchain_core.callbacks.file.FileCallbackHandler.close() -> None +``` + + + + + + +Close the file if it's open. + +This method is safe to call multiple times and will only close +the file if it's currently open. + + + + + + + +```python +langchain_core.callbacks.file.FileCallbackHandler.on_agent_action( + action: langchain_core.agents.AgentAction, + color: str | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Handle agent action by writing the action log. + +**Parameters:** + + +The agent action containing the log to write. + + + +Color override for this specific output. + +If `None`, uses `self.color`. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.file.FileCallbackHandler.on_agent_finish( + finish: langchain_core.agents.AgentFinish, + color: str | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Handle agent finish by writing the finish log. + +**Parameters:** + + +The agent finish object containing the log to write. + + + +Color override for this specific output. + +If `None`, uses `self.color`. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.file.FileCallbackHandler.on_chain_end( + outputs: dict[str, typing.Any], + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Print that we finished a chain. + +**Parameters:** + + +The outputs of the chain. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.file.FileCallbackHandler.on_chain_start( + serialized: dict[str, typing.Any], + inputs: dict[str, typing.Any], + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Print that we are entering a chain. + +**Parameters:** + + +The serialized chain information. + + + +The inputs to the chain. + + + +Additional keyword arguments that may contain `'name'`. + + + + + + + + +```python +langchain_core.callbacks.file.FileCallbackHandler.on_text( + text: str, + color: str | None = None, + end: str = '', + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Handle text output. + +**Parameters:** + + +The text to write. + + + +Color override for this specific output. + +If `None`, uses `self.color`. + + + +String appended after the text. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.file.FileCallbackHandler.on_tool_end( + output: str, + color: str | None = None, + observation_prefix: str | None = None, + llm_prefix: str | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Handle tool end by writing the output with optional prefixes. + +**Parameters:** + + +The tool output to write. + + + +Color override for this specific output. + +If `None`, uses `self.color`. + + + +Optional prefix to write before the output. + + + +Optional prefix to write after the output. + + + +Additional keyword arguments. + + + + + + + + + + +```python +langchain_core.callbacks.file._GLOBAL_DEPRECATION_WARNED = False +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/manager.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/manager.mdx new file mode 100644 index 0000000..a161715 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/manager.mdx @@ -0,0 +1,2935 @@ +--- +layout: overview +slug: langchain-core/langchain_core/callbacks/manager +title: langchain_core.callbacks.manager +--- + +Run managers. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AsyncCallbackManager`](#langchain_core-callbacks-manager-AsyncCallbackManager) | Async callback manager that handles callbacks from LangChain. | +| [`AsyncCallbackManagerForChainGroup`](#langchain_core-callbacks-manager-AsyncCallbackManagerForChainGroup) | Async callback manager for the chain group. | +| [`AsyncCallbackManagerForChainRun`](#langchain_core-callbacks-manager-AsyncCallbackManagerForChainRun) | Async callback manager for chain run. | +| [`AsyncCallbackManagerForLLMRun`](#langchain_core-callbacks-manager-AsyncCallbackManagerForLLMRun) | Async callback manager for LLM run. | +| [`AsyncCallbackManagerForRetrieverRun`](#langchain_core-callbacks-manager-AsyncCallbackManagerForRetrieverRun) | Async callback manager for retriever run. | +| [`AsyncCallbackManagerForToolRun`](#langchain_core-callbacks-manager-AsyncCallbackManagerForToolRun) | Async callback manager for tool run. | +| [`AsyncParentRunManager`](#langchain_core-callbacks-manager-AsyncParentRunManager) | Async parent run manager. | +| [`AsyncRunManager`](#langchain_core-callbacks-manager-AsyncRunManager) | Async run manager. | +| [`BaseRunManager`](#langchain_core-callbacks-manager-BaseRunManager) | Base class for run manager (a bound callback manager). | +| [`CallbackManager`](#langchain_core-callbacks-manager-CallbackManager) | Callback manager for LangChain. | +| [`CallbackManagerForChainGroup`](#langchain_core-callbacks-manager-CallbackManagerForChainGroup) | Callback manager for the chain group. | +| [`CallbackManagerForChainRun`](#langchain_core-callbacks-manager-CallbackManagerForChainRun) | Callback manager for chain run. | +| [`CallbackManagerForLLMRun`](#langchain_core-callbacks-manager-CallbackManagerForLLMRun) | Callback manager for LLM run. | +| [`CallbackManagerForRetrieverRun`](#langchain_core-callbacks-manager-CallbackManagerForRetrieverRun) | Callback manager for retriever run. | +| [`CallbackManagerForToolRun`](#langchain_core-callbacks-manager-CallbackManagerForToolRun) | Callback manager for tool run. | +| [`ParentRunManager`](#langchain_core-callbacks-manager-ParentRunManager) | Synchronous parent run manager. | +| [`RunManager`](#langchain_core-callbacks-manager-RunManager) | Synchronous run manager. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_ahandle_event_for_handler`](#langchain_core-callbacks-manager-_ahandle_event_for_handler) | - | +| [`_configure`](#langchain_core-callbacks-manager-_configure) | Configure the callback manager. | +| [`_executor`](#langchain_core-callbacks-manager-_executor) | - | +| [`_get_debug`](#langchain_core-callbacks-manager-_get_debug) | - | +| [`_run_coros`](#langchain_core-callbacks-manager-_run_coros) | - | +| [`adispatch_custom_event`](#langchain_core-callbacks-manager-adispatch_custom_event) | Dispatch an adhoc event to the handlers. | +| [`ahandle_event`](#langchain_core-callbacks-manager-ahandle_event) | Async generic event handler for `AsyncCallbackManager`. | +| [`atrace_as_chain_group`](#langchain_core-callbacks-manager-atrace_as_chain_group) | Get an async callback manager for a chain group in a context manager. | +| [`dispatch_custom_event`](#langchain_core-callbacks-manager-dispatch_custom_event) | Dispatch an adhoc event. | +| [`handle_event`](#langchain_core-callbacks-manager-handle_event) | Generic event handler for `CallbackManager`. | +| [`shielded`](#langchain_core-callbacks-manager-shielded) | Makes so an awaitable method is always shielded from cancellation. | +| [`trace_as_chain_group`](#langchain_core-callbacks-manager-trace_as_chain_group) | Get a callback manager for a chain group in a context manager. | + +### Data + +[`Func`](#langchain_core-callbacks-manager-Func) + +[`T`](#langchain_core-callbacks-manager-T) + +[`logger`](#langchain_core-callbacks-manager-logger) + +### API + + + + + +```python +class langchain_core.callbacks.manager.AsyncCallbackManager() +``` + + + + + + +**Bases:** [BaseCallbackManager](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-BaseCallbackManager) + +Async callback manager that handles callbacks from LangChain. + + + +Return whether the handler is async. + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManager.configure( + inheritable_callbacks: langchain_core.callbacks.base.Callbacks = None, + local_callbacks: langchain_core.callbacks.base.Callbacks = None, + verbose: bool = False, + inheritable_tags: list[str] | None = None, + local_tags: list[str] | None = None, + inheritable_metadata: dict[str, typing.Any] | None = None, + local_metadata: dict[str, typing.Any] | None = None +) -> langchain_core.callbacks.manager.AsyncCallbackManager +``` + + + + + + +classmethod + +Configure the async callback manager. + +**Parameters:** + + +The inheritable callbacks. + + + +The local callbacks. + + + +Whether to enable verbose mode. + + + +The inheritable tags. + + + +The local tags. + + + +The inheritable metadata. + + + +The local metadata. + + +**Returns:** `AsyncCallbackManager` + +The configured async callback manager. + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManager.on_chain_start( + serialized: dict[str, typing.Any] | None, + inputs: dict[str, typing.Any] | typing.Any, + run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun +``` + + + + + + +async + +Async run when chain starts running. + +**Parameters:** + + +The serialized chain. + + + +The inputs to the chain. + + + +The ID of the run. + + + +Additional keyword arguments. + + +**Returns:** `AsyncCallbackManagerForChainRun` + +The async callback manager for the chain run. + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManager.on_chat_model_start( + serialized: dict[str, typing.Any], + messages: list[list[langchain_core.messages.BaseMessage]], + run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> list[langchain_core.callbacks.manager.AsyncCallbackManagerForLLMRun] +``` + + + + + + +async + +Async run when LLM starts running. + +**Parameters:** + + +The serialized LLM. + + + +The list of messages. + + + +The ID of the run. + + + +Additional keyword arguments. + + +**Returns:** `list[AsyncCallbackManagerForLLMRun]` + +The list of async callback managers, one for each LLM run corresponding to + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManager.on_custom_event( + name: str, + data: typing.Any, + run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Dispatch an adhoc event to the handlers (async version). + +This event should NOT be used in any internal LangChain code. The event is meant +specifically for users of the library to dispatch custom events that are +tailored to their application. + +**Parameters:** + + +The name of the adhoc event. + + + +The data for the adhoc event. + + + +The ID of the run. + + +**Raises:** + +- `ValueError`: If additional keyword arguments are passed. + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManager.on_llm_start( + serialized: dict[str, typing.Any], + prompts: list[str], + run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> list[langchain_core.callbacks.manager.AsyncCallbackManagerForLLMRun] +``` + + + + + + +async + +Run when LLM starts running. + +**Parameters:** + + +The serialized LLM. + + + +The list of prompts. + + + +The ID of the run. + + + +Additional keyword arguments. + + +**Returns:** `list[AsyncCallbackManagerForLLMRun]` + +The list of async callback managers, one for each LLM run corresponding to + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManager.on_retriever_start( + serialized: dict[str, typing.Any] | None, + query: str, + run_id: uuid.UUID | None = None, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> langchain_core.callbacks.manager.AsyncCallbackManagerForRetrieverRun +``` + + + + + + +async + +Run when the retriever starts running. + +**Parameters:** + + +The serialized retriever. + + + +The query. + + + +The ID of the run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + +**Returns:** `AsyncCallbackManagerForRetrieverRun` + +The async callback manager for the retriever run. + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManager.on_tool_start( + serialized: dict[str, typing.Any] | None, + input_str: str, + run_id: uuid.UUID | None = None, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> langchain_core.callbacks.manager.AsyncCallbackManagerForToolRun +``` + + + + + + +async + +Run when the tool starts running. + +**Parameters:** + + +The serialized tool. + + + +The input to the tool. + + + +The ID of the run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + +**Returns:** `AsyncCallbackManagerForToolRun` + +The async callback manager for the tool run. + + + + + + + + + +```python +class langchain_core.callbacks.manager.AsyncCallbackManagerForChainGroup( + handlers: list[langchain_core.callbacks.base.BaseCallbackHandler], + inheritable_handlers: list[langchain_core.callbacks.base.BaseCallbackHandler] | None = None, + parent_run_id: uuid.UUID | None = None, + parent_run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [AsyncCallbackManager](#langchain_core-callbacks-manager-AsyncCallbackManager) + +Async callback manager for the chain group. + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForChainGroup.copy() -> langchain_core.callbacks.manager.AsyncCallbackManagerForChainGroup +``` + + + + + + +Return a copy the async callback manager. + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForChainGroup.merge( + other: langchain_core.callbacks.base.BaseCallbackManager +) -> langchain_core.callbacks.manager.AsyncCallbackManagerForChainGroup +``` + + + + + + +Merge the group callback manager with another callback manager. + +Overwrites the merge method in the base class to ensure that the parent run +manager is preserved. Keeps the `parent_run_manager` from the current object. + +**Returns:** `AsyncCallbackManagerForChainGroup` + +A copy of the current `AsyncCallbackManagerForChainGroup` with the handlers, +tags, etc. of the other callback manager merged in. + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForChainGroup.on_chain_end( + outputs: dict[str, typing.Any] | typing.Any, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when traced chain group ends. + +**Parameters:** + + +The outputs of the chain. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForChainGroup.on_chain_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when chain errors. + +**Parameters:** + + +The error. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun() +``` + + + + + + +**Bases:** [AsyncParentRunManager](#langchain_core-callbacks-manager-AsyncParentRunManager), [ChainManagerMixin](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-ChainManagerMixin) + +Async callback manager for chain run. + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun.get_sync() -> langchain_core.callbacks.manager.CallbackManagerForChainRun +``` + + + + + + +Get the equivalent sync `RunManager`. + +**Returns:** `CallbackManagerForChainRun` + +The sync `RunManager`. + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun.on_agent_action( + action: langchain_core.agents.AgentAction, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when agent action is received. + +**Parameters:** + + +The agent action. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun.on_agent_finish( + finish: langchain_core.agents.AgentFinish, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when agent finish is received. + +**Parameters:** + + +The agent finish. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun.on_chain_end( + outputs: dict[str, typing.Any] | typing.Any, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when a chain ends running. + +**Parameters:** + + +The outputs of the chain. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun.on_chain_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when chain errors. + +**Parameters:** + + +The error. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.manager.AsyncCallbackManagerForLLMRun() +``` + + + + + + +**Bases:** [AsyncRunManager](#langchain_core-callbacks-manager-AsyncRunManager), [LLMManagerMixin](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-LLMManagerMixin) + +Async callback manager for LLM run. + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForLLMRun.get_sync() -> langchain_core.callbacks.manager.CallbackManagerForLLMRun +``` + + + + + + +Get the equivalent sync `RunManager`. + +**Returns:** `CallbackManagerForLLMRun` + +The sync `RunManager`. + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForLLMRun.on_llm_end( + response: langchain_core.outputs.LLMResult, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when LLM ends running. + +**Parameters:** + + +The LLM result. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForLLMRun.on_llm_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when LLM errors. + +**Parameters:** + + +The error. + + + +Additional keyword arguments. + +- response (LLMResult): The response which was generated before + the error occurred. + + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForLLMRun.on_llm_new_token( + token: str, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when LLM generates a new token. + +**Parameters:** + + +The new token. + + + +The chunk. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.manager.AsyncCallbackManagerForRetrieverRun() +``` + + + + + + +**Bases:** [AsyncParentRunManager](#langchain_core-callbacks-manager-AsyncParentRunManager), [RetrieverManagerMixin](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-RetrieverManagerMixin) + +Async callback manager for retriever run. + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForRetrieverRun.get_sync() -> langchain_core.callbacks.manager.CallbackManagerForRetrieverRun +``` + + + + + + +Get the equivalent sync `RunManager`. + +**Returns:** `CallbackManagerForRetrieverRun` + +The sync `RunManager`. + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForRetrieverRun.on_retriever_end( + documents: collections.abc.Sequence[langchain_core.documents.Document], + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when the retriever ends running. + +**Parameters:** + + +The retrieved documents. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForRetrieverRun.on_retriever_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when retriever errors. + +**Parameters:** + + +The error. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.manager.AsyncCallbackManagerForToolRun() +``` + + + + + + +**Bases:** [AsyncParentRunManager](#langchain_core-callbacks-manager-AsyncParentRunManager), [ToolManagerMixin](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-ToolManagerMixin) + +Async callback manager for tool run. + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForToolRun.get_sync() -> langchain_core.callbacks.manager.CallbackManagerForToolRun +``` + + + + + + +Get the equivalent sync `RunManager`. + +**Returns:** `CallbackManagerForToolRun` + +The sync `RunManager`. + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForToolRun.on_tool_end( + output: typing.Any, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Async run when the tool ends running. + +**Parameters:** + + +The output of the tool. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.AsyncCallbackManagerForToolRun.on_tool_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when tool errors. + +**Parameters:** + + +The error. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.manager.AsyncParentRunManager() +``` + + + + + + +**Bases:** [AsyncRunManager](#langchain_core-callbacks-manager-AsyncRunManager) + +Async parent run manager. + + + + + + +```python +langchain_core.callbacks.manager.AsyncParentRunManager.get_child( + tag: str | None = None +) -> langchain_core.callbacks.manager.AsyncCallbackManager +``` + + + + + + +Get a child callback manager. + +**Parameters:** + + +The tag for the child callback manager. + + +**Returns:** `AsyncCallbackManager` + +The child callback manager. + + + + + + + + + +```python +class langchain_core.callbacks.manager.AsyncRunManager() +``` + + + + + + +Abstract + +**Bases:** [BaseRunManager](#langchain_core-callbacks-manager-BaseRunManager) + +Async run manager. + + + + + + +```python +langchain_core.callbacks.manager.AsyncRunManager.get_sync() -> langchain_core.callbacks.manager.RunManager +``` + + + + + + +abstract + +Get the equivalent sync `RunManager`. + +**Returns:** `RunManager` + +The sync `RunManager`. + + + + + + + +```python +langchain_core.callbacks.manager.AsyncRunManager.on_retry( + retry_state: tenacity.RetryCallState, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Async run when a retry is received. + +**Parameters:** + + +The retry state. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.AsyncRunManager.on_text( + text: str, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when a text is received. + +**Parameters:** + + +The received text. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.manager.BaseRunManager( + run_id: uuid.UUID, + handlers: list[langchain_core.callbacks.base.BaseCallbackHandler], + inheritable_handlers: list[langchain_core.callbacks.base.BaseCallbackHandler], + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + inheritable_tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + inheritable_metadata: dict[str, typing.Any] | None = None +) +``` + + + + + + +**Bases:** [RunManagerMixin](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-RunManagerMixin) + +Base class for run manager (a bound callback manager). + + + + + + + + + + + + + + + + + +```python +langchain_core.callbacks.manager.BaseRunManager.get_noop_manager() -> typing_extensions.Self +``` + + + + + + +classmethod + +Return a manager that doesn't perform any operations. + +**Returns:** `Self` + +The noop manager. + + + + + + + + + +```python +class langchain_core.callbacks.manager.CallbackManager() +``` + + + + + + +**Bases:** [BaseCallbackManager](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-BaseCallbackManager) + +Callback manager for LangChain. + + + + + + +```python +langchain_core.callbacks.manager.CallbackManager.configure( + inheritable_callbacks: langchain_core.callbacks.base.Callbacks = None, + local_callbacks: langchain_core.callbacks.base.Callbacks = None, + verbose: bool = False, + inheritable_tags: list[str] | None = None, + local_tags: list[str] | None = None, + inheritable_metadata: dict[str, typing.Any] | None = None, + local_metadata: dict[str, typing.Any] | None = None +) -> langchain_core.callbacks.manager.CallbackManager +``` + + + + + + +classmethod + +Configure the callback manager. + +**Parameters:** + + +The inheritable callbacks. + + + +The local callbacks. + + + +Whether to enable verbose mode. + + + +The inheritable tags. + + + +The local tags. + + + +The inheritable metadata. + + + +The local metadata. + + +**Returns:** `CallbackManager` + +The configured callback manager. + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManager.on_chain_start( + serialized: dict[str, typing.Any] | None, + inputs: dict[str, typing.Any] | typing.Any, + run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> langchain_core.callbacks.manager.CallbackManagerForChainRun +``` + + + + + + +Run when chain starts running. + +**Parameters:** + + +The serialized chain. + + + +The inputs to the chain. + + + +The ID of the run. + + + +Additional keyword arguments. + + +**Returns:** `CallbackManagerForChainRun` + +The callback manager for the chain run. + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManager.on_chat_model_start( + serialized: dict[str, typing.Any], + messages: list[list[langchain_core.messages.BaseMessage]], + run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> list[langchain_core.callbacks.manager.CallbackManagerForLLMRun] +``` + + + + + + +Run when chat model starts running. + +**Parameters:** + + +The serialized LLM. + + + +The list of messages. + + + +The ID of the run. + + + +Additional keyword arguments. + + +**Returns:** `list[CallbackManagerForLLMRun]` + +A callback manager for each list of messages as an LLM run. + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManager.on_custom_event( + name: str, + data: typing.Any, + run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Dispatch an adhoc event to the handlers (async version). + +This event should NOT be used in any internal LangChain code. The event is meant +specifically for users of the library to dispatch custom events that are +tailored to their application. + +**Parameters:** + + +The name of the adhoc event. + + + +The data for the adhoc event. + + + +The ID of the run. + + +**Raises:** + +- `ValueError`: If additional keyword arguments are passed. + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManager.on_llm_start( + serialized: dict[str, typing.Any], + prompts: list[str], + run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> list[langchain_core.callbacks.manager.CallbackManagerForLLMRun] +``` + + + + + + +Run when LLM starts running. + +**Parameters:** + + +The serialized LLM. + + + +The list of prompts. + + + +The ID of the run. + + + +Additional keyword arguments. + + +**Returns:** `list[CallbackManagerForLLMRun]` + +A callback manager for each prompt as an LLM run. + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManager.on_retriever_start( + serialized: dict[str, typing.Any] | None, + query: str, + run_id: uuid.UUID | None = None, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> langchain_core.callbacks.manager.CallbackManagerForRetrieverRun +``` + + + + + + +Run when the retriever starts running. + +**Parameters:** + + +The serialized retriever. + + + +The query. + + + +The ID of the run. + + + +The ID of the parent run. + + + +Additional keyword arguments. + + +**Returns:** `CallbackManagerForRetrieverRun` + +The callback manager for the retriever run. + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManager.on_tool_start( + serialized: dict[str, typing.Any] | None, + input_str: str, + run_id: uuid.UUID | None = None, + parent_run_id: uuid.UUID | None = None, + inputs: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.callbacks.manager.CallbackManagerForToolRun +``` + + + + + + +Run when tool starts running. + +**Parameters:** + + +Serialized representation of the tool. + + + +The input to the tool as a string. + +Non-string inputs are cast to strings. + + + +ID for the run. + + + +The ID of the parent run. + + + +The original input to the tool if provided. + +Recommended for usage instead of input_str when the original input is +needed. + +If provided, the inputs are expected to be formatted as a dict. The keys +will correspond to the named-arguments in the tool. + + + +The keyword arguments to pass to the event handler + + +**Returns:** `CallbackManagerForToolRun` + +The callback manager for the tool run. + + + + + + + + + +```python +class langchain_core.callbacks.manager.CallbackManagerForChainGroup( + handlers: list[langchain_core.callbacks.base.BaseCallbackHandler], + inheritable_handlers: list[langchain_core.callbacks.base.BaseCallbackHandler] | None = None, + parent_run_id: uuid.UUID | None = None, + parent_run_manager: langchain_core.callbacks.manager.CallbackManagerForChainRun, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [CallbackManager](#langchain_core-callbacks-manager-CallbackManager) + +Callback manager for the chain group. + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForChainGroup.copy() -> langchain_core.callbacks.manager.CallbackManagerForChainGroup +``` + + + + + + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForChainGroup.merge( + other: langchain_core.callbacks.base.BaseCallbackManager +) -> langchain_core.callbacks.manager.CallbackManagerForChainGroup +``` + + + + + + +Merge the group callback manager with another callback manager. + +Overwrites the merge method in the base class to ensure that the parent run +manager is preserved. Keeps the `parent_run_manager` from the current object. + +**Returns:** `CallbackManagerForChainGroup` + +A copy of the current object with the handlers, tags, and other attributes + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForChainGroup.on_chain_end( + outputs: dict[str, typing.Any] | typing.Any, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when traced chain group ends. + +**Parameters:** + + +The outputs of the chain. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForChainGroup.on_chain_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when chain errors. + +**Parameters:** + + +The error. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.manager.CallbackManagerForChainRun() +``` + + + + + + +**Bases:** [ParentRunManager](#langchain_core-callbacks-manager-ParentRunManager), [ChainManagerMixin](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-ChainManagerMixin) + +Callback manager for chain run. + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForChainRun.on_agent_action( + action: langchain_core.agents.AgentAction, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when agent action is received. + +**Parameters:** + + +The agent action. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForChainRun.on_agent_finish( + finish: langchain_core.agents.AgentFinish, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when agent finish is received. + +**Parameters:** + + +The agent finish. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForChainRun.on_chain_end( + outputs: dict[str, typing.Any] | typing.Any, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when chain ends running. + +**Parameters:** + + +The outputs of the chain. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForChainRun.on_chain_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when chain errors. + +**Parameters:** + + +The error. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.manager.CallbackManagerForLLMRun() +``` + + + + + + +**Bases:** [RunManager](#langchain_core-callbacks-manager-RunManager), [LLMManagerMixin](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-LLMManagerMixin) + +Callback manager for LLM run. + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForLLMRun.on_llm_end( + response: langchain_core.outputs.LLMResult, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when LLM ends running. + +**Parameters:** + + +The LLM result. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForLLMRun.on_llm_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when LLM errors. + +**Parameters:** + + +The error. + + + +Additional keyword arguments. + +- response (LLMResult): The response which was generated before + the error occurred. + + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForLLMRun.on_llm_new_token( + token: str, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when LLM generates a new token. + +**Parameters:** + + +The new token. + + + +The chunk. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.manager.CallbackManagerForRetrieverRun() +``` + + + + + + +**Bases:** [ParentRunManager](#langchain_core-callbacks-manager-ParentRunManager), [RetrieverManagerMixin](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-RetrieverManagerMixin) + +Callback manager for retriever run. + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForRetrieverRun.on_retriever_end( + documents: collections.abc.Sequence[langchain_core.documents.Document], + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when retriever ends running. + +**Parameters:** + + +The retrieved documents. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForRetrieverRun.on_retriever_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when retriever errors. + +**Parameters:** + + +The error. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.manager.CallbackManagerForToolRun() +``` + + + + + + +**Bases:** [ParentRunManager](#langchain_core-callbacks-manager-ParentRunManager), [ToolManagerMixin](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-ToolManagerMixin) + +Callback manager for tool run. + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForToolRun.on_tool_end( + output: typing.Any, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when the tool ends running. + +**Parameters:** + + +The output of the tool. + + + +The keyword arguments to pass to the event handler + + + + + + + + +```python +langchain_core.callbacks.manager.CallbackManagerForToolRun.on_tool_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when tool errors. + +**Parameters:** + + +The error. + + + +Additional keyword arguments. + + + + + + + + + + +```python +class langchain_core.callbacks.manager.ParentRunManager() +``` + + + + + + +**Bases:** [RunManager](#langchain_core-callbacks-manager-RunManager) + +Synchronous parent run manager. + + + + + + +```python +langchain_core.callbacks.manager.ParentRunManager.get_child( + tag: str | None = None +) -> langchain_core.callbacks.manager.CallbackManager +``` + + + + + + +Get a child callback manager. + +**Parameters:** + + +The tag for the child callback manager. + + +**Returns:** `CallbackManager` + +The child callback manager. + + + + + + + + + +```python +class langchain_core.callbacks.manager.RunManager() +``` + + + + + + +**Bases:** [BaseRunManager](#langchain_core-callbacks-manager-BaseRunManager) + +Synchronous run manager. + + + + + + +```python +langchain_core.callbacks.manager.RunManager.on_retry( + retry_state: tenacity.RetryCallState, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when a retry is received. + +**Parameters:** + + +The retry state. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.manager.RunManager.on_text( + text: str, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when a text is received. + +**Parameters:** + + +The received text. + + + +Additional keyword arguments. + + + + + + + + + + +```python +langchain_core.callbacks.manager._ahandle_event_for_handler( + handler: langchain_core.callbacks.base.BaseCallbackHandler, + event_name: str, + ignore_condition_name: str | None, + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + + +```python +langchain_core.callbacks.manager._configure( + callback_manager_cls: type[langchain_core.callbacks.manager.T], + inheritable_callbacks: langchain_core.callbacks.base.Callbacks = None, + local_callbacks: langchain_core.callbacks.base.Callbacks = None, + inheritable_tags: list[str] | None = None, + local_tags: list[str] | None = None, + inheritable_metadata: dict[str, typing.Any] | None = None, + local_metadata: dict[str, typing.Any] | None = None, + verbose: bool = False +) -> langchain_core.callbacks.manager.T +``` + + + + + + +Configure the callback manager. + +**Parameters:** + + +The callback manager class. + + + +The inheritable callbacks. + + + +The local callbacks. + + + +The inheritable tags. + + + +The local tags. + + + +The inheritable metadata. + + + +The local metadata. + + + +Whether to enable verbose mode. + + +**Returns:** `T` + +The configured callback manager. + +**Raises:** + +- `RuntimeError`: If `LANGCHAIN_TRACING` is set but `LANGCHAIN_TRACING_V2` is not. + + + + + + + + +```python +langchain_core.callbacks.manager._executor() -> concurrent.futures.ThreadPoolExecutor +``` + + + + + + + + + + + + + +```python +langchain_core.callbacks.manager._get_debug() -> bool +``` + + + + + + + + + + + + + +```python +langchain_core.callbacks.manager._run_coros( + coros: list[collections.abc.Coroutine[typing.Any, typing.Any, typing.Any]] +) -> None +``` + + + + + + + + + + + + + +```python +langchain_core.callbacks.manager.adispatch_custom_event( + name: str, + data: typing.Any, + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> None +``` + + + + + + +async + +Dispatch an adhoc event to the handlers. + +Example: Use with astream events + + ```python + from langchain_core.callbacks import ( + AsyncCallbackHandler, + adispatch_custom_event + ) + from langchain_core.runnable import RunnableLambda + + class CustomCallbackManager(AsyncCallbackHandler): + async def on_custom_event( + self, + name: str, + data: Any, + *, + run_id: UUID, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + print(f"Received custom event: {name} with data: {data}") + + callback = CustomCallbackManager() + + async def foo(inputs): + await adispatch_custom_event("event_type_1", {"bar": "buzz}) + await adispatch_custom_event("event_type_2", 5) + return inputs + + foo_ = RunnableLambda(foo) + + async for event in foo_.ainvoke_stream( + {"a": "1"}, + version="v2", + config={"callbacks": [CustomCallbackManager()]} + ): + print(event) + ``` + +!!! warning + + If using python 3.10 and async, you MUST specify the `config` parameter or the + function will raise an error. This is due to a limitation in asyncio for python + 3.10 that prevents LangChain from automatically propagating the config object on + the user's behalf. + +**Parameters:** + + +The name of the adhoc event. + + + +The data for the adhoc event. + +Free form data. Ideally should be JSON serializable to avoid serialization +issues downstream, but this is not enforced. + + + +Optional config object. + +Mirrors the async API but not strictly needed. + + +**Raises:** + +- `RuntimeError`: If there is no parent run ID available to associate the event +with. + + + + + + + + +```python +langchain_core.callbacks.manager.ahandle_event( + handlers: list[langchain_core.callbacks.base.BaseCallbackHandler], + event_name: str, + ignore_condition_name: str | None, + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Async generic event handler for `AsyncCallbackManager`. + +**Parameters:** + + +The list of handlers that will handle the event. + + + +The name of the event (e.g., `'on_llm_start'`). + + + +Name of the attribute defined on handler that if `True` +will cause the handler to be skipped for the given event. + + + +The arguments to pass to the event handler. + + + +The keyword arguments to pass to the event handler. + + + + + + + + + +```python +langchain_core.callbacks.manager.atrace_as_chain_group( + group_name: str, + callback_manager: langchain_core.callbacks.manager.AsyncCallbackManager | None = None, + inputs: dict[str, typing.Any] | None = None, + project_name: str | None = None, + example_id: str | uuid.UUID | None = None, + run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None +) -> collections.abc.AsyncGenerator[langchain_core.callbacks.manager.AsyncCallbackManagerForChainGroup, None] +``` + + + + + + +async + +Get an async callback manager for a chain group in a context manager. + +Useful for grouping different async calls together as a single run even if they +aren't composed in a single chain. + +!!! note + + Must have `LANGCHAIN_TRACING_V2` env var set to true to see the trace in + LangSmith. + +**Parameters:** + + +The name of the chain group. + + + +The async callback manager to use, which manages tracing and +other callback behavior. + + + +The inputs to the chain group. + + + +The name of the project. + + + +The ID of the example. + + + +The ID of the run. + + + +The inheritable tags to apply to all runs. + + + +The metadata to apply to all runs. + + + + + + + + + +```python +langchain_core.callbacks.manager.dispatch_custom_event( + name: str, + data: typing.Any, + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> None +``` + + + + + + +Dispatch an adhoc event. + +**Parameters:** + + +The name of the adhoc event. + + + +The data for the adhoc event. + +Free form data. Ideally should be JSON serializable to avoid serialization +issues downstream, but this is not enforced. + + + +Optional config object. + +Mirrors the async API but not strictly needed. + + +**Raises:** + +- `RuntimeError`: If there is no parent run ID available to associate the event +with. + + + + + + + + +```python +langchain_core.callbacks.manager.handle_event( + handlers: list[langchain_core.callbacks.base.BaseCallbackHandler], + event_name: str, + ignore_condition_name: str | None, + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Generic event handler for `CallbackManager`. + +**Parameters:** + + +The list of handlers that will handle the event. + + + +The name of the event (e.g., `'on_llm_start'`). + + + +Name of the attribute defined on handler that if `True` +will cause the handler to be skipped for the given event. + + + +The arguments to pass to the event handler. + + + +The keyword arguments to pass to the event handler + + + + + + + + + +```python +langchain_core.callbacks.manager.shielded( + func: langchain_core.callbacks.manager.Func +) -> langchain_core.callbacks.manager.Func +``` + + + + + + +Makes so an awaitable method is always shielded from cancellation. + +**Parameters:** + + +The function to shield. + + +**Returns:** `Func` + +The shielded function + + + + + + + + +```python +langchain_core.callbacks.manager.trace_as_chain_group( + group_name: str, + callback_manager: langchain_core.callbacks.manager.CallbackManager | None = None, + inputs: dict[str, typing.Any] | None = None, + project_name: str | None = None, + example_id: str | uuid.UUID | None = None, + run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None +) -> collections.abc.Generator[langchain_core.callbacks.manager.CallbackManagerForChainGroup, None, None] +``` + + + + + + +Get a callback manager for a chain group in a context manager. + +Useful for grouping different calls together as a single run even if they aren't +composed in a single chain. + +!!! note + + Must have `LANGCHAIN_TRACING_V2` env var set to true to see the trace in + LangSmith. + +**Parameters:** + + +The name of the chain group. + + + +The callback manager to use. + + + +The inputs to the chain group. + + + +The name of the project. + + + +The ID of the example. + + + +The ID of the run. + + + +The inheritable tags to apply to all runs. + + + +The metadata to apply to all runs. + + + + + + + + + +```python +langchain_core.callbacks.manager.Func = TypeVar('Func', bound=Callable) +``` + + + + + + + + + +```python +langchain_core.callbacks.manager.T = TypeVar('T', CallbackManager, AsyncCallbackManager) +``` + + + + + + + + + +```python +langchain_core.callbacks.manager.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/stdout.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/stdout.mdx new file mode 100644 index 0000000..d10b6df --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/stdout.mdx @@ -0,0 +1,259 @@ +--- +layout: overview +slug: langchain-core/langchain_core/callbacks/stdout +title: langchain_core.callbacks.stdout +--- + +Callback handler that prints to std out. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`StdOutCallbackHandler`](#langchain_core-callbacks-stdout-StdOutCallbackHandler) | Callback handler that prints to std out. | + +### API + + + + + +```python +class langchain_core.callbacks.stdout.StdOutCallbackHandler( + color: str | None = None +) +``` + + + + + + +**Bases:** [BaseCallbackHandler](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-BaseCallbackHandler) + +Callback handler that prints to std out. + + + + + + +```python +langchain_core.callbacks.stdout.StdOutCallbackHandler.on_agent_action( + action: langchain_core.agents.AgentAction, + color: str | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run on agent action. + +**Parameters:** + + +The agent action. + + + +The color to use for the text. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.stdout.StdOutCallbackHandler.on_agent_finish( + finish: langchain_core.agents.AgentFinish, + color: str | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run on the agent end. + +**Parameters:** + + +The agent finish. + + + +The color to use for the text. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.stdout.StdOutCallbackHandler.on_chain_end( + outputs: dict[str, typing.Any], + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Print out that we finished a chain. + +**Parameters:** + + +The outputs of the chain. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.stdout.StdOutCallbackHandler.on_chain_start( + serialized: dict[str, typing.Any], + inputs: dict[str, typing.Any], + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Print out that we are entering a chain. + +**Parameters:** + + +The serialized chain. + + + +The inputs to the chain. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.stdout.StdOutCallbackHandler.on_text( + text: str, + color: str | None = None, + end: str = '', + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when the agent ends. + +**Parameters:** + + +The text to print. + + + +The color to use for the text. + + + +The end character to use. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.stdout.StdOutCallbackHandler.on_tool_end( + output: typing.Any, + color: str | None = None, + observation_prefix: str | None = None, + llm_prefix: str | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +If not the final action, print out observation. + +**Parameters:** + + +The output to print. + + + +The color to use for the text. + + + +The observation prefix. + + + +The LLM prefix. + + + +Additional keyword arguments. + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/streaming_stdout.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/streaming_stdout.mdx new file mode 100644 index 0000000..05f61f3 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/streaming_stdout.mdx @@ -0,0 +1,479 @@ +--- +layout: overview +slug: langchain-core/langchain_core/callbacks/streaming_stdout +title: langchain_core.callbacks.streaming_stdout +--- + +Callback Handler streams to stdout on new llm token. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`StreamingStdOutCallbackHandler`](#langchain_core-callbacks-streaming_stdout-StreamingStdOutCallbackHandler) | Callback handler for streaming. | + +### API + + + + + +```python +class langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler() +``` + + + + + + +**Bases:** [BaseCallbackHandler](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-BaseCallbackHandler) + +Callback handler for streaming. + +!!! warning "Only works with LLMs that support streaming." + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_agent_action( + action: langchain_core.agents.AgentAction, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run on agent action. + +**Parameters:** + + +The agent action. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_agent_finish( + finish: langchain_core.agents.AgentFinish, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run on the agent end. + +**Parameters:** + + +The agent finish. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_chain_end( + outputs: dict[str, typing.Any], + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when a chain ends running. + +**Parameters:** + + +The outputs of the chain. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_chain_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when chain errors. + +**Parameters:** + + +The error that occurred. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_chain_start( + serialized: dict[str, typing.Any], + inputs: dict[str, typing.Any], + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when a chain starts running. + +**Parameters:** + + +The serialized chain. + + + +The inputs to the chain. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_chat_model_start( + serialized: dict[str, typing.Any], + messages: list[list[langchain_core.messages.BaseMessage]], + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when LLM starts running. + +**Parameters:** + + +The serialized LLM. + + + +The messages to run. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_llm_end( + response: langchain_core.outputs.LLMResult, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when LLM ends running. + +**Parameters:** + + +The response from the LLM. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_llm_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when LLM errors. + +**Parameters:** + + +The error that occurred. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_llm_new_token( + token: str, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run on new LLM token. Only available when streaming is enabled. + +**Parameters:** + + +The new token. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_llm_start( + serialized: dict[str, typing.Any], + prompts: list[str], + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when LLM starts running. + +**Parameters:** + + +The serialized LLM. + + + +The prompts to run. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_text( + text: str, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run on an arbitrary text. + +**Parameters:** + + +The text to print. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_tool_end( + output: typing.Any, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when tool ends running. + +**Parameters:** + + +The output of the tool. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_tool_error( + error: BaseException, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when tool errors. + +**Parameters:** + + +The error that occurred. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler.on_tool_start( + serialized: dict[str, typing.Any], + input_str: str, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Run when the tool starts running. + +**Parameters:** + + +The serialized tool. + + + +The input string. + + + +Additional keyword arguments. + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/usage.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/usage.mdx new file mode 100644 index 0000000..f15308d --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/callbacks/usage.mdx @@ -0,0 +1,116 @@ +--- +layout: overview +slug: langchain-core/langchain_core/callbacks/usage +title: langchain_core.callbacks.usage +--- + +Callback Handler that tracks `AIMessage.usage_metadata`. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`UsageMetadataCallbackHandler`](#langchain_core-callbacks-usage-UsageMetadataCallbackHandler) | Callback Handler that tracks `AIMessage.usage_metadata`. | + +### Functions + +| Name | Description | +|------|-------------| +| [`get_usage_metadata_callback`](#langchain_core-callbacks-usage-get_usage_metadata_callback) | Get usage metadata callback. | + +### API + + + + + +```python +class langchain_core.callbacks.usage.UsageMetadataCallbackHandler() +``` + + + + + + +**Bases:** [BaseCallbackHandler](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-BaseCallbackHandler) + +Callback Handler that tracks `AIMessage.usage_metadata`. + +!!! version-added "Added in `langchain-core` 0.3.49" + + + + + + + + + + + +```python +langchain_core.callbacks.usage.UsageMetadataCallbackHandler.__repr__() -> str +``` + + + + + + + + + + + + +```python +langchain_core.callbacks.usage.UsageMetadataCallbackHandler.on_llm_end( + response: langchain_core.outputs.LLMResult, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Collect token usage. + + + + + + + + + +```python +langchain_core.callbacks.usage.get_usage_metadata_callback( + name: str = 'usage_metadata_callback' +) -> collections.abc.Generator[langchain_core.callbacks.usage.UsageMetadataCallbackHandler, None, None] +``` + + + + + + +Get usage metadata callback. + +Get context manager for tracking usage metadata across chat model calls using +[`AIMessage.usage_metadata`][langchain.messages.AIMessage.usage_metadata]. + +!!! version-added "Added in `langchain-core` 0.3.49" + +**Parameters:** + + +The name of the context variable. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_history.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_history.mdx new file mode 100644 index 0000000..866b225 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_history.mdx @@ -0,0 +1,442 @@ +--- +layout: overview +slug: langchain-core/langchain_core/chat_history +title: langchain_core.chat_history +--- + +Chat message history stores a history of the message interactions in a chat. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseChatMessageHistory`](#langchain_core-chat_history-BaseChatMessageHistory) | Abstract base class for storing chat message history. | +| [`InMemoryChatMessageHistory`](#langchain_core-chat_history-InMemoryChatMessageHistory) | In memory implementation of chat message history. | + +### API + + + + + +```python +class langchain_core.chat_history.BaseChatMessageHistory() +``` + + + + + + +Abstract + +Abstract base class for storing chat message history. + +Implementations guidelines: + +Implementations are expected to over-ride all or some of the following methods: + +* `add_messages`: sync variant for bulk addition of messages +* `aadd_messages`: async variant for bulk addition of messages +* `messages`: sync variant for getting messages +* `aget_messages`: async variant for getting messages +* `clear`: sync variant for clearing messages +* `aclear`: async variant for clearing messages + +`add_messages` contains a default implementation that calls `add_message` +for each message in the sequence. This is provided for backwards compatibility +with existing implementations which only had `add_message`. + +Async variants all have default implementations that call the sync variants. +Implementers can choose to override the async implementations to provide +truly async implementations. + +Usage guidelines: + +When used for updating history, users should favor usage of `add_messages` +over `add_message` or other variants like `add_user_message` and `add_ai_message` +to avoid unnecessary round-trips to the underlying persistence layer. + + + +A property or attribute that returns a list of messages. + +In general, getting the messages may involve IO to the underlying persistence +layer, so this operation is expected to incur some latency. + + + + + +```python +langchain_core.chat_history.BaseChatMessageHistory.__str__() -> str +``` + + + + + + +Return a string representation of the chat history. + + + + + + + +```python +langchain_core.chat_history.BaseChatMessageHistory.aadd_messages( + messages: collections.abc.Sequence[langchain_core.messages.BaseMessage] +) -> None +``` + + + + + + +async + +Async add a list of messages. + +**Parameters:** + + +A sequence of `BaseMessage` objects to store. + + + + + + + + +```python +langchain_core.chat_history.BaseChatMessageHistory.aclear() -> None +``` + + + + + + +async + +Async remove all messages from the store. + + + + + + + +```python +langchain_core.chat_history.BaseChatMessageHistory.add_ai_message( + message: langchain_core.messages.AIMessage | str +) -> None +``` + + + + + + +Convenience method for adding an `AIMessage` string to the store. + +!!! note + + This is a convenience method. Code should favor the bulk `add_messages` + interface instead to save on round-trips to the persistence layer. + +This method may be deprecated in a future release. + +**Parameters:** + + +The `AIMessage` to add. + + + + + + + + +```python +langchain_core.chat_history.BaseChatMessageHistory.add_message( + message: langchain_core.messages.BaseMessage +) -> None +``` + + + + + + +Add a Message object to the store. + +**Parameters:** + + +A `BaseMessage` object to store. + + +**Raises:** + +- `NotImplementedError`: If the sub-class has not implemented an efficient +`add_messages` method. + + + + + + + +```python +langchain_core.chat_history.BaseChatMessageHistory.add_messages( + messages: collections.abc.Sequence[langchain_core.messages.BaseMessage] +) -> None +``` + + + + + + +Add a list of messages. + +Implementations should over-ride this method to handle bulk addition of messages +in an efficient manner to avoid unnecessary round-trips to the underlying store. + +**Parameters:** + + +A sequence of `BaseMessage` objects to store. + + + + + + + + +```python +langchain_core.chat_history.BaseChatMessageHistory.add_user_message( + message: langchain_core.messages.HumanMessage | str +) -> None +``` + + + + + + +Convenience method for adding a human message string to the store. + +!!! note + + This is a convenience method. Code should favor the bulk `add_messages` + interface instead to save on round-trips to the persistence layer. + +This method may be deprecated in a future release. + +**Parameters:** + + +The `HumanMessage` to add to the store. + + + + + + + + +```python +langchain_core.chat_history.BaseChatMessageHistory.aget_messages() -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +async + +Async version of getting messages. + +Can over-ride this method to provide an efficient async implementation. + +In general, fetching messages may involve IO to the underlying persistence +layer. + +**Returns:** `list[BaseMessage]` + +The messages. + + + + + + + +```python +langchain_core.chat_history.BaseChatMessageHistory.clear() -> None +``` + + + + + + +abstract + +Remove all messages from the store. + + + + + + + + + +```python +class langchain_core.chat_history.InMemoryChatMessageHistory() +``` + + + + + + +**Bases:** [BaseChatMessageHistory](#langchain_core-chat_history-BaseChatMessageHistory), `BaseModel` + +In memory implementation of chat message history. + +Stores messages in a memory list. + + + +A list of messages stored in memory. + + + + + +```python +langchain_core.chat_history.InMemoryChatMessageHistory.aadd_messages( + messages: collections.abc.Sequence[langchain_core.messages.BaseMessage] +) -> None +``` + + + + + + +async + +Async add messages to the store. + +**Parameters:** + + +The messages to add. + + + + + + + + +```python +langchain_core.chat_history.InMemoryChatMessageHistory.aclear() -> None +``` + + + + + + +async + +Async clear all messages from the store. + + + + + + + +```python +langchain_core.chat_history.InMemoryChatMessageHistory.add_message( + message: langchain_core.messages.BaseMessage +) -> None +``` + + + + + + +Add a self-created message to the store. + +**Parameters:** + + +The message to add. + + + + + + + + +```python +langchain_core.chat_history.InMemoryChatMessageHistory.aget_messages() -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +async + +Async version of getting messages. + +Can over-ride this method to provide an efficient async implementation. + +In general, fetching messages may involve IO to the underlying persistence +layer. + +**Returns:** `list[BaseMessage]` + +List of messages. + + + + + + + +```python +langchain_core.chat_history.InMemoryChatMessageHistory.clear() -> None +``` + + + + + + +Clear all messages from the store. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_loaders.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_loaders.mdx new file mode 100644 index 0000000..b3255c2 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_loaders.mdx @@ -0,0 +1,81 @@ +--- +layout: overview +slug: langchain-core/langchain_core/chat_loaders +title: langchain_core.chat_loaders +--- + +Chat loaders. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseChatLoader`](#langchain_core-chat_loaders-BaseChatLoader) | Base class for chat loaders. | + +### API + + + + + +```python +class langchain_core.chat_loaders.BaseChatLoader() +``` + + + + + + +Abstract + +Base class for chat loaders. + + + + + + +```python +langchain_core.chat_loaders.BaseChatLoader.lazy_load() -> collections.abc.Iterator[langchain_core.chat_sessions.ChatSession] +``` + + + + + + +abstract + +Lazy load the chat sessions. + +**Returns:** `Iterator[ChatSession]` + +An iterator of chat sessions. + + + + + + + +```python +langchain_core.chat_loaders.BaseChatLoader.load() -> list[langchain_core.chat_sessions.ChatSession] +``` + + + + + + +Eagerly load the chat sessions into memory. + +**Returns:** `list[ChatSession]` + +A list of chat sessions. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_sessions.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_sessions.mdx new file mode 100644 index 0000000..5114755 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/chat_sessions.mdx @@ -0,0 +1,46 @@ +--- +layout: overview +slug: langchain-core/langchain_core/chat_sessions +title: langchain_core.chat_sessions +--- + +**Chat Sessions** are a collection of messages and function calls. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ChatSession`](#langchain_core-chat_sessions-ChatSession) | Chat Session. | + +### API + + + + + +```python +class langchain_core.chat_sessions.ChatSession +``` + + + + + + +**Bases:** `typing.TypedDict` + +Chat Session. + +Chat Session represents a single conversation, channel, or other group of messages. + + +A sequence of the function calling specs for the messages. + + + +A sequence of the LangChain chat messages loaded from the source. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders.mdx new file mode 100644 index 0000000..8d6b0d6 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders.mdx @@ -0,0 +1,88 @@ +--- +layout: overview +slug: langchain-core/langchain_core/document_loaders +title: langchain_core.document_loaders +--- + +Document loaders. + +## Submodules + +- **[`langchain_core.document_loaders.base`](/langchain-core/langchain_core/document_loaders/base)** +- **[`langchain_core.document_loaders.blob_loaders`](/langchain-core/langchain_core/document_loaders/blob_loaders)** +- **[`langchain_core.document_loaders.langsmith`](/langchain-core/langchain_core/document_loaders/langsmith)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-document_loaders-__dir__) | - | +| [`__getattr__`](#langchain_core-document_loaders-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-document_loaders-__all__) + +[`_dynamic_imports`](#langchain_core-document_loaders-_dynamic_imports) + +### API + + + + + +```python +langchain_core.document_loaders.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.document_loaders.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.document_loaders.__all__ = ('BaseBlobParser', 'BaseLoader', 'Blob', 'BlobLoader', 'LangSmithLoader', 'PathL... +``` + + + + + + + + + +```python +langchain_core.document_loaders._dynamic_imports = {'BaseBlobParser': 'base', 'BaseLoader': 'base', 'Blob': 'blob_loaders', 'BlobLo... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/base.mdx new file mode 100644 index 0000000..e36719a --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/base.mdx @@ -0,0 +1,273 @@ +--- +layout: overview +slug: langchain-core/langchain_core/document_loaders/base +title: langchain_core.document_loaders.base +--- + +Abstract interface for document loader implementations. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseBlobParser`](#langchain_core-document_loaders-base-BaseBlobParser) | Abstract interface for blob parsers. | +| [`BaseLoader`](#langchain_core-document_loaders-base-BaseLoader) | Interface for document loader. | + +### Data + +[`_HAS_TEXT_SPLITTERS`](#langchain_core-document_loaders-base-_HAS_TEXT_SPLITTERS) + +### API + + + + + +```python +class langchain_core.document_loaders.base.BaseBlobParser() +``` + + + + + + +Abstract + +Abstract interface for blob parsers. + +A blob parser provides a way to parse raw data stored in a blob into one or more +`Document` objects. + +The parser can be composed with blob loaders, making it easy to reuse a parser +independent of how the blob was originally loaded. + + + + + + +```python +langchain_core.document_loaders.base.BaseBlobParser.lazy_parse( + blob: langchain_core.documents.base.Blob +) -> collections.abc.Iterator[langchain_core.documents.Document] +``` + + + + + + +abstract + +Lazy parsing interface. + +Subclasses are required to implement this method. + +**Parameters:** + + +`Blob` instance + + +**Returns:** `Iterator[Document]` + +Generator of `Document` objects + + + + + + + +```python +langchain_core.document_loaders.base.BaseBlobParser.parse( + blob: langchain_core.documents.base.Blob +) -> list[langchain_core.documents.Document] +``` + + + + + + +Eagerly parse the blob into a `Document` or list of `Document` objects. + +This is a convenience method for interactive development environment. + +Production applications should favor the `lazy_parse` method instead. + +Subclasses should generally not over-ride this parse method. + +**Parameters:** + + +`Blob` instance + + +**Returns:** `list[Document]` + +List of `Document` objects + + + + + + + + + +```python +class langchain_core.document_loaders.base.BaseLoader() +``` + + + + + + +Abstract + +Interface for document loader. + +Implementations should implement the lazy-loading method using generators to avoid +loading all documents into memory at once. + +`load` is provided just for user convenience and should not be overridden. + + + + + + +```python +langchain_core.document_loaders.base.BaseLoader.alazy_load() -> collections.abc.AsyncIterator[langchain_core.documents.Document] +``` + + + + + + +async + +A lazy loader for `Document`. + + + + + + + +```python +langchain_core.document_loaders.base.BaseLoader.aload() -> list[langchain_core.documents.Document] +``` + + + + + + +async + +Load data into `Document` objects. + +**Returns:** `list[Document]` + +The documents. + + + + + + + +```python +langchain_core.document_loaders.base.BaseLoader.lazy_load() -> collections.abc.Iterator[langchain_core.documents.Document] +``` + + + + + + +A lazy loader for `Document`. + + + + + + + +```python +langchain_core.document_loaders.base.BaseLoader.load() -> list[langchain_core.documents.Document] +``` + + + + + + +Load data into `Document` objects. + +**Returns:** `list[Document]` + +The documents. + + + + + + + +```python +langchain_core.document_loaders.base.BaseLoader.load_and_split( + text_splitter: langchain_text_splitters.TextSplitter | None = None +) -> list[langchain_core.documents.Document] +``` + + + + + + +Load `Document` and split into chunks. Chunks are returned as `Document`. + +!!! danger + + Do not override this method. It should be considered to be deprecated! + +**Parameters:** + + +`TextSplitter` instance to use for splitting documents. + +Defaults to `RecursiveCharacterTextSplitter`. + + +**Returns:** `list[Document]` + +List of `Document` objects. + +**Raises:** + +- `ImportError`: If `langchain-text-splitters` is not installed and no +`text_splitter` is provided. + + + + + + + + + +```python +langchain_core.document_loaders.base._HAS_TEXT_SPLITTERS = True +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/blob_loaders.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/blob_loaders.mdx new file mode 100644 index 0000000..89dd5e6 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/blob_loaders.mdx @@ -0,0 +1,78 @@ +--- +layout: overview +slug: langchain-core/langchain_core/document_loaders/blob_loaders +title: langchain_core.document_loaders.blob_loaders +--- + +Schema for Blobs and Blob Loaders. + +The goal is to facilitate decoupling of content loading from content parsing code. In +addition, content loading code should provide a lazy loading interface by default. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BlobLoader`](#langchain_core-document_loaders-blob_loaders-BlobLoader) | Abstract interface for blob loaders implementation. | + +### Data + +[`__all__`](#langchain_core-document_loaders-blob_loaders-__all__) + +### API + + + + + +```python +class langchain_core.document_loaders.blob_loaders.BlobLoader() +``` + + + + + + +Abstract + +Abstract interface for blob loaders implementation. + +Implementer should be able to load raw content from a storage system according to +some criteria and return the raw content lazily as a stream of blobs. + + + + + + +```python +langchain_core.document_loaders.blob_loaders.BlobLoader.yield_blobs() -> collections.abc.Iterator[langchain_core.documents.base.Blob] +``` + + + + + + +abstract + +A lazy loader for raw data represented by LangChain's `Blob` object. + + + + + + + + + +```python +langchain_core.document_loaders.blob_loaders.__all__ = ['Blob', 'BlobLoader', 'PathLike'] +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/langsmith.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/langsmith.mdx new file mode 100644 index 0000000..183c26d --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/document_loaders/langsmith.mdx @@ -0,0 +1,118 @@ +--- +layout: overview +slug: langchain-core/langchain_core/document_loaders/langsmith +title: langchain_core.document_loaders.langsmith +--- + +LangSmith document loader. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LangSmithLoader`](#langchain_core-document_loaders-langsmith-LangSmithLoader) | Load LangSmith Dataset examples as `Document` objects. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_stringify`](#langchain_core-document_loaders-langsmith-_stringify) | - | + +### API + + + + + +```python +class langchain_core.document_loaders.langsmith.LangSmithLoader( + dataset_id: uuid.UUID | str | None = None, + dataset_name: str | None = None, + example_ids: collections.abc.Sequence[uuid.UUID | str] | None = None, + as_of: datetime.datetime | str | None = None, + splits: collections.abc.Sequence[str] | None = None, + inline_s3_urls: bool = True, + offset: int = 0, + limit: int | None = None, + metadata: dict | None = None, + filter: str | None = None, + content_key: str = '', + format_content: collections.abc.Callable[..., str] | None = None, + client: langsmith.Client | None = None, + client_kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseLoader](/langchain-core/langchain_core/document_loaders/base#langchain_core-document_loaders-base-BaseLoader) + +Load LangSmith Dataset examples as `Document` objects. + +Loads the example inputs as the `Document` page content and places the entire +example into the `Document` metadata. This allows you to easily create few-shot +example retrievers from the loaded documents. + +??? example "Lazy loading" + + ```python + from langchain_core.document_loaders import LangSmithLoader + + loader = LangSmithLoader(dataset_id="...", limit=100) + docs = [] + for doc in loader.lazy_load(): + docs.append(doc) + ``` + + ```python + # -> [Document("...", metadata={"inputs": {...}, "outputs": {...}, ...}), ...] + ``` + + + + + + + + + + + + + + +```python +langchain_core.document_loaders.langsmith.LangSmithLoader.lazy_load() -> collections.abc.Iterator[langchain_core.documents.Document] +``` + + + + + + + + + + + + + + +```python +langchain_core.document_loaders.langsmith._stringify( + x: str | dict[str, typing.Any] +) -> str +``` + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents.mdx new file mode 100644 index 0000000..23c8e82 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents.mdx @@ -0,0 +1,113 @@ +--- +layout: overview +slug: langchain-core/langchain_core/documents +title: langchain_core.documents +--- + +Documents module for data retrieval and processing workflows. + +This module provides core abstractions for handling data in retrieval-augmented +generation (RAG) pipelines, vector stores, and document processing workflows. + +!!! warning "Documents vs. message content" + + This module is distinct from `langchain_core.messages.content`, which provides + multimodal content blocks for **LLM chat I/O** (text, images, audio, etc. within + messages). + + **Key distinction:** + + - **Documents** (this module): For **data retrieval and processing workflows** + - Vector stores, retrievers, RAG pipelines + - Text chunking, embedding, and semantic search + - Example: Chunks of a PDF stored in a vector database + + - **Content Blocks** (`messages.content`): For **LLM conversational I/O** + - Multimodal message content sent to/from models + - Tool calls, reasoning, citations within chat + - Example: An image sent to a vision model in a chat message (via + [`ImageContentBlock`][langchain.messages.ImageContentBlock]) + + While both can represent similar data types (text, files), they serve different + architectural purposes in LangChain applications. + +## Submodules + +- **[`langchain_core.documents.base`](/langchain-core/langchain_core/documents/base)** +- **[`langchain_core.documents.compressor`](/langchain-core/langchain_core/documents/compressor)** +- **[`langchain_core.documents.transformers`](/langchain-core/langchain_core/documents/transformers)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-documents-__dir__) | - | +| [`__getattr__`](#langchain_core-documents-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-documents-__all__) + +[`_dynamic_imports`](#langchain_core-documents-_dynamic_imports) + +### API + + + + + +```python +langchain_core.documents.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.documents.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.documents.__all__ = ('BaseDocumentCompressor', 'BaseDocumentTransformer', 'Document') +``` + + + + + + + + + +```python +langchain_core.documents._dynamic_imports = {'Document': 'base', 'BaseDocumentCompressor': 'compressor', 'BaseDocumentTransf... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/base.mdx new file mode 100644 index 0000000..7e0b022 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/base.mdx @@ -0,0 +1,495 @@ +--- +layout: overview +slug: langchain-core/langchain_core/documents/base +title: langchain_core.documents.base +--- + +Base classes for media and documents. + +This module contains core abstractions for **data retrieval and processing workflows**: + +- `BaseMedia`: Base class providing `id` and `metadata` fields +- `Blob`: Raw data loading (files, binary data) - used by document loaders +- `Document`: Text content for retrieval (RAG, vector stores, semantic search) + +!!! note "Not for LLM chat messages" + + These classes are for data processing pipelines, not LLM I/O. For multimodal + content in chat messages (images, audio in conversations), see + `langchain.messages` content blocks instead. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseMedia`](#langchain_core-documents-base-BaseMedia) | Base class for content used in retrieval and data processing workflows. | +| [`Blob`](#langchain_core-documents-base-Blob) | Raw data abstraction for document loading and file processing. | +| [`Document`](#langchain_core-documents-base-Document) | Class for storing a piece of text and associated metadata. | + +### Data + +[`PathLike`](#langchain_core-documents-base-PathLike) + +### API + + + + + +```python +class langchain_core.documents.base.BaseMedia() +``` + + + + + + +**Bases:** [Serializable](/langchain-core/langchain_core/load/serializable#langchain_core-load-serializable-Serializable) + +Base class for content used in retrieval and data processing workflows. + +Provides common fields for content that needs to be stored, indexed, or searched. + +!!! note + + For multimodal content in **chat messages** (images, audio sent to/from LLMs), + use `langchain.messages` content blocks instead. + + + +An optional identifier for the document. + +Ideally this should be unique across the document collection and formatted +as a UUID, but this will not be enforced. + + + +Arbitrary metadata associated with the content. + + + + + + + +```python +class langchain_core.documents.base.Blob() +``` + + + + + + +**Bases:** [BaseMedia](#langchain_core-documents-base-BaseMedia) + +Raw data abstraction for document loading and file processing. + +Represents raw bytes or text, either in-memory or by file reference. Used +primarily by document loaders to decouple data loading from parsing. + +Inspired by [Mozilla's `Blob`](https://developer.mozilla.org/en-US/docs/Web/API/Blob) + +???+ example "Initialize a blob from in-memory data" + + ```python + from langchain_core.documents import Blob + + blob = Blob.from_data("Hello, world!") + + # Read the blob as a string + print(blob.as_string()) + + # Read the blob as bytes + print(blob.as_bytes()) + + # Read the blob as a byte stream + with blob.as_bytes_io() as f: + print(f.read()) + ``` + +??? example "Load from memory and specify MIME type and metadata" + + ```python + from langchain_core.documents import Blob + + blob = Blob.from_data( + data="Hello, world!", + mime_type="text/plain", + metadata={"source": "https://example.com"}, + ) + ``` + +??? example "Load the blob from a file" + + ```python + from langchain_core.documents import Blob + + blob = Blob.from_path("path/to/file.txt") + + # Read the blob as a string + print(blob.as_string()) + + # Read the blob as bytes + print(blob.as_bytes()) + + # Read the blob as a byte stream + with blob.as_bytes_io() as f: + print(f.read()) + ``` + + + +Raw data associated with the `Blob`. + + + +Encoding to use if decoding the bytes into a string. + +Uses `utf-8` as default encoding if decoding to string. + + + +MIME type, not to be confused with a file extension. + + + + + + +Location where the original content was found. + + + +The source location of the blob as string if known otherwise none. + +If a path is associated with the `Blob`, it will default to the path location. + +Unless explicitly set via a metadata field called `'source'`, in which +case that value will be used instead. + + + + + +```python +langchain_core.documents.base.Blob.__repr__() -> str +``` + + + + + + +Return the blob representation. + + + + + + + +```python +langchain_core.documents.base.Blob.as_bytes() -> bytes +``` + + + + + + +Read data as bytes. + +**Returns:** `bytes` + +The data as bytes. + +**Raises:** + +- `ValueError`: If the blob cannot be represented as bytes. + + + + + + + +```python +langchain_core.documents.base.Blob.as_bytes_io() -> collections.abc.Generator[io.BytesIO | io.BufferedReader, None, None] +``` + + + + + + +Read data as a byte stream. + +**Raises:** + +- `NotImplementedError`: If the blob cannot be represented as a byte stream. + + + + + + + +```python +langchain_core.documents.base.Blob.as_string() -> str +``` + + + + + + +Read data as a string. + +**Returns:** `str` + +The data as a string. + +**Raises:** + +- `ValueError`: If the blob cannot be represented as a string. + + + + + + + +```python +langchain_core.documents.base.Blob.check_blob_is_valid( + values: dict[str, typing.Any] +) -> typing.Any +``` + + + + + + +classmethod + +Verify that either data or path is provided. + + + + + + + +```python +langchain_core.documents.base.Blob.from_data( + data: str | bytes, + encoding: str = 'utf-8', + mime_type: str | None = None, + path: str | None = None, + metadata: dict | None = None +) -> langchain_core.documents.base.Blob +``` + + + + + + +classmethod + +Initialize the `Blob` from in-memory data. + +**Parameters:** + + +The in-memory data associated with the `Blob` + + + +Encoding to use if decoding the bytes into a string + + + +If provided, will be set as the MIME type of the data + + + +If provided, will be set as the source from which the data came + + + +Metadata to associate with the `Blob` + + +**Returns:** `Blob` + +`Blob` instance + + + + + + + +```python +langchain_core.documents.base.Blob.from_path( + path: langchain_core.documents.base.PathLike, + encoding: str = 'utf-8', + mime_type: str | None = None, + guess_type: bool = True, + metadata: dict | None = None +) -> langchain_core.documents.base.Blob +``` + + + + + + +classmethod + +Load the blob from a path like object. + +**Parameters:** + + +Path-like object to file to be read + + + +Encoding to use if decoding the bytes into a string + + + +If provided, will be set as the MIME type of the data + + + +If `True`, the MIME type will be guessed from the file +extension, if a MIME type was not provided + + + +Metadata to associate with the `Blob` + + +**Returns:** `Blob` + +`Blob` instance + + + + + + + + + +```python +class langchain_core.documents.base.Document( + page_content: str, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseMedia](#langchain_core-documents-base-BaseMedia) + +Class for storing a piece of text and associated metadata. + +!!! note + + `Document` is for **retrieval workflows**, not chat I/O. For sending text + to an LLM in a conversation, use message types from `langchain.messages`. + + + +String text. + + + + + + + + +```python +langchain_core.documents.base.Document.__str__() -> str +``` + + + + + + +Override `__str__` to restrict it to page_content and metadata. + +**Returns:** `str` + +A string representation of the `Document`. + + + + + + + +```python +langchain_core.documents.base.Document.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "document"]` + + + + + + + +```python +langchain_core.documents.base.Document.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + + + +```python +langchain_core.documents.base.PathLike = str | PurePath +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/compressor.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/compressor.mdx new file mode 100644 index 0000000..ac0d8ce --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/compressor.mdx @@ -0,0 +1,133 @@ +--- +layout: overview +slug: langchain-core/langchain_core/documents/compressor +title: langchain_core.documents.compressor +--- + +Document compressor. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseDocumentCompressor`](#langchain_core-documents-compressor-BaseDocumentCompressor) | Base class for document compressors. | + +### API + + + + + +```python +class langchain_core.documents.compressor.BaseDocumentCompressor() +``` + + + + + + +Abstract + +**Bases:** `BaseModel` + +Base class for document compressors. + +This abstraction is primarily used for post-processing of retrieved documents. + +`Document` objects matching a given query are first retrieved. + +Then the list of documents can be further processed. + +For example, one could re-rank the retrieved documents using an LLM. + +!!! note + Users should favor using a `RunnableLambda` instead of sub-classing from this + interface. + + + + + + +```python +langchain_core.documents.compressor.BaseDocumentCompressor.acompress_documents( + documents: collections.abc.Sequence[langchain_core.documents.Document], + query: str, + callbacks: langchain_core.callbacks.Callbacks | None = None +) -> collections.abc.Sequence[langchain_core.documents.Document] +``` + + + + + + +async + +Async compress retrieved documents given the query context. + +**Parameters:** + + +The retrieved `Document` objects. + + + +The query context. + + + +Optional `Callbacks` to run during compression. + + +**Returns:** `Sequence[Document]` + +The compressed documents. + + + + + + + +```python +langchain_core.documents.compressor.BaseDocumentCompressor.compress_documents( + documents: collections.abc.Sequence[langchain_core.documents.Document], + query: str, + callbacks: langchain_core.callbacks.Callbacks | None = None +) -> collections.abc.Sequence[langchain_core.documents.Document] +``` + + + + + + +abstract + +Compress retrieved documents given the query context. + +**Parameters:** + + +The retrieved `Document` objects. + + + +The query context. + + + +Optional `Callbacks` to run during compression. + + +**Returns:** `Sequence[Document]` + +The compressed documents. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/transformers.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/transformers.mdx new file mode 100644 index 0000000..c970bcd --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/documents/transformers.mdx @@ -0,0 +1,104 @@ +--- +layout: overview +slug: langchain-core/langchain_core/documents/transformers +title: langchain_core.documents.transformers +--- + +Document transformers. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseDocumentTransformer`](#langchain_core-documents-transformers-BaseDocumentTransformer) | Abstract base class for document transformation. | + +### API + + + + + +```python +class langchain_core.documents.transformers.BaseDocumentTransformer() +``` + + + + + + +Abstract + +Abstract base class for document transformation. + +A document transformation takes a sequence of `Document` objects and returns a +sequence of transformed `Document` objects. + + + + + + +```python +langchain_core.documents.transformers.BaseDocumentTransformer.atransform_documents( + documents: collections.abc.Sequence[langchain_core.documents.Document], + kwargs: typing.Any = {} +) -> collections.abc.Sequence[langchain_core.documents.Document] +``` + + + + + + +async + +Asynchronously transform a list of documents. + +**Parameters:** + + +A sequence of `Document` objects to be transformed. + + +**Returns:** `Sequence[Document]` + +A sequence of transformed `Document` objects. + + + + + + + +```python +langchain_core.documents.transformers.BaseDocumentTransformer.transform_documents( + documents: collections.abc.Sequence[langchain_core.documents.Document], + kwargs: typing.Any = {} +) -> collections.abc.Sequence[langchain_core.documents.Document] +``` + + + + + + +abstract + +Transform a list of documents. + +**Parameters:** + + +A sequence of `Document` objects to be transformed. + + +**Returns:** `Sequence[Document]` + +A sequence of transformed `Document` objects. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings.mdx new file mode 100644 index 0000000..1f52d16 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings.mdx @@ -0,0 +1,87 @@ +--- +layout: overview +slug: langchain-core/langchain_core/embeddings +title: langchain_core.embeddings +--- + +Embeddings. + +## Submodules + +- **[`langchain_core.embeddings.embeddings`](/langchain-core/langchain_core/embeddings/embeddings)** +- **[`langchain_core.embeddings.fake`](/langchain-core/langchain_core/embeddings/fake)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-embeddings-__dir__) | - | +| [`__getattr__`](#langchain_core-embeddings-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-embeddings-__all__) + +[`_dynamic_imports`](#langchain_core-embeddings-_dynamic_imports) + +### API + + + + + +```python +langchain_core.embeddings.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.embeddings.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.embeddings.__all__ = ('DeterministicFakeEmbedding', 'Embeddings', 'FakeEmbeddings') +``` + + + + + + + + + +```python +langchain_core.embeddings._dynamic_imports = {'Embeddings': 'embeddings', 'DeterministicFakeEmbedding': 'fake', 'FakeEmbeddin... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings/embeddings.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings/embeddings.mdx new file mode 100644 index 0000000..17715fa --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings/embeddings.mdx @@ -0,0 +1,185 @@ +--- +layout: overview +slug: langchain-core/langchain_core/embeddings/embeddings +title: langchain_core.embeddings.embeddings +--- + +**Embeddings** interface. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Embeddings`](#langchain_core-embeddings-embeddings-Embeddings) | Interface for embedding models. | + +### API + + + + + +```python +class langchain_core.embeddings.embeddings.Embeddings() +``` + + + + + + +Abstract + +Interface for embedding models. + +This is an interface meant for implementing text embedding models. + +Text embedding models are used to map text to a vector (a point in n-dimensional +space). + +Texts that are similar will usually be mapped to points that are close to each +other in this space. The exact details of what's considered "similar" and how +"distance" is measured in this space are dependent on the specific embedding model. + +This abstraction contains a method for embedding a list of documents and a method +for embedding a query text. The embedding of a query text is expected to be a single +vector, while the embedding of a list of documents is expected to be a list of +vectors. + +Usually the query embedding is identical to the document embedding, but the +abstraction allows treating them independently. + +In addition to the synchronous methods, this interface also provides asynchronous +versions of the methods. + +By default, the asynchronous methods are implemented using the synchronous methods; +however, implementations may choose to override the asynchronous methods with +an async native implementation for performance reasons. + + + + + + +```python +langchain_core.embeddings.embeddings.Embeddings.aembed_documents( + texts: list[str] +) -> list[list[float]] +``` + + + + + + +async + +Asynchronous Embed search docs. + +**Parameters:** + + +List of text to embed. + + +**Returns:** `list[list[float]]` + +List of embeddings. + + + + + + + +```python +langchain_core.embeddings.embeddings.Embeddings.aembed_query( + text: str +) -> list[float] +``` + + + + + + +async + +Asynchronous Embed query text. + +**Parameters:** + + +Text to embed. + + +**Returns:** `list[float]` + +Embedding. + + + + + + + +```python +langchain_core.embeddings.embeddings.Embeddings.embed_documents( + texts: list[str] +) -> list[list[float]] +``` + + + + + + +abstract + +Embed search docs. + +**Parameters:** + + +List of text to embed. + + +**Returns:** `list[list[float]]` + +List of embeddings. + + + + + + + +```python +langchain_core.embeddings.embeddings.Embeddings.embed_query( + text: str +) -> list[float] +``` + + + + + + +abstract + +Embed query text. + +**Parameters:** + + +Text to embed. + + +**Returns:** `list[float]` + +Embedding. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings/fake.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings/fake.mdx new file mode 100644 index 0000000..d3e0b7d --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/embeddings/fake.mdx @@ -0,0 +1,196 @@ +--- +layout: overview +slug: langchain-core/langchain_core/embeddings/fake +title: langchain_core.embeddings.fake +--- + +Module contains a few fake embedding models for testing purposes. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DeterministicFakeEmbedding`](#langchain_core-embeddings-fake-DeterministicFakeEmbedding) | Deterministic fake embedding model for unit testing purposes. | +| [`FakeEmbeddings`](#langchain_core-embeddings-fake-FakeEmbeddings) | Fake embedding model for unit testing purposes. | + +### API + + + + + +```python +class langchain_core.embeddings.fake.DeterministicFakeEmbedding() +``` + + + + + + +**Bases:** [Embeddings](/langchain-core/langchain_core/embeddings/embeddings#langchain_core-embeddings-embeddings-Embeddings), `BaseModel` + +Deterministic fake embedding model for unit testing purposes. + +This embedding model creates embeddings by sampling from a normal distribution +with a seed based on the hash of the text. + +!!! danger "Toy model" + Do not use this outside of testing, as it is not a real embedding model. + + + +The size of the embedding vector. + + + + + +```python +langchain_core.embeddings.fake.DeterministicFakeEmbedding._get_embedding( + seed: int +) -> list[float] +``` + + + + + + + + + + + + +```python +langchain_core.embeddings.fake.DeterministicFakeEmbedding._get_seed( + text: str +) -> int +``` + + + + + + +staticmethod + +Get a seed for the random generator, using the hash of the text. + + + + + + + +```python +langchain_core.embeddings.fake.DeterministicFakeEmbedding.embed_documents( + texts: list[str] +) -> list[list[float]] +``` + + + + + + + + + + + + +```python +langchain_core.embeddings.fake.DeterministicFakeEmbedding.embed_query( + text: str +) -> list[float] +``` + + + + + + + + + + + + + + +```python +class langchain_core.embeddings.fake.FakeEmbeddings() +``` + + + + + + +**Bases:** [Embeddings](/langchain-core/langchain_core/embeddings/embeddings#langchain_core-embeddings-embeddings-Embeddings), `BaseModel` + +Fake embedding model for unit testing purposes. + +This embedding model creates embeddings by sampling from a normal distribution. + +!!! danger "Toy model" + Do not use this outside of testing, as it is not a real embedding model. + + + +The size of the embedding vector. + + + + + +```python +langchain_core.embeddings.fake.FakeEmbeddings._get_embedding() -> list[float] +``` + + + + + + + + + + + + +```python +langchain_core.embeddings.fake.FakeEmbeddings.embed_documents( + texts: list[str] +) -> list[list[float]] +``` + + + + + + + + + + + + +```python +langchain_core.embeddings.fake.FakeEmbeddings.embed_query( + text: str +) -> list[float] +``` + + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/env.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/env.mdx new file mode 100644 index 0000000..6774aeb --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/env.mdx @@ -0,0 +1,39 @@ +--- +layout: overview +slug: langchain-core/langchain_core/env +title: langchain_core.env +--- + +Utilities for getting information about the runtime environment. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_runtime_environment`](#langchain_core-env-get_runtime_environment) | Get information about the LangChain runtime environment. | + +### API + + + + + +```python +langchain_core.env.get_runtime_environment() -> dict +``` + + + + + + +Get information about the LangChain runtime environment. + +**Returns:** `dict` + +A dictionary with information about the runtime environment. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors.mdx new file mode 100644 index 0000000..7f24749 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors.mdx @@ -0,0 +1,91 @@ +--- +layout: overview +slug: langchain-core/langchain_core/example_selectors +title: langchain_core.example_selectors +--- + +Example selectors. + +**Example selector** implements logic for selecting examples to include them in prompts. +This allows us to select examples that are most relevant to the input. + +## Submodules + +- **[`langchain_core.example_selectors.base`](/langchain-core/langchain_core/example_selectors/base)** +- **[`langchain_core.example_selectors.length_based`](/langchain-core/langchain_core/example_selectors/length_based)** +- **[`langchain_core.example_selectors.semantic_similarity`](/langchain-core/langchain_core/example_selectors/semantic_similarity)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-example_selectors-__dir__) | - | +| [`__getattr__`](#langchain_core-example_selectors-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-example_selectors-__all__) + +[`_dynamic_imports`](#langchain_core-example_selectors-_dynamic_imports) + +### API + + + + + +```python +langchain_core.example_selectors.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.example_selectors.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.example_selectors.__all__ = ('BaseExampleSelector', 'LengthBasedExampleSelector', 'MaxMarginalRelevanceExamp... +``` + + + + + + + + + +```python +langchain_core.example_selectors._dynamic_imports = {'BaseExampleSelector': 'base', 'LengthBasedExampleSelector': 'length_based', 'M... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/base.mdx new file mode 100644 index 0000000..cc70087 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/base.mdx @@ -0,0 +1,165 @@ +--- +layout: overview +slug: langchain-core/langchain_core/example_selectors/base +title: langchain_core.example_selectors.base +--- + +Interface for selecting examples to include in prompts. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseExampleSelector`](#langchain_core-example_selectors-base-BaseExampleSelector) | Interface for selecting examples to include in prompts. | + +### API + + + + + +```python +class langchain_core.example_selectors.base.BaseExampleSelector() +``` + + + + + + +Abstract + +Interface for selecting examples to include in prompts. + + + + + + +```python +langchain_core.example_selectors.base.BaseExampleSelector.aadd_example( + example: dict[str, str] +) -> typing.Any +``` + + + + + + +async + +Async add new example to store. + +**Parameters:** + + +A dictionary with keys as input variables +and values as their values. + + +**Returns:** `Any` + +Any return value. + + + + + + + +```python +langchain_core.example_selectors.base.BaseExampleSelector.add_example( + example: dict[str, str] +) -> typing.Any +``` + + + + + + +abstract + +Add new example to store. + +**Parameters:** + + +A dictionary with keys as input variables +and values as their values. + + +**Returns:** `Any` + +Any return value. + + + + + + + +```python +langchain_core.example_selectors.base.BaseExampleSelector.aselect_examples( + input_variables: dict[str, str] +) -> list[dict] +``` + + + + + + +async + +Async select which examples to use based on the inputs. + +**Parameters:** + + +A dictionary with keys as input variables +and values as their values. + + +**Returns:** `list[dict]` + +A list of examples. + + + + + + + +```python +langchain_core.example_selectors.base.BaseExampleSelector.select_examples( + input_variables: dict[str, str] +) -> list[dict] +``` + + + + + + +abstract + +Select which examples to use based on the inputs. + +**Parameters:** + + +A dictionary with keys as input variables +and values as their values. + + +**Returns:** `list[dict]` + +A list of examples. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/length_based.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/length_based.mdx new file mode 100644 index 0000000..91565c3 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/length_based.mdx @@ -0,0 +1,213 @@ +--- +layout: overview +slug: langchain-core/langchain_core/example_selectors/length_based +title: langchain_core.example_selectors.length_based +--- + +Select examples based on length. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LengthBasedExampleSelector`](#langchain_core-example_selectors-length_based-LengthBasedExampleSelector) | Select examples based on length. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_length_based`](#langchain_core-example_selectors-length_based-_get_length_based) | - | + +### API + + + + + +```python +class langchain_core.example_selectors.length_based.LengthBasedExampleSelector() +``` + + + + + + +**Bases:** [BaseExampleSelector](/langchain-core/langchain_core/example_selectors/base#langchain_core-example_selectors-base-BaseExampleSelector), `BaseModel` + +Select examples based on length. + + + +Prompt template used to format the examples. + + + +Length of each example. + + + +A list of the examples that the prompt template expects. + + + +Function to measure prompt length. Defaults to word count. + + + +Max length for the prompt, beyond which examples are cut. + + + + + +```python +langchain_core.example_selectors.length_based.LengthBasedExampleSelector.aadd_example( + example: dict[str, str] +) -> None +``` + + + + + + +async + +Async add new example to list. + +**Parameters:** + + +A dictionary with keys as input variables +and values as their values. + + + + + + + + +```python +langchain_core.example_selectors.length_based.LengthBasedExampleSelector.add_example( + example: dict[str, str] +) -> None +``` + + + + + + +Add new example to list. + +**Parameters:** + + +A dictionary with keys as input variables +and values as their values. + + + + + + + + +```python +langchain_core.example_selectors.length_based.LengthBasedExampleSelector.aselect_examples( + input_variables: dict[str, str] +) -> list[dict] +``` + + + + + + +async + +Async select which examples to use based on the input lengths. + +**Parameters:** + + +A dictionary with keys as input variables +and values as their values. + + +**Returns:** `list[dict]` + +A list of examples to include in the prompt. + + + + + + + +```python +langchain_core.example_selectors.length_based.LengthBasedExampleSelector.post_init() -> typing_extensions.Self +``` + + + + + + +Validate that the examples are formatted correctly. + + + + + + + +```python +langchain_core.example_selectors.length_based.LengthBasedExampleSelector.select_examples( + input_variables: dict[str, str] +) -> list[dict] +``` + + + + + + +Select which examples to use based on the input lengths. + +**Parameters:** + + +A dictionary with keys as input variables +and values as their values. + + +**Returns:** `list[dict]` + +A list of examples to include in the prompt. + + + + + + + + + +```python +langchain_core.example_selectors.length_based._get_length_based( + text: str +) -> int +``` + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/semantic_similarity.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/semantic_similarity.mdx new file mode 100644 index 0000000..f9f0a96 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/example_selectors/semantic_similarity.mdx @@ -0,0 +1,656 @@ +--- +layout: overview +slug: langchain-core/langchain_core/example_selectors/semantic_similarity +title: langchain_core.example_selectors.semantic_similarity +--- + +Example selector that selects examples based on SemanticSimilarity. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MaxMarginalRelevanceExampleSelector`](#langchain_core-example_selectors-semantic_similarity-MaxMarginalRelevanceExampleSelector) | Select examples based on Max Marginal Relevance. | +| [`SemanticSimilarityExampleSelector`](#langchain_core-example_selectors-semantic_similarity-SemanticSimilarityExampleSelector) | Select examples based on semantic similarity. | +| [`_VectorStoreExampleSelector`](#langchain_core-example_selectors-semantic_similarity-_VectorStoreExampleSelector) | Example selector that selects examples based on SemanticSimilarity. | + +### Functions + +| Name | Description | +|------|-------------| +| [`sorted_values`](#langchain_core-example_selectors-semantic_similarity-sorted_values) | Return a list of values in dict sorted by key. | + +### API + + + + + +```python +class langchain_core.example_selectors.semantic_similarity.MaxMarginalRelevanceExampleSelector() +``` + + + + + + +**Bases:** [_VectorStoreExampleSelector](#langchain_core-example_selectors-semantic_similarity-_VectorStoreExampleSelector) + +Select examples based on Max Marginal Relevance. + +This was shown to improve performance in this paper: +https://arxiv.org/pdf/2211.13892.pdf + + + +Number of examples to fetch to rerank. + + + + + +```python +langchain_core.example_selectors.semantic_similarity.MaxMarginalRelevanceExampleSelector.afrom_examples( + examples: list[dict], + embeddings: langchain_core.embeddings.Embeddings, + vectorstore_cls: type[langchain_core.vectorstores.VectorStore], + k: int = 4, + input_keys: list[str] | None = None, + fetch_k: int = 20, + example_keys: list[str] | None = None, + vectorstore_kwargs: dict | None = None, + vectorstore_cls_kwargs: typing.Any = {} +) -> langchain_core.example_selectors.semantic_similarity.MaxMarginalRelevanceExampleSelector +``` + + + + + + +async classmethod + +Create k-shot example selector using example list and embeddings. + +Reshuffles examples dynamically based on Max Marginal Relevance. + +**Parameters:** + + +List of examples to use in the prompt. + + + +An initialized embedding API interface, e.g. OpenAIEmbeddings(). + + + +A vector store DB interface class, e.g. FAISS. + + + +Number of examples to select. + + + +Number of `Document` objects to fetch to pass to MMR algorithm. + + + +If provided, the search is based on the input variables +instead of all variables. + + + +If provided, keys to filter examples to. + + + +Extra arguments passed to similarity_search function +of the `VectorStore`. + + + +optional kwargs containing url for vector store + + +**Returns:** `MaxMarginalRelevanceExampleSelector` + +The ExampleSelector instantiated, backed by a vector store. + + + + + + + +```python +langchain_core.example_selectors.semantic_similarity.MaxMarginalRelevanceExampleSelector.aselect_examples( + input_variables: dict[str, str] +) -> list[dict] +``` + + + + + + +async + +Asynchronously select examples based on Max Marginal Relevance. + +**Parameters:** + + +The input variables to use for search. + + +**Returns:** `list[dict]` + +The selected examples. + + + + + + + +```python +langchain_core.example_selectors.semantic_similarity.MaxMarginalRelevanceExampleSelector.from_examples( + examples: list[dict], + embeddings: langchain_core.embeddings.Embeddings, + vectorstore_cls: type[langchain_core.vectorstores.VectorStore], + k: int = 4, + input_keys: list[str] | None = None, + fetch_k: int = 20, + example_keys: list[str] | None = None, + vectorstore_kwargs: dict | None = None, + vectorstore_cls_kwargs: typing.Any = {} +) -> langchain_core.example_selectors.semantic_similarity.MaxMarginalRelevanceExampleSelector +``` + + + + + + +classmethod + +Create k-shot example selector using example list and embeddings. + +Reshuffles examples dynamically based on Max Marginal Relevance. + +**Parameters:** + + +List of examples to use in the prompt. + + + +An initialized embedding API interface, e.g. OpenAIEmbeddings(). + + + +A vector store DB interface class, e.g. FAISS. + + + +Number of examples to select. + + + +Number of `Document` objects to fetch to pass to MMR algorithm. + + + +If provided, the search is based on the input variables +instead of all variables. + + + +If provided, keys to filter examples to. + + + +Extra arguments passed to similarity_search function +of the `VectorStore`. + + + +optional kwargs containing url for vector store + + +**Returns:** `MaxMarginalRelevanceExampleSelector` + +The ExampleSelector instantiated, backed by a vector store. + + + + + + + +```python +langchain_core.example_selectors.semantic_similarity.MaxMarginalRelevanceExampleSelector.select_examples( + input_variables: dict[str, str] +) -> list[dict] +``` + + + + + + +Select examples based on Max Marginal Relevance. + +**Parameters:** + + +The input variables to use for search. + + +**Returns:** `list[dict]` + +The selected examples. + + + + + + + + + +```python +class langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector() +``` + + + + + + +**Bases:** [_VectorStoreExampleSelector](#langchain_core-example_selectors-semantic_similarity-_VectorStoreExampleSelector) + +Select examples based on semantic similarity. + + + + + + +```python +langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector.afrom_examples( + examples: list[dict], + embeddings: langchain_core.embeddings.Embeddings, + vectorstore_cls: type[langchain_core.vectorstores.VectorStore], + k: int = 4, + input_keys: list[str] | None = None, + example_keys: list[str] | None = None, + vectorstore_kwargs: dict | None = None, + vectorstore_cls_kwargs: typing.Any = {} +) -> langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector +``` + + + + + + +async classmethod + +Async create k-shot example selector using example list and embeddings. + +Reshuffles examples dynamically based on query similarity. + +**Parameters:** + + +List of examples to use in the prompt. + + + +An initialized embedding API interface, e.g. OpenAIEmbeddings(). + + + +A vector store DB interface class, e.g. FAISS. + + + +Number of examples to select. + + + +If provided, the search is based on the input variables +instead of all variables. + + + +If provided, keys to filter examples to. + + + +Extra arguments passed to similarity_search function +of the `VectorStore`. + + + +optional kwargs containing url for vector store + + +**Returns:** `SemanticSimilarityExampleSelector` + +The ExampleSelector instantiated, backed by a vector store. + + + + + + + +```python +langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector.aselect_examples( + input_variables: dict[str, str] +) -> list[dict] +``` + + + + + + +async + +Asynchronously select examples based on semantic similarity. + +**Parameters:** + + +The input variables to use for search. + + +**Returns:** `list[dict]` + +The selected examples. + + + + + + + +```python +langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector.from_examples( + examples: list[dict], + embeddings: langchain_core.embeddings.Embeddings, + vectorstore_cls: type[langchain_core.vectorstores.VectorStore], + k: int = 4, + input_keys: list[str] | None = None, + example_keys: list[str] | None = None, + vectorstore_kwargs: dict | None = None, + vectorstore_cls_kwargs: typing.Any = {} +) -> langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector +``` + + + + + + +classmethod + +Create k-shot example selector using example list and embeddings. + +Reshuffles examples dynamically based on query similarity. + +**Parameters:** + + +List of examples to use in the prompt. + + + +An initialized embedding API interface, e.g. OpenAIEmbeddings(). + + + +A vector store DB interface class, e.g. FAISS. + + + +Number of examples to select. + + + +If provided, the search is based on the input variables +instead of all variables. + + + +If provided, keys to filter examples to. + + + +Extra arguments passed to similarity_search function +of the `VectorStore`. + + + +optional kwargs containing url for vector store + + +**Returns:** `SemanticSimilarityExampleSelector` + +The ExampleSelector instantiated, backed by a vector store. + + + + + + + +```python +langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector.select_examples( + input_variables: dict[str, str] +) -> list[dict] +``` + + + + + + +Select examples based on semantic similarity. + +**Parameters:** + + +The input variables to use for search. + + +**Returns:** `list[dict]` + +The selected examples. + + + + + + + + + +```python +class langchain_core.example_selectors.semantic_similarity._VectorStoreExampleSelector() +``` + + + + + + +Abstract + +**Bases:** [BaseExampleSelector](/langchain-core/langchain_core/example_selectors/base#langchain_core-example_selectors-base-BaseExampleSelector), `BaseModel` + +Example selector that selects examples based on SemanticSimilarity. + + + +Optional keys to filter examples to. + + + +Optional keys to filter input to. If provided, the search is based on +the input variables instead of all variables. + + + +Number of examples to select. + + + + + + +VectorStore that contains information about examples. + + + +Extra arguments passed to similarity_search function of the `VectorStore`. + + + + + +```python +langchain_core.example_selectors.semantic_similarity._VectorStoreExampleSelector._documents_to_examples( + documents: list[langchain_core.documents.Document] +) -> list[dict] +``` + + + + + + + + + + + + +```python +langchain_core.example_selectors.semantic_similarity._VectorStoreExampleSelector._example_to_text( + example: dict[str, str], + input_keys: list[str] | None +) -> str +``` + + + + + + +staticmethod + + + + + + + +```python +langchain_core.example_selectors.semantic_similarity._VectorStoreExampleSelector.aadd_example( + example: dict[str, str] +) -> str +``` + + + + + + +async + +Async add new example to vectorstore. + +**Parameters:** + + +A dictionary with keys as input variables +and values as their values. + + +**Returns:** `str` + +The ID of the added example. + + + + + + + +```python +langchain_core.example_selectors.semantic_similarity._VectorStoreExampleSelector.add_example( + example: dict[str, str] +) -> str +``` + + + + + + +Add a new example to vectorstore. + +**Parameters:** + + +A dictionary with keys as input variables +and values as their values. + + +**Returns:** `str` + +The ID of the added example. + + + + + + + + + +```python +langchain_core.example_selectors.semantic_similarity.sorted_values( + values: dict[str, str] +) -> list[typing.Any] +``` + + + + + + +Return a list of values in dict sorted by key. + +**Parameters:** + + +A dictionary with keys as input variables +and values as their values. + + +**Returns:** `list[Any]` + +A list of values in dict sorted by key. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/exceptions.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/exceptions.mdx new file mode 100644 index 0000000..30bdf64 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/exceptions.mdx @@ -0,0 +1,204 @@ +--- +layout: overview +slug: langchain-core/langchain_core/exceptions +title: langchain_core.exceptions +--- + +Custom **exceptions** for LangChain. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ContextOverflowError`](#langchain_core-exceptions-ContextOverflowError) | Exception raised when input exceeds the model's context limit. | +| [`ErrorCode`](#langchain_core-exceptions-ErrorCode) | Error codes. | +| [`LangChainException`](#langchain_core-exceptions-LangChainException) | General LangChain exception. | +| [`OutputParserException`](#langchain_core-exceptions-OutputParserException) | Exception that output parsers should raise to signify a parsing error. | +| [`TracerException`](#langchain_core-exceptions-TracerException) | Base class for exceptions in tracers module. | + +### Functions + +| Name | Description | +|------|-------------| +| [`create_message`](#langchain_core-exceptions-create_message) | Create a message with a link to the LangChain troubleshooting guide. | + +### API + + + + + +```python +class langchain_core.exceptions.ContextOverflowError() +``` + + + + + + +Exception + +**Bases:** [LangChainException](#langchain_core-exceptions-LangChainException) + +Exception raised when input exceeds the model's context limit. + +This exception is raised by chat models when the input tokens exceed +the maximum context window supported by the model. + + + + + + + + +```python +class langchain_core.exceptions.ErrorCode +``` + + + + + + +**Bases:** `enum.Enum` + +Error codes. + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class langchain_core.exceptions.LangChainException() +``` + + + + + + +Exception + +**Bases:** `Exception` + +General LangChain exception. + + + + + + + + +```python +class langchain_core.exceptions.OutputParserException( + error: typing.Any, + observation: str | None = None, + llm_output: str | None = None, + send_to_llm: bool = False +) +``` + + + + + + +Exception + +**Bases:** `ValueError`, [LangChainException](#langchain_core-exceptions-LangChainException) + +Exception that output parsers should raise to signify a parsing error. + +This exists to differentiate parsing errors from other code or execution errors +that also may arise inside the output parser. + +`OutputParserException` will be available to catch and handle in ways to fix the +parsing error, while other errors will be raised. + + + + + + + + +```python +class langchain_core.exceptions.TracerException() +``` + + + + + + +Exception + +**Bases:** [LangChainException](#langchain_core-exceptions-LangChainException) + +Base class for exceptions in tracers module. + + + + + + + + +```python +langchain_core.exceptions.create_message( + message: str, + error_code: langchain_core.exceptions.ErrorCode +) -> str +``` + + + + + + +Create a message with a link to the LangChain troubleshooting guide. + +**Parameters:** + + +The message to display. + + + +The error code to display. + + +**Returns:** `str` + +The full message with the troubleshooting link. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/globals.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/globals.mdx new file mode 100644 index 0000000..c3aa3c9 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/globals.mdx @@ -0,0 +1,210 @@ +--- +layout: overview +slug: langchain-core/langchain_core/globals +title: langchain_core.globals +--- + +Global values and configuration that apply to all of LangChain. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_debug`](#langchain_core-globals-get_debug) | Get the value of the `debug` global setting. | +| [`get_llm_cache`](#langchain_core-globals-get_llm_cache) | Get the value of the `llm_cache` global setting. | +| [`get_verbose`](#langchain_core-globals-get_verbose) | Get the value of the `verbose` global setting. | +| [`set_debug`](#langchain_core-globals-set_debug) | Set a new value for the `debug` global setting. | +| [`set_llm_cache`](#langchain_core-globals-set_llm_cache) | Set a new LLM cache, overwriting the previous value, if any. | +| [`set_verbose`](#langchain_core-globals-set_verbose) | Set a new value for the `verbose` global setting. | + +### Data + +[`_debug`](#langchain_core-globals-_debug) + +[`_llm_cache`](#langchain_core-globals-_llm_cache) + +[`_verbose`](#langchain_core-globals-_verbose) + +### API + + + + + +```python +langchain_core.globals.get_debug() -> bool +``` + + + + + + +Get the value of the `debug` global setting. + +**Returns:** `bool` + +The value of the `debug` global setting. + + + + + + + + +```python +langchain_core.globals.get_llm_cache() -> typing.Optional[langchain_core.caches.BaseCache] +``` + + + + + + +Get the value of the `llm_cache` global setting. + +**Returns:** `Optional[BaseCache]` + +The value of the `llm_cache` global setting. + + + + + + + + +```python +langchain_core.globals.get_verbose() -> bool +``` + + + + + + +Get the value of the `verbose` global setting. + +**Returns:** `bool` + +The value of the `verbose` global setting. + + + + + + + + +```python +langchain_core.globals.set_debug( + value: bool +) -> None +``` + + + + + + +Set a new value for the `debug` global setting. + +**Parameters:** + + +The new value for the `debug` global setting. + + + + + + + + + +```python +langchain_core.globals.set_llm_cache( + value: typing.Optional[langchain_core.caches.BaseCache] +) -> None +``` + + + + + + +Set a new LLM cache, overwriting the previous value, if any. + +**Parameters:** + + +The new LLM cache to use. If `None`, the LLM cache is disabled. + + + + + + + + + +```python +langchain_core.globals.set_verbose( + value: bool +) -> None +``` + + + + + + +Set a new value for the `verbose` global setting. + +**Parameters:** + + +The new value for the `verbose` global setting. + + + + + + + + + +```python +langchain_core.globals._debug: bool = False +``` + + + + + + + + + +```python +langchain_core.globals._llm_cache: Optional[BaseCache] = None +``` + + + + + + + + + +```python +langchain_core.globals._verbose: bool = False +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing.mdx new file mode 100644 index 0000000..013d01d --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing.mdx @@ -0,0 +1,92 @@ +--- +layout: overview +slug: langchain-core/langchain_core/indexing +title: langchain_core.indexing +--- + +Code to help indexing data into a vectorstore. + +This package contains helper logic to help deal with indexing data into +a `VectorStore` while avoiding duplicated content and over-writing content +if it's unchanged. + +## Submodules + +- **[`langchain_core.indexing.api`](/langchain-core/langchain_core/indexing/api)** +- **[`langchain_core.indexing.base`](/langchain-core/langchain_core/indexing/base)** +- **[`langchain_core.indexing.in_memory`](/langchain-core/langchain_core/indexing/in_memory)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-indexing-__dir__) | - | +| [`__getattr__`](#langchain_core-indexing-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-indexing-__all__) + +[`_dynamic_imports`](#langchain_core-indexing-_dynamic_imports) + +### API + + + + + +```python +langchain_core.indexing.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.indexing.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.indexing.__all__ = ('DeleteResponse', 'DocumentIndex', 'InMemoryRecordManager', 'IndexingResult', '... +``` + + + + + + + + + +```python +langchain_core.indexing._dynamic_imports = {'aindex': 'api', 'index': 'api', 'IndexingResult': 'api', 'DeleteResponse': 'ba... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/api.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/api.mdx new file mode 100644 index 0000000..801fb3a --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/api.mdx @@ -0,0 +1,789 @@ +--- +layout: overview +slug: langchain-core/langchain_core/indexing/api +title: langchain_core.indexing.api +--- + +Module contains logic for indexing documents into vector stores. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`IndexingException`](#langchain_core-indexing-api-IndexingException) | Raised when an indexing operation fails. | +| [`IndexingResult`](#langchain_core-indexing-api-IndexingResult) | Return a detailed a breakdown of the result of the indexing operation. | +| [`_HashedDocument`](#langchain_core-indexing-api-_HashedDocument) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_abatch`](#langchain_core-indexing-api-_abatch) | Utility batching function. | +| [`_adelete`](#langchain_core-indexing-api-_adelete) | - | +| [`_batch`](#langchain_core-indexing-api-_batch) | Utility batching function. | +| [`_calculate_hash`](#langchain_core-indexing-api-_calculate_hash) | Return a hexadecimal digest of *text* using *algorithm*. | +| [`_deduplicate_in_order`](#langchain_core-indexing-api-_deduplicate_in_order) | Deduplicate a list of hashed documents while preserving order. | +| [`_delete`](#langchain_core-indexing-api-_delete) | Delete documents from a vector store or document index by their IDs. | +| [`_get_document_with_hash`](#langchain_core-indexing-api-_get_document_with_hash) | Calculate a hash of the document, and assign it to the uid. | +| [`_get_source_id_assigner`](#langchain_core-indexing-api-_get_source_id_assigner) | Get the source id from the document. | +| [`_hash_nested_dict`](#langchain_core-indexing-api-_hash_nested_dict) | Hash a nested dictionary to a UUID using the configured algorithm. | +| [`_hash_string`](#langchain_core-indexing-api-_hash_string) | Hash *input_string* to a deterministic UUID using the configured algorithm. | +| [`_hash_string_to_uuid`](#langchain_core-indexing-api-_hash_string_to_uuid) | Hashes a string and returns the corresponding UUID. | +| [`_to_async_iterator`](#langchain_core-indexing-api-_to_async_iterator) | Convert an iterable to an async iterator. | +| [`_warn_about_sha1`](#langchain_core-indexing-api-_warn_about_sha1) | Emit a one-time warning about SHA-1 collision weaknesses. | +| [`aindex`](#langchain_core-indexing-api-aindex) | Async index data from the loader into the vector store. | +| [`index`](#langchain_core-indexing-api-index) | Index data from the loader into the vector store. | + +### Data + +[`NAMESPACE_UUID`](#langchain_core-indexing-api-NAMESPACE_UUID) + +[`T`](#langchain_core-indexing-api-T) + +[`_WARNED_ABOUT_SHA1`](#langchain_core-indexing-api-_WARNED_ABOUT_SHA1) + +### API + + + + + +```python +class langchain_core.indexing.api.IndexingException() +``` + + + + + + +Exception + +**Bases:** [LangChainException](/langchain-core/langchain_core/exceptions#langchain_core-exceptions-LangChainException) + +Raised when an indexing operation fails. + + + + + + + + +```python +class langchain_core.indexing.api.IndexingResult +``` + + + + + + +**Bases:** `typing.TypedDict` + +Return a detailed a breakdown of the result of the indexing operation. + + +Number of added documents. + + + +Number of deleted documents. + + + +Number of skipped documents because they were already up to date. + + + +Number of updated documents because they were not up to date. + + + + + + + + +```python +class langchain_core.indexing.api._HashedDocument( + args: typing.Any = (), + kwargs: typing.Any = {} +) +``` + + + + + + + + + + + + +```python +langchain_core.indexing.api._abatch( + size: int, + iterable: collections.abc.AsyncIterable[langchain_core.indexing.api.T] +) -> collections.abc.AsyncIterator[list[langchain_core.indexing.api.T]] +``` + + + + + + +async + +Utility batching function. + + + + + + + + +```python +langchain_core.indexing.api._adelete( + vector_store: langchain_core.vectorstores.VectorStore | langchain_core.indexing.base.DocumentIndex, + ids: list[str] +) -> None +``` + + + + + + +async + + + + + + + + +```python +langchain_core.indexing.api._batch( + size: int, + iterable: collections.abc.Iterable[langchain_core.indexing.api.T] +) -> collections.abc.Iterator[list[langchain_core.indexing.api.T]] +``` + + + + + + +Utility batching function. + + + + + + + + +```python +langchain_core.indexing.api._calculate_hash( + text: str, + algorithm: typing.Literal['sha1', 'sha256', 'sha512', 'blake2b'] +) -> str +``` + + + + + + +Return a hexadecimal digest of *text* using *algorithm*. + + + + + + + + +```python +langchain_core.indexing.api._deduplicate_in_order( + hashed_documents: collections.abc.Iterable[langchain_core.documents.Document] +) -> collections.abc.Iterator[langchain_core.documents.Document] +``` + + + + + + +Deduplicate a list of hashed documents while preserving order. + + + + + + + + +```python +langchain_core.indexing.api._delete( + vector_store: langchain_core.vectorstores.VectorStore | langchain_core.indexing.base.DocumentIndex, + ids: list[str] +) -> None +``` + + + + + + +Delete documents from a vector store or document index by their IDs. + +**Parameters:** + + +The vector store or document index to delete from. + + + +List of document IDs to delete. + + +**Raises:** + +- `IndexingException`: If the delete operation fails. +- `TypeError`: If the `vector_store` is neither a `VectorStore` nor a +`DocumentIndex`. + + + + + + + + +```python +langchain_core.indexing.api._get_document_with_hash( + document: langchain_core.documents.Document, + key_encoder: collections.abc.Callable[[Document], str] | typing.Literal['sha1', 'sha256', 'sha512', 'blake2b'] +) -> langchain_core.documents.Document +``` + + + + + + +Calculate a hash of the document, and assign it to the uid. + +When using one of the predefined hashing algorithms, the hash is calculated +by hashing the content and the metadata of the document. + +**Parameters:** + + +Document to hash. + + + +Hashing algorithm to use for hashing the document. +If not provided, a default encoder using SHA-1 will be used. +SHA-1 is not collision-resistant, and a motivated attacker +could craft two different texts that hash to the +same cache key. + +New applications should use one of the alternative encoders +or provide a custom and strong key encoder function to avoid this risk. + +When changing the key encoder, you must change the +index as well to avoid duplicated documents in the cache. + + +**Returns:** `Document` + +Document with a unique identifier based on the hash of the content and metadata. + +**Raises:** + +- `ValueError`: If the metadata cannot be serialized using json. + + + + + + + + +```python +langchain_core.indexing.api._get_source_id_assigner( + source_id_key: str | collections.abc.Callable[[Document], str] | None +) -> collections.abc.Callable[[Document], str | None] +``` + + + + + + +Get the source id from the document. + + + + + + + + +```python +langchain_core.indexing.api._hash_nested_dict( + data: dict[typing.Any, typing.Any], + algorithm: typing.Literal['sha1', 'sha256', 'sha512', 'blake2b'] +) -> uuid.UUID +``` + + + + + + +Hash a nested dictionary to a UUID using the configured algorithm. + + + + + + + + +```python +langchain_core.indexing.api._hash_string( + input_string: str, + algorithm: typing.Literal['sha1', 'sha256', 'sha512', 'blake2b'] +) -> uuid.UUID +``` + + + + + + +Hash *input_string* to a deterministic UUID using the configured algorithm. + + + + + + + + +```python +langchain_core.indexing.api._hash_string_to_uuid( + input_string: str +) -> str +``` + + + + + + +Hashes a string and returns the corresponding UUID. + + + + + + + + +```python +langchain_core.indexing.api._to_async_iterator( + iterator: collections.abc.Iterable[langchain_core.indexing.api.T] +) -> collections.abc.AsyncIterator[langchain_core.indexing.api.T] +``` + + + + + + +async + +Convert an iterable to an async iterator. + + + + + + + + +```python +langchain_core.indexing.api._warn_about_sha1() -> None +``` + + + + + + +Emit a one-time warning about SHA-1 collision weaknesses. + + + + + + + + +```python +langchain_core.indexing.api.aindex( + docs_source: langchain_core.document_loaders.base.BaseLoader | collections.abc.Iterable[langchain_core.documents.Document] | collections.abc.AsyncIterator[langchain_core.documents.Document], + record_manager: langchain_core.indexing.base.RecordManager, + vector_store: langchain_core.vectorstores.VectorStore | langchain_core.indexing.base.DocumentIndex, + batch_size: int = 100, + cleanup: typing.Literal['incremental', 'full', 'scoped_full'] | None = None, + source_id_key: str | collections.abc.Callable[[Document], str] | None = None, + cleanup_batch_size: int = 1000, + force_update: bool = False, + key_encoder: typing.Literal['sha1', 'sha256', 'sha512', 'blake2b'] | collections.abc.Callable[[Document], str] = 'sha1', + upsert_kwargs: dict[str, typing.Any] | None = None +) -> langchain_core.indexing.api.IndexingResult +``` + + + + + + +async + +Async index data from the loader into the vector store. + +Indexing functionality uses a manager to keep track of which documents +are in the vector store. + +This allows us to keep track of which documents were updated, and which +documents were deleted, which documents should be skipped. + +For the time being, documents are indexed using their hashes, and users +are not able to specify the uid of the document. + +!!! warning "Behavior changed in `langchain-core` 0.3.25" + + Added `scoped_full` cleanup mode. + +!!! warning + + * In full mode, the loader should be returning + the entire dataset, and not just a subset of the dataset. + Otherwise, the auto_cleanup will remove documents that it is not + supposed to. + * In incremental mode, if documents associated with a particular + source id appear across different batches, the indexing API + will do some redundant work. This will still result in the + correct end state of the index, but will unfortunately not be + 100% efficient. For example, if a given document is split into 15 + chunks, and we index them using a batch size of 5, we'll have 3 batches + all with the same source id. In general, to avoid doing too much + redundant work select as big a batch size as possible. + * The `scoped_full` mode is suitable if determining an appropriate batch size + is challenging or if your data loader cannot return the entire dataset at + once. This mode keeps track of source IDs in memory, which should be fine + for most use cases. If your dataset is large (10M+ docs), you will likely + need to parallelize the indexing process regardless. + +**Parameters:** + + +Data loader or iterable of documents to index. + + + +Timestamped set to keep track of which documents were +updated. + + + +`VectorStore` or DocumentIndex to index the documents into. + + + +Batch size to use when indexing. + + + +How to handle clean up of documents. + +- incremental: Cleans up all documents that haven't been updated AND + that are associated with source IDs that were seen during indexing. + Clean up is done continuously during indexing helping to minimize the + probability of users seeing duplicated content. +- full: Delete all documents that have not been returned by the loader + during this run of indexing. + Clean up runs after all documents have been indexed. + This means that users may see duplicated content during indexing. +- scoped_full: Similar to Full, but only deletes all documents + that haven't been updated AND that are associated with + source IDs that were seen during indexing. +- None: Do not delete any documents. + + + +Optional key that helps identify the original source +of the document. + + + +Batch size to use when cleaning up documents. + + + +Force update documents even if they are present in the +record manager. Useful if you are re-indexing with updated embeddings. + + + +Hashing algorithm to use for hashing the document content and +metadata. Options include "blake2b", "sha256", and "sha512". + +!!! version-added "Added in `langchain-core` 0.3.66" + + + +Hashing algorithm to use for hashing the document. +If not provided, a default encoder using SHA-1 will be used. +SHA-1 is not collision-resistant, and a motivated attacker +could craft two different texts that hash to the +same cache key. + +New applications should use one of the alternative encoders +or provide a custom and strong key encoder function to avoid this risk. + +When changing the key encoder, you must change the +index as well to avoid duplicated documents in the cache. + + + +Additional keyword arguments to pass to the add_documents +method of the `VectorStore` or the upsert method of the DocumentIndex. +For example, you can use this to specify a custom vector_field: +upsert_kwargs={"vector_field": "embedding"} +!!! version-added "Added in `langchain-core` 0.3.10" + + +**Returns:** `IndexingResult` + +Indexing result which contains information about how many documents + +**Raises:** + +- `ValueError`: If cleanup mode is not one of 'incremental', 'full' or None +- `ValueError`: If cleanup mode is incremental and source_id_key is None. +- `ValueError`: If `VectorStore` does not have +"adelete" and "aadd_documents" required methods. +- `ValueError`: If source_id_key is not None, but is not a string or callable. +- `TypeError`: If `vector_store` is not a `VectorStore` or DocumentIndex. +- `AssertionError`: If `source_id_key` is None when cleanup mode is +incremental or `scoped_full` (should be unreachable). + + + + + + + + +```python +langchain_core.indexing.api.index( + docs_source: langchain_core.document_loaders.base.BaseLoader | collections.abc.Iterable[langchain_core.documents.Document], + record_manager: langchain_core.indexing.base.RecordManager, + vector_store: langchain_core.vectorstores.VectorStore | langchain_core.indexing.base.DocumentIndex, + batch_size: int = 100, + cleanup: typing.Literal['incremental', 'full', 'scoped_full'] | None = None, + source_id_key: str | collections.abc.Callable[[Document], str] | None = None, + cleanup_batch_size: int = 1000, + force_update: bool = False, + key_encoder: typing.Literal['sha1', 'sha256', 'sha512', 'blake2b'] | collections.abc.Callable[[Document], str] = 'sha1', + upsert_kwargs: dict[str, typing.Any] | None = None +) -> langchain_core.indexing.api.IndexingResult +``` + + + + + + +Index data from the loader into the vector store. + +Indexing functionality uses a manager to keep track of which documents +are in the vector store. + +This allows us to keep track of which documents were updated, and which +documents were deleted, which documents should be skipped. + +For the time being, documents are indexed using their hashes, and users +are not able to specify the uid of the document. + +!!! warning "Behavior changed in `langchain-core` 0.3.25" + + Added `scoped_full` cleanup mode. + +!!! warning + + * In full mode, the loader should be returning + the entire dataset, and not just a subset of the dataset. + Otherwise, the auto_cleanup will remove documents that it is not + supposed to. + * In incremental mode, if documents associated with a particular + source id appear across different batches, the indexing API + will do some redundant work. This will still result in the + correct end state of the index, but will unfortunately not be + 100% efficient. For example, if a given document is split into 15 + chunks, and we index them using a batch size of 5, we'll have 3 batches + all with the same source id. In general, to avoid doing too much + redundant work select as big a batch size as possible. + * The `scoped_full` mode is suitable if determining an appropriate batch size + is challenging or if your data loader cannot return the entire dataset at + once. This mode keeps track of source IDs in memory, which should be fine + for most use cases. If your dataset is large (10M+ docs), you will likely + need to parallelize the indexing process regardless. + +**Parameters:** + + +Data loader or iterable of documents to index. + + + +Timestamped set to keep track of which documents were +updated. + + + +`VectorStore` or DocumentIndex to index the documents into. + + + +Batch size to use when indexing. + + + +How to handle clean up of documents. + +- incremental: Cleans up all documents that haven't been updated AND + that are associated with source IDs that were seen during indexing. + Clean up is done continuously during indexing helping to minimize the + probability of users seeing duplicated content. +- full: Delete all documents that have not been returned by the loader + during this run of indexing. + Clean up runs after all documents have been indexed. + This means that users may see duplicated content during indexing. +- scoped_full: Similar to Full, but only deletes all documents + that haven't been updated AND that are associated with + source IDs that were seen during indexing. +- None: Do not delete any documents. + + + +Optional key that helps identify the original source +of the document. + + + +Batch size to use when cleaning up documents. + + + +Force update documents even if they are present in the +record manager. Useful if you are re-indexing with updated embeddings. + + + +Hashing algorithm to use for hashing the document content and +metadata. Options include "blake2b", "sha256", and "sha512". + +!!! version-added "Added in `langchain-core` 0.3.66" + + + +Hashing algorithm to use for hashing the document. +If not provided, a default encoder using SHA-1 will be used. +SHA-1 is not collision-resistant, and a motivated attacker +could craft two different texts that hash to the +same cache key. + +New applications should use one of the alternative encoders +or provide a custom and strong key encoder function to avoid this risk. + +When changing the key encoder, you must change the +index as well to avoid duplicated documents in the cache. + + + +Additional keyword arguments to pass to the add_documents +method of the `VectorStore` or the upsert method of the DocumentIndex. +For example, you can use this to specify a custom vector_field: +upsert_kwargs={"vector_field": "embedding"} +!!! version-added "Added in `langchain-core` 0.3.10" + + +**Returns:** `IndexingResult` + +Indexing result which contains information about how many documents + +**Raises:** + +- `ValueError`: If cleanup mode is not one of 'incremental', 'full' or None +- `ValueError`: If cleanup mode is incremental and source_id_key is None. +- `ValueError`: If `VectorStore` does not have +"delete" and "add_documents" required methods. +- `ValueError`: If source_id_key is not None, but is not a string or callable. +- `TypeError`: If `vectorstore` is not a `VectorStore` or a DocumentIndex. +- `AssertionError`: If `source_id` is None when cleanup mode is incremental. +(should be unreachable code). + + + + + + + + +```python +langchain_core.indexing.api.NAMESPACE_UUID = uuid.UUID(int=1984) +``` + + + + + + + + + +```python +langchain_core.indexing.api.T = TypeVar('T') +``` + + + + + + + + + +```python +langchain_core.indexing.api._WARNED_ABOUT_SHA1: bool = False +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/base.mdx new file mode 100644 index 0000000..88355ec --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/base.mdx @@ -0,0 +1,1269 @@ +--- +layout: overview +slug: langchain-core/langchain_core/indexing/base +title: langchain_core.indexing.base +--- + +Base classes for indexing. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DeleteResponse`](#langchain_core-indexing-base-DeleteResponse) | A generic response for delete operation. | +| [`DocumentIndex`](#langchain_core-indexing-base-DocumentIndex) | A document retriever that supports indexing operations. | +| [`InMemoryRecordManager`](#langchain_core-indexing-base-InMemoryRecordManager) | An in-memory record manager for testing purposes. | +| [`RecordManager`](#langchain_core-indexing-base-RecordManager) | Abstract base class representing the interface for a record manager. | +| [`UpsertResponse`](#langchain_core-indexing-base-UpsertResponse) | A generic response for upsert operations. | +| [`_Record`](#langchain_core-indexing-base-_Record) | - | + +### API + + + + + +```python +class langchain_core.indexing.base.DeleteResponse +``` + + + + + + +**Bases:** `typing.TypedDict` + +A generic response for delete operation. + +The fields in this response are optional and whether the `VectorStore` +returns them or not is up to the implementation. + + +The IDs that failed to be deleted. + +!!! warning + Deleting an ID that does not exist is **NOT** considered a failure. + + + +The number of items that were successfully deleted. + +If returned, this should only include *actual* deletions. + +If the ID did not exist to begin with, +it should not be included in this count. + + + +The number of items that failed to be deleted. + + + +The IDs that were successfully deleted. + +If returned, this should only include *actual* deletions. + +If the ID did not exist to begin with, +it should not be included in this list. + + + + + + + + +```python +class langchain_core.indexing.base.DocumentIndex() +``` + + + + + + +**Bases:** [BaseRetriever](/langchain-core/langchain_core/retrievers#langchain_core-retrievers-BaseRetriever) + +A document retriever that supports indexing operations. + +This indexing interface is designed to be a generic abstraction for storing and +querying documents that has an ID and metadata associated with it. + +The interface is designed to be agnostic to the underlying implementation of the +indexing system. + +The interface is designed to support the following operations: + +1. Storing document in the index. +2. Fetching document by ID. +3. Searching for document using a query. + + + + + + +```python +langchain_core.indexing.base.DocumentIndex.adelete( + ids: list[str] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.indexing.base.DeleteResponse +``` + + + + + + +async + +Delete by IDs or other criteria. Async variant. + +Calling adelete without any input parameters should raise a ValueError! + +**Parameters:** + + +List of IDs to delete. + + + +Additional keyword arguments. This is up to the implementation. +For example, can include an option to delete the entire index. + + +**Returns:** `DeleteResponse` + +A response object that contains the list of IDs that were + + + + + + + +```python +langchain_core.indexing.base.DocumentIndex.aget( + ids: collections.abc.Sequence[str], + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + +Get documents by id. + +Fewer documents may be returned than requested if some IDs are not found or +if there are duplicated IDs. + +Users should not assume that the order of the returned documents matches +the order of the input IDs. Instead, users should rely on the ID field of the +returned documents. + +This method should **NOT** raise exceptions if no documents are found for +some IDs. + +**Parameters:** + + +List of IDs to get. + + + +Additional keyword arguments. These are up to the implementation. + + +**Returns:** `list[Document]` + +List of documents that were found. + + + + + + + +```python +langchain_core.indexing.base.DocumentIndex.aupsert( + items: collections.abc.Sequence[langchain_core.documents.Document], + kwargs: typing.Any = {} +) -> langchain_core.indexing.base.UpsertResponse +``` + + + + + + +async + +Add or update documents in the `VectorStore`. Async version of `upsert`. + +The upsert functionality should utilize the ID field of the item +if it is provided. If the ID is not provided, the upsert method is free +to generate an ID for the item. + +When an ID is specified and the item already exists in the `VectorStore`, +the upsert method should update the item with the new data. If the item +does not exist, the upsert method should add the item to the `VectorStore`. + +**Parameters:** + + +Sequence of documents to add to the `VectorStore`. + + + +Additional keyword arguments. + + +**Returns:** `UpsertResponse` + +A response object that contains the list of IDs that were + + + + + + + +```python +langchain_core.indexing.base.DocumentIndex.delete( + ids: list[str] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.indexing.base.DeleteResponse +``` + + + + + + +abstract + +Delete by IDs or other criteria. + +Calling delete without any input parameters should raise a ValueError! + +**Parameters:** + + +List of IDs to delete. + + + +Additional keyword arguments. This is up to the implementation. +For example, can include an option to delete the entire index, +or else issue a non-blocking delete etc. + + +**Returns:** `DeleteResponse` + +A response object that contains the list of IDs that were + + + + + + + +```python +langchain_core.indexing.base.DocumentIndex.get( + ids: collections.abc.Sequence[str], + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +abstract + +Get documents by id. + +Fewer documents may be returned than requested if some IDs are not found or +if there are duplicated IDs. + +Users should not assume that the order of the returned documents matches +the order of the input IDs. Instead, users should rely on the ID field of the +returned documents. + +This method should **NOT** raise exceptions if no documents are found for +some IDs. + +**Parameters:** + + +List of IDs to get. + + + +Additional keyword arguments. These are up to the implementation. + + +**Returns:** `list[Document]` + +List of documents that were found. + + + + + + + +```python +langchain_core.indexing.base.DocumentIndex.upsert( + items: collections.abc.Sequence[langchain_core.documents.Document], + kwargs: typing.Any = {} +) -> langchain_core.indexing.base.UpsertResponse +``` + + + + + + +abstract + +Upsert documents into the index. + +The upsert functionality should utilize the ID field of the content object +if it is provided. If the ID is not provided, the upsert method is free +to generate an ID for the content. + +When an ID is specified and the content already exists in the `VectorStore`, +the upsert method should update the content with the new data. If the content +does not exist, the upsert method should add the item to the `VectorStore`. + +**Parameters:** + + +Sequence of documents to add to the `VectorStore`. + + + +Additional keyword arguments. + + +**Returns:** `UpsertResponse` + +A response object that contains the list of IDs that were + + + + + + + + + +```python +class langchain_core.indexing.base.InMemoryRecordManager( + namespace: str +) +``` + + + + + + +**Bases:** [RecordManager](#langchain_core-indexing-base-RecordManager) + +An in-memory record manager for testing purposes. + + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.acreate_schema() -> None +``` + + + + + + +async + +In-memory schema creation is simply ensuring the structure is initialized. + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.adelete_keys( + keys: collections.abc.Sequence[str] +) -> None +``` + + + + + + +async + +Async delete specified records from the database. + +**Parameters:** + + +A list of keys to delete. + + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.aexists( + keys: collections.abc.Sequence[str] +) -> list[bool] +``` + + + + + + +async + +Async check if the provided keys exist in the database. + +**Parameters:** + + +A list of keys to check. + + +**Returns:** `list[bool]` + +A list of boolean values indicating the existence of each key. + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.aget_time() -> float +``` + + + + + + +async + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.alist_keys( + before: float | None = None, + after: float | None = None, + group_ids: collections.abc.Sequence[str] | None = None, + limit: int | None = None +) -> list[str] +``` + + + + + + +async + +Async list records in the database based on the provided filters. + +**Parameters:** + + +Filter to list records updated before this time. + + + +Filter to list records updated after this time. + + + +Filter to list records with specific group IDs. + + + +optional limit on the number of records to return. + + +**Returns:** `list[str]` + +A list of keys for the matching records. + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.aupdate( + keys: collections.abc.Sequence[str], + group_ids: collections.abc.Sequence[str | None] | None = None, + time_at_least: float | None = None +) -> None +``` + + + + + + +async + +Async upsert records into the database. + +**Parameters:** + + +A list of record keys to upsert. + + + +A list of group IDs corresponding to the keys. + + + +Optional timestamp. Implementation can use this +to optionally verify that the timestamp IS at least this time +in the system that stores. +E.g., use to validate that the time in the postgres database +is equal to or larger than the given timestamp, if not +raise an error. +This is meant to help prevent time-drift issues since +time may not be monotonically increasing! + + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.create_schema() -> None +``` + + + + + + +In-memory schema creation is simply ensuring the structure is initialized. + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.delete_keys( + keys: collections.abc.Sequence[str] +) -> None +``` + + + + + + +Delete specified records from the database. + +**Parameters:** + + +A list of keys to delete. + + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.exists( + keys: collections.abc.Sequence[str] +) -> list[bool] +``` + + + + + + +Check if the provided keys exist in the database. + +**Parameters:** + + +A list of keys to check. + + +**Returns:** `list[bool]` + +A list of boolean values indicating the existence of each key. + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.get_time() -> float +``` + + + + + + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.list_keys( + before: float | None = None, + after: float | None = None, + group_ids: collections.abc.Sequence[str] | None = None, + limit: int | None = None +) -> list[str] +``` + + + + + + +List records in the database based on the provided filters. + +**Parameters:** + + +Filter to list records updated before this time. + + + +Filter to list records updated after this time. + + + +Filter to list records with specific group IDs. + + + +optional limit on the number of records to return. + + +**Returns:** `list[str]` + +A list of keys for the matching records. + + + + + + + +```python +langchain_core.indexing.base.InMemoryRecordManager.update( + keys: collections.abc.Sequence[str], + group_ids: collections.abc.Sequence[str | None] | None = None, + time_at_least: float | None = None +) -> None +``` + + + + + + +Upsert records into the database. + +**Parameters:** + + +A list of record keys to upsert. + + + +A list of group IDs corresponding to the keys. + + + +Optional timestamp. Implementation can use this +to optionally verify that the timestamp IS at least this time +in the system that stores. +E.g., use to validate that the time in the postgres database +is equal to or larger than the given timestamp, if not +raise an error. +This is meant to help prevent time-drift issues since +time may not be monotonically increasing! + + +**Raises:** + +- `ValueError`: If the length of keys doesn't match the length of group +ids. +- `ValueError`: If time_at_least is in the future. + + + + + + + + + +```python +class langchain_core.indexing.base.RecordManager( + namespace: str +) +``` + + + + + + +Abstract + +Abstract base class representing the interface for a record manager. + +The record manager abstraction is used by the langchain indexing API. + +The record manager keeps track of which documents have been +written into a `VectorStore` and when they were written. + +The indexing API computes hashes for each document and stores the hash +together with the write time and the source id in the record manager. + +On subsequent indexing runs, the indexing API can check the record manager +to determine which documents have already been indexed and which have not. + +This allows the indexing API to avoid re-indexing documents that have +already been indexed, and to only index new documents. + +The main benefit of this abstraction is that it works across many vectorstores. +To be supported, a `VectorStore` needs to only support the ability to add and +delete documents by ID. Using the record manager, the indexing API will +be able to delete outdated documents and avoid redundant indexing of documents +that have already been indexed. + +The main constraints of this abstraction are: + +1. It relies on the time-stamps to determine which documents have been + indexed and which have not. This means that the time-stamps must be + monotonically increasing. The timestamp should be the timestamp + as measured by the server to minimize issues. +2. The record manager is currently implemented separately from the + vectorstore, which means that the overall system becomes distributed + and may create issues with consistency. For example, writing to + record manager succeeds, but corresponding writing to `VectorStore` fails. + + + + + + +```python +langchain_core.indexing.base.RecordManager.acreate_schema() -> None +``` + + + + + + +async abstract + +Asynchronously create the database schema for the record manager. + + + + + + + +```python +langchain_core.indexing.base.RecordManager.adelete_keys( + keys: collections.abc.Sequence[str] +) -> None +``` + + + + + + +async abstract + +Asynchronously delete specified records from the database. + +**Parameters:** + + +A list of keys to delete. + + + + + + + + +```python +langchain_core.indexing.base.RecordManager.aexists( + keys: collections.abc.Sequence[str] +) -> list[bool] +``` + + + + + + +async abstract + +Asynchronously check if the provided keys exist in the database. + +**Parameters:** + + +A list of keys to check. + + +**Returns:** `list[bool]` + +A list of boolean values indicating the existence of each key. + + + + + + + +```python +langchain_core.indexing.base.RecordManager.aget_time() -> float +``` + + + + + + +async abstract + +Asynchronously get the current server time as a high resolution timestamp. + +It's important to get this from the server to ensure a monotonic clock, +otherwise there may be data loss when cleaning up old documents! + +**Returns:** `float` + +The current server time as a float timestamp. + + + + + + + +```python +langchain_core.indexing.base.RecordManager.alist_keys( + before: float | None = None, + after: float | None = None, + group_ids: collections.abc.Sequence[str] | None = None, + limit: int | None = None +) -> list[str] +``` + + + + + + +async abstract + +Asynchronously list records in the database based on the provided filters. + +**Parameters:** + + +Filter to list records updated before this time. + + + +Filter to list records updated after this time. + + + +Filter to list records with specific group IDs. + + + +optional limit on the number of records to return. + + +**Returns:** `list[str]` + +A list of keys for the matching records. + + + + + + + +```python +langchain_core.indexing.base.RecordManager.aupdate( + keys: collections.abc.Sequence[str], + group_ids: collections.abc.Sequence[str | None] | None = None, + time_at_least: float | None = None +) -> None +``` + + + + + + +async abstract + +Asynchronously upsert records into the database. + +**Parameters:** + + +A list of record keys to upsert. + + + +A list of group IDs corresponding to the keys. + + + +Optional timestamp. Implementation can use this +to optionally verify that the timestamp IS at least this time +in the system that stores the data. + +e.g., use to validate that the time in the postgres database +is equal to or larger than the given timestamp, if not +raise an error. + +This is meant to help prevent time-drift issues since +time may not be monotonically increasing! + + +**Raises:** + +- `ValueError`: If the length of keys doesn't match the length of group_ids. + + + + + + + +```python +langchain_core.indexing.base.RecordManager.create_schema() -> None +``` + + + + + + +abstract + +Create the database schema for the record manager. + + + + + + + +```python +langchain_core.indexing.base.RecordManager.delete_keys( + keys: collections.abc.Sequence[str] +) -> None +``` + + + + + + +abstract + +Delete specified records from the database. + +**Parameters:** + + +A list of keys to delete. + + + + + + + + +```python +langchain_core.indexing.base.RecordManager.exists( + keys: collections.abc.Sequence[str] +) -> list[bool] +``` + + + + + + +abstract + +Check if the provided keys exist in the database. + +**Parameters:** + + +A list of keys to check. + + +**Returns:** `list[bool]` + +A list of boolean values indicating the existence of each key. + + + + + + + +```python +langchain_core.indexing.base.RecordManager.get_time() -> float +``` + + + + + + +abstract + +Get the current server time as a high resolution timestamp! + +It's important to get this from the server to ensure a monotonic clock, +otherwise there may be data loss when cleaning up old documents! + +**Returns:** `float` + +The current server time as a float timestamp. + + + + + + + +```python +langchain_core.indexing.base.RecordManager.list_keys( + before: float | None = None, + after: float | None = None, + group_ids: collections.abc.Sequence[str] | None = None, + limit: int | None = None +) -> list[str] +``` + + + + + + +abstract + +List records in the database based on the provided filters. + +**Parameters:** + + +Filter to list records updated before this time. + + + +Filter to list records updated after this time. + + + +Filter to list records with specific group IDs. + + + +optional limit on the number of records to return. + + +**Returns:** `list[str]` + +A list of keys for the matching records. + + + + + + + +```python +langchain_core.indexing.base.RecordManager.update( + keys: collections.abc.Sequence[str], + group_ids: collections.abc.Sequence[str | None] | None = None, + time_at_least: float | None = None +) -> None +``` + + + + + + +abstract + +Upsert records into the database. + +**Parameters:** + + +A list of record keys to upsert. + + + +A list of group IDs corresponding to the keys. + + + +Optional timestamp. Implementation can use this +to optionally verify that the timestamp IS at least this time +in the system that stores the data. + +e.g., use to validate that the time in the postgres database +is equal to or larger than the given timestamp, if not +raise an error. + +This is meant to help prevent time-drift issues since +time may not be monotonically increasing! + + +**Raises:** + +- `ValueError`: If the length of keys doesn't match the length of group_ids. + + + + + + + + + +```python +class langchain_core.indexing.base.UpsertResponse +``` + + + + + + +**Bases:** `typing.TypedDict` + +A generic response for upsert operations. + +The upsert response will be used by abstractions that implement an upsert +operation for content that can be upserted by ID. + +Upsert APIs that accept inputs with IDs and generate IDs internally +will return a response that includes the IDs that succeeded and the IDs +that failed. + +If there are no failures, the failed list will be empty, and the order +of the IDs in the succeeded list will match the order of the input documents. + +If there are failures, the response becomes ill defined, and a user of the API +cannot determine which generated ID corresponds to which input document. + +It is recommended for users explicitly attach the IDs to the items being +indexed to avoid this issue. + + +The IDs that failed to index. + + + +The IDs that were successfully indexed. + + + + + + + + +```python +class langchain_core.indexing.base._Record +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/in_memory.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/in_memory.mdx new file mode 100644 index 0000000..e6e28ea --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/indexing/in_memory.mdx @@ -0,0 +1,151 @@ +--- +layout: overview +slug: langchain-core/langchain_core/indexing/in_memory +title: langchain_core.indexing.in_memory +--- + +In memory document index. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`InMemoryDocumentIndex`](#langchain_core-indexing-in_memory-InMemoryDocumentIndex) | In memory document index. | + +### API + + + + + +```python +class langchain_core.indexing.in_memory.InMemoryDocumentIndex() +``` + + + + + + +**Bases:** [DocumentIndex](/langchain-core/langchain_core/indexing/base#langchain_core-indexing-base-DocumentIndex) + +In memory document index. + +This is an in-memory document index that stores documents in a dictionary. + +It provides a simple search API that returns documents by the number of +counts the given query appears in the document. + + + + + + + + + + + +```python +langchain_core.indexing.in_memory.InMemoryDocumentIndex._get_relevant_documents( + query: str, + run_manager: langchain_core.callbacks.CallbackManagerForRetrieverRun +) -> list[langchain_core.documents.Document] +``` + + + + + + + + + + + + +```python +langchain_core.indexing.in_memory.InMemoryDocumentIndex.delete( + ids: list[str] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.indexing.base.DeleteResponse +``` + + + + + + +Delete by IDs. + +**Parameters:** + + +List of IDs to delete. + + +**Returns:** `DeleteResponse` + +A response object that contains the list of IDs that were successfully + +**Raises:** + +- `ValueError`: If IDs is None. + + + + + + + +```python +langchain_core.indexing.in_memory.InMemoryDocumentIndex.get( + ids: collections.abc.Sequence[str], + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + + + + + + + +```python +langchain_core.indexing.in_memory.InMemoryDocumentIndex.upsert( + items: collections.abc.Sequence[langchain_core.documents.Document], + kwargs: typing.Any = {} +) -> langchain_core.indexing.UpsertResponse +``` + + + + + + +Upsert documents into the index. + +**Parameters:** + + +Sequence of documents to add to the index. + + + +Additional keyword arguments. + + +**Returns:** `UpsertResponse` + +A response object that contains the list of IDs that were + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models.mdx new file mode 100644 index 0000000..ffc54f3 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models.mdx @@ -0,0 +1,120 @@ +--- +layout: overview +slug: langchain-core/langchain_core/language_models +title: langchain_core.language_models +--- + +Core language model abstractions. + +LangChain has two main classes to work with language models: chat models and +"old-fashioned" LLMs (string-in, string-out). + +**Chat models** + +Language models that use a sequence of messages as inputs and return chat messages +as outputs (as opposed to using plain text). + +Chat models support the assignment of distinct roles to conversation messages, helping +to distinguish messages from the AI, users, and instructions such as system messages. + +The key abstraction for chat models is +[`BaseChatModel`][langchain_core.language_models.BaseChatModel]. Implementations should +inherit from this class. + +See existing [chat model integrations](https://docs.langchain.com/oss/python/integrations/chat). + +**LLMs (legacy)** + +Language models that takes a string as input and returns a string. + +These are traditionally older models (newer models generally are chat models). + +Although the underlying models are string in, string out, the LangChain wrappers also +allow these models to take messages as input. This gives them the same interface as +chat models. When messages are passed in as input, they will be formatted into a string +under the hood before being passed to the underlying model. + +## Submodules + +- **[`langchain_core.language_models._utils`](/langchain-core/langchain_core/language_models/_utils)** +- **[`langchain_core.language_models.base`](/langchain-core/langchain_core/language_models/base)** +- **[`langchain_core.language_models.chat_models`](/langchain-core/langchain_core/language_models/chat_models)** +- **[`langchain_core.language_models.fake`](/langchain-core/langchain_core/language_models/fake)** +- **[`langchain_core.language_models.fake_chat_models`](/langchain-core/langchain_core/language_models/fake_chat_models)** +- **[`langchain_core.language_models.llms`](/langchain-core/langchain_core/language_models/llms)** +- **[`langchain_core.language_models.model_profile`](/langchain-core/langchain_core/language_models/model_profile)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-language_models-__dir__) | - | +| [`__getattr__`](#langchain_core-language_models-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-language_models-__all__) + +[`_dynamic_imports`](#langchain_core-language_models-_dynamic_imports) + +### API + + + + + +```python +langchain_core.language_models.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.language_models.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.language_models.__all__ = ('LLM', 'BaseChatModel', 'BaseLLM', 'BaseLanguageModel', 'FakeListChatModel', 'F... +``` + + + + + + + + + +```python +langchain_core.language_models._dynamic_imports = {'BaseLanguageModel': 'base', 'LangSmithParams': 'base', 'LanguageModelInput': '... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/_utils.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/_utils.mdx new file mode 100644 index 0000000..9a5bf1c --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/_utils.mdx @@ -0,0 +1,327 @@ +--- +layout: overview +slug: langchain-core/langchain_core/language_models/_utils +title: langchain_core.language_models._utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ParsedDataUri`](#langchain_core-language_models-_utils-ParsedDataUri) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_ensure_message_copy`](#langchain_core-language_models-_utils-_ensure_message_copy) | Create a copy of the message if it hasn't been copied yet. | +| [`_normalize_messages`](#langchain_core-language_models-_utils-_normalize_messages) | Normalize message formats to LangChain v1 standard content blocks. | +| [`_parse_data_uri`](#langchain_core-language_models-_utils-_parse_data_uri) | Parse a data URI into its components. | +| [`_update_content_block`](#langchain_core-language_models-_utils-_update_content_block) | Update a content block at the given index, handling type issues. | +| [`_update_message_content_to_blocks`](#langchain_core-language_models-_utils-_update_message_content_to_blocks) | - | +| [`is_openai_data_block`](#langchain_core-language_models-_utils-is_openai_data_block) | Check whether a block contains multimodal data in OpenAI Chat Completions format. | + +### Data + +[`T`](#langchain_core-language_models-_utils-T) + +### API + + + + + +```python +class langchain_core.language_models._utils.ParsedDataUri +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + +```python +langchain_core.language_models._utils._ensure_message_copy( + message: langchain_core.language_models._utils.T, + formatted_message: langchain_core.language_models._utils.T +) -> langchain_core.language_models._utils.T +``` + + + + + + +Create a copy of the message if it hasn't been copied yet. + + + + + + + + +```python +langchain_core.language_models._utils._normalize_messages( + messages: collections.abc.Sequence[langchain_core.messages.BaseMessage] +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +Normalize message formats to LangChain v1 standard content blocks. + +Chat models already implement support for: +- Images in OpenAI Chat Completions format + These will be passed through unchanged +- LangChain v1 standard content blocks + +This function extends support to: +- `[Audio](https://platform.openai.com/docs/api-reference/chat/create) and + `[file](https://platform.openai.com/docs/api-reference/files) data in OpenAI + Chat Completions format + - Images are technically supported but we expect chat models to handle them + directly; this may change in the future +- LangChain v0 standard content blocks for backward compatibility + +!!! warning "Behavior changed in `langchain-core` 1.0.0" + + In previous versions, this function returned messages in LangChain v0 format. + Now, it returns messages in LangChain v1 format, which upgraded chat models now + expect to receive when passing back in message history. For backward + compatibility, this function will convert v0 message content to v1 format. + +??? note "v0 Content Block Schemas" + + `URLContentBlock`: + + ```python + { + mime_type: NotRequired[str] + type: Literal['image', 'audio', 'file'], + source_type: Literal['url'], + url: str, + } + ``` + + `Base64ContentBlock`: + + ```python + { + mime_type: NotRequired[str] + type: Literal['image', 'audio', 'file'], + source_type: Literal['base64'], + data: str, + } + ``` + + `IDContentBlock`: + + (In practice, this was never used) + + ```python + { + type: Literal["image", "audio", "file"], + source_type: Literal["id"], + id: str, + } + ``` + + `PlainTextContentBlock`: + + ```python + { + mime_type: NotRequired[str] + type: Literal['file'], + source_type: Literal['text'], + url: str, + } + ``` + +If a v1 message is passed in, it will be returned as-is, meaning it is safe to +always pass in v1 messages to this function for assurance. + +For posterity, here are the OpenAI Chat Completions schemas we expect: + +Chat Completions image. Can be URL-based or base64-encoded. Supports MIME types +png, jpeg/jpg, webp, static gif: +{ + "type": Literal['image_url'], + "image_url": { + "url": Union["data:$MIME_TYPE;base64,$BASE64_ENCODED_IMAGE", "$IMAGE_URL"], + "detail": Literal['low', 'high', 'auto'] = 'auto', # Supported by OpenAI + } +} + +Chat Completions audio: +{ + "type": Literal['input_audio'], + "input_audio": { + "format": Literal['wav', 'mp3'], + "data": str = "$BASE64_ENCODED_AUDIO", + }, +} + +Chat Completions files: either base64 or pre-uploaded file ID +{ + "type": Literal['file'], + "file": Union[ + { + "filename": str | None = "$FILENAME", + "file_data": str = "$BASE64_ENCODED_FILE", + }, + { + "file_id": str = "$FILE_ID", # For pre-uploaded files to OpenAI + }, + ], +} + + + + + + + + +```python +langchain_core.language_models._utils._parse_data_uri( + uri: str +) -> langchain_core.language_models._utils.ParsedDataUri | None +``` + + + + + + +Parse a data URI into its components. + +If parsing fails, return `None`. If either MIME type or data is missing, return +`None`. + + + + + + + + +```python +langchain_core.language_models._utils._update_content_block( + formatted_message: langchain_core.messages.BaseMessage, + idx: int, + new_block: langchain_core.messages.content.ContentBlock | dict +) -> None +``` + + + + + + +Update a content block at the given index, handling type issues. + + + + + + + + +```python +langchain_core.language_models._utils._update_message_content_to_blocks( + message: langchain_core.language_models._utils.T, + output_version: str +) -> langchain_core.language_models._utils.T +``` + + + + + + + + + + + + + +```python +langchain_core.language_models._utils.is_openai_data_block( + block: dict, + filter_: typing.Literal['image', 'audio', 'file'] | None = None +) -> bool +``` + + + + + + +Check whether a block contains multimodal data in OpenAI Chat Completions format. + +Supports both data and ID-style blocks (e.g. `'file_data'` and `'file_id'`) + +If additional keys are present, they are ignored / will not affect outcome as long +as the required keys are present and valid. + +**Parameters:** + + +The content block to check. + + + +If provided, only return True for blocks matching this specific type. +- "image": Only match image_url blocks +- "audio": Only match input_audio blocks +- "file": Only match file blocks +If `None`, match any valid OpenAI data block type. Note that this means that +if the block has a valid OpenAI data type but the filter_ is set to a +different type, this function will return False. + + +**Returns:** `bool` + +`True` if the block is a valid OpenAI data block and matches the filter_ + + + + + + + + +```python +langchain_core.language_models._utils.T = TypeVar('T', bound='BaseMessage') +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/base.mdx new file mode 100644 index 0000000..16d0736 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/base.mdx @@ -0,0 +1,608 @@ +--- +layout: overview +slug: langchain-core/langchain_core/language_models/base +title: langchain_core.language_models.base +--- + +Base language models class. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseLanguageModel`](#langchain_core-language_models-base-BaseLanguageModel) | Abstract base class for interfacing with language models. | +| [`LangSmithParams`](#langchain_core-language_models-base-LangSmithParams) | LangSmith parameters for tracing. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_token_ids_default_method`](#langchain_core-language_models-base-_get_token_ids_default_method) | Encode the text into token IDs using the fallback GPT-2 tokenizer. | +| [`_get_verbosity`](#langchain_core-language_models-base-_get_verbosity) | - | +| [`get_tokenizer`](#langchain_core-language_models-base-get_tokenizer) | Get a GPT-2 tokenizer instance. | + +### Data + +[`LanguageModelInput`](#langchain_core-language_models-base-LanguageModelInput) + +[`LanguageModelLike`](#langchain_core-language_models-base-LanguageModelLike) + +[`LanguageModelOutput`](#langchain_core-language_models-base-LanguageModelOutput) + +[`LanguageModelOutputVar`](#langchain_core-language_models-base-LanguageModelOutputVar) + +[`_GPT2_TOKENIZER_WARNED`](#langchain_core-language_models-base-_GPT2_TOKENIZER_WARNED) + +[`_HAS_TRANSFORMERS`](#langchain_core-language_models-base-_HAS_TRANSFORMERS) + +### API + + + + + +```python +class langchain_core.language_models.base.BaseLanguageModel() +``` + + + + + + +Abstract + +**Bases:** [RunnableSerializable[LanguageModelInput, LanguageModelOutputVar]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +Abstract base class for interfacing with language models. + +All language model wrappers inherited from `BaseLanguageModel`. + + + +Get the input type for this `Runnable`. + + + +Get the identifying parameters. + + + +Whether to cache the response. + +* If `True`, will use the global cache. +* If `False`, will not use a cache +* If `None`, will use the global cache if it's set, otherwise no cache. +* If instance of `BaseCache`, will use the provided cache. + +Caching is not currently supported for streaming methods of models. + + + +Callbacks to add to the run trace. + + + +Optional encoder to use for counting tokens. + + + +Metadata to add to the run trace. + + + + + + +Tags to add to the run trace. + + + +Whether to print out response text. + + + + + +```python +langchain_core.language_models.base.BaseLanguageModel.agenerate_prompt( + prompts: list[langchain_core.prompt_values.PromptValue], + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +async abstract + +Asynchronously pass a sequence of prompts and return model generations. + +This method should make use of batched calls for models that expose a batched +API. + +Use this method when you want to: + +1. Take advantage of batched calls, +2. Need more output from the model than just the top generated value, +3. Are building chains that are agnostic to the underlying language model + type (e.g., pure text completion models vs chat models). + +**Parameters:** + + +List of `PromptValue` objects. + +A `PromptValue` is an object that can be converted to match the format +of any language model (string for pure text generation models and +`BaseMessage` objects for chat models). + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + + + +`Callbacks` to pass through. + +Used for executing additional functionality, such as logging or +streaming, throughout generation. + + + +Arbitrary additional keyword arguments. + +These are usually passed to the model provider API call. + + +**Returns:** `LLMResult` + +An `LLMResult`, which contains a list of candidate `Generation` objects for +each input prompt and additional model provider-specific output. + + + + + + + +```python +langchain_core.language_models.base.BaseLanguageModel.generate_prompt( + prompts: list[langchain_core.prompt_values.PromptValue], + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +abstract + +Pass a sequence of prompts to the model and return model generations. + +This method should make use of batched calls for models that expose a batched +API. + +Use this method when you want to: + +1. Take advantage of batched calls, +2. Need more output from the model than just the top generated value, +3. Are building chains that are agnostic to the underlying language model + type (e.g., pure text completion models vs chat models). + +**Parameters:** + + +List of `PromptValue` objects. + +A `PromptValue` is an object that can be converted to match the format +of any language model (string for pure text generation models and +`BaseMessage` objects for chat models). + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + + + +`Callbacks` to pass through. + +Used for executing additional functionality, such as logging or +streaming, throughout generation. + + + +Arbitrary additional keyword arguments. + +These are usually passed to the model provider API call. + + +**Returns:** `LLMResult` + +An `LLMResult`, which contains a list of candidate `Generation` objects for +each input prompt and additional model provider-specific output. + + + + + + + +```python +langchain_core.language_models.base.BaseLanguageModel.get_num_tokens( + text: str +) -> int +``` + + + + + + +Get the number of tokens present in the text. + +Useful for checking if an input fits in a model's context window. + +This should be overridden by model-specific implementations to provide accurate +token counts via model-specific tokenizers. + +**Parameters:** + + +The string input to tokenize. + + +**Returns:** `int` + +The integer number of tokens in the text. + + + + + + + +```python +langchain_core.language_models.base.BaseLanguageModel.get_num_tokens_from_messages( + messages: list[langchain_core.messages.BaseMessage], + tools: collections.abc.Sequence | None = None +) -> int +``` + + + + + + +Get the number of tokens in the messages. + +Useful for checking if an input fits in a model's context window. + +This should be overridden by model-specific implementations to provide accurate +token counts via model-specific tokenizers. + +!!! note + + * The base implementation of `get_num_tokens_from_messages` ignores tool + schemas. + * The base implementation of `get_num_tokens_from_messages` adds additional + prefixes to messages in represent user roles, which will add to the + overall token count. Model-specific implementations may choose to + handle this differently. + +**Parameters:** + + +The message inputs to tokenize. + + + +If provided, sequence of dict, `BaseModel`, function, or +`BaseTool` objects to be converted to tool schemas. + + +**Returns:** `int` + +The sum of the number of tokens across the messages. + + + + + + + +```python +langchain_core.language_models.base.BaseLanguageModel.get_token_ids( + text: str +) -> list[int] +``` + + + + + + +Return the ordered IDs of the tokens in a text. + +**Parameters:** + + +The string input to tokenize. + + +**Returns:** `list[int]` + +A list of IDs corresponding to the tokens in the text, in order they occur +in the text. + + + + + + + +```python +langchain_core.language_models.base.BaseLanguageModel.set_verbose( + verbose: bool | None +) -> bool +``` + + + + + + +If verbose is `None`, set it. + +This allows users to pass in `None` as verbose to access the global setting. + +**Parameters:** + + +The verbosity setting to use. + + +**Returns:** `bool` + +The verbosity setting to use. + + + + + + + +```python +langchain_core.language_models.base.BaseLanguageModel.with_structured_output( + schema: dict | type, + kwargs: typing.Any = {} +) -> langchain_core.runnables.Runnable[langchain_core.language_models.base.LanguageModelInput, dict | pydantic.BaseModel] +``` + + + + + + +Not implemented on this class. + + + + + + + + + +```python +class langchain_core.language_models.base.LangSmithParams +``` + + + + + + +**Bases:** `typing.TypedDict` + +LangSmith parameters for tracing. + + +Max tokens for generation. + + + +Name of the model. + + + +Type of the model. + +Should be `'chat'` or `'llm'`. + + + +Provider of the model. + + + +Stop words for generation. + + + +Temperature for generation. + + + + + + + + +```python +langchain_core.language_models.base._get_token_ids_default_method( + text: str +) -> list[int] +``` + + + + + + +Encode the text into token IDs using the fallback GPT-2 tokenizer. + + + + + + + + +```python +langchain_core.language_models.base._get_verbosity() -> bool +``` + + + + + + + + + + + + + +```python +langchain_core.language_models.base.get_tokenizer() -> typing.Any +``` + + + + + + +Get a GPT-2 tokenizer instance. + +This function is cached to avoid re-loading the tokenizer every time it is called. + +**Returns:** `Any` + +The GPT-2 tokenizer instance. + +**Raises:** + +- `ImportError`: If the transformers package is not installed. + + + + + + + + +```python +langchain_core.language_models.base.LanguageModelInput = PromptValue | str | Sequence[MessageLikeRepresentation] +``` + + + + + + +Input to a language model. + + + + + + + +```python +langchain_core.language_models.base.LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput] +``` + + + + + + +Input/output interface for a language model. + + + + + + + +```python +langchain_core.language_models.base.LanguageModelOutput = BaseMessage | str +``` + + + + + + +Output from a language model. + + + + + + + +```python +langchain_core.language_models.base.LanguageModelOutputVar = TypeVar('LanguageModelOutputVar', AIMessage, str) +``` + + + + + + +Type variable for the output of a language model. + + + + + + + +```python +langchain_core.language_models.base._GPT2_TOKENIZER_WARNED = False +``` + + + + + + + + + +```python +langchain_core.language_models.base._HAS_TRANSFORMERS = True +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/chat_models.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/chat_models.mdx new file mode 100644 index 0000000..903310c --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/chat_models.mdx @@ -0,0 +1,1297 @@ +--- +layout: overview +slug: langchain-core/langchain_core/language_models/chat_models +title: langchain_core.language_models.chat_models +--- + +Chat models for conversational AI. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseChatModel`](#langchain_core-language_models-chat_models-BaseChatModel) | Base class for chat models. | +| [`SimpleChatModel`](#langchain_core-language_models-chat_models-SimpleChatModel) | Simplified implementation for a chat model to inherit from. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_cleanup_llm_representation`](#langchain_core-language_models-chat_models-_cleanup_llm_representation) | Remove non-serializable objects from a serialized object. | +| [`_format_for_tracing`](#langchain_core-language_models-chat_models-_format_for_tracing) | Format messages for tracing in `on_chat_model_start`. | +| [`_format_ls_structured_output`](#langchain_core-language_models-chat_models-_format_ls_structured_output) | - | +| [`_gen_info_and_msg_metadata`](#langchain_core-language_models-chat_models-_gen_info_and_msg_metadata) | - | +| [`_generate_response_from_error`](#langchain_core-language_models-chat_models-_generate_response_from_error) | - | +| [`agenerate_from_stream`](#langchain_core-language_models-chat_models-agenerate_from_stream) | Async generate from a stream. | +| [`generate_from_stream`](#langchain_core-language_models-chat_models-generate_from_stream) | Generate from a stream. | + +### Data + +[`_MAX_CLEANUP_DEPTH`](#langchain_core-language_models-chat_models-_MAX_CLEANUP_DEPTH) + +### API + + + + + +```python +class langchain_core.language_models.chat_models.BaseChatModel() +``` + + + + + + +Abstract + +**Bases:** [BaseLanguageModel[AIMessage]](/langchain-core/langchain_core/language_models/base#langchain_core-language_models-base-BaseLanguageModel) + +Base class for chat models. + + + +Get the output type for this `Runnable`. + + + +Return type of chat model. + + + + + + +Whether to disable streaming for this model. + +If streaming is bypassed, then `stream`/`astream`/`astream_events` will +defer to `invoke`/`ainvoke`. + +- If `True`, will always bypass streaming case. +- If `'tool_calling'`, will bypass streaming case only when the model is called + with a `tools` keyword argument. In other words, LangChain will automatically + switch to non-streaming behavior (`invoke`) only when the tools argument is + provided. This offers the best of both worlds. +- If `False` (Default), will always use streaming case if available. + +The main reason for this flag is that code might be written using `stream` and +a user may want to swap out a given model for another model whose implementation +does not properly support streaming. + + + + + + +Version of `AIMessage` output format to store in message content. + +`AIMessage.content_blocks` will lazily parse the contents of `content` into a +standard format. This flag can be used to additionally store the standard format +in message content, e.g., for serialization purposes. + +Supported values: + +- `'v0'`: provider-specific format in content (can lazily-parse with + `content_blocks`) +- `'v1'`: standardized format in content (consistent with `content_blocks`) + +Partner packages (e.g., +[`langchain-openai`](https://pypi.org/project/langchain-openai)) can also use this +field to roll out new content formats in a backward-compatible way. + +!!! version-added "Added in `langchain-core` 1.0.0" + + + +Profile detailing model capabilities. + +!!! warning "Beta feature" + + This is a beta feature. The format of model profiles is subject to change. + +If not specified, automatically loaded from the provider package on initialization +if data is available. + +Example profile data includes context window sizes, supported modalities, or support +for tool calling, structured output, and other features. + +!!! version-added "Added in `langchain-core` 1.1.0" + + + +An optional rate limiter to use for limiting the number of requests. + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._agenerate( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.ChatResult +``` + + + + + + +async + +Generate the result. + +**Parameters:** + + +The messages to generate from. + + + +Optional list of stop words to use when generating. + + + +Optional callback manager to use for this call. + + + +Additional keyword arguments to pass to the model. + + +**Returns:** `ChatResult` + +The chat result. + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._agenerate_with_cache( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.ChatResult +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._astream( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.outputs.ChatGenerationChunk] +``` + + + + + + +async + +Stream the output of the model. + +**Parameters:** + + +The messages to generate from. + + + +Optional list of stop words to use when generating. + + + +Optional callback manager to use for this call. + + + +Additional keyword arguments to pass to the model. + + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._call_async( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.BaseMessage +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._combine_llm_outputs( + _llm_outputs: list[langchain_core.language_models.chat_models.BaseChatModel.dict | None] +) -> langchain_core.language_models.chat_models.BaseChatModel.dict +``` + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._convert_cached_generations( + cache_val: list +) -> list[langchain_core.outputs.ChatGeneration] +``` + + + + + + +Convert cached Generation objects to ChatGeneration objects. + +Handle case where cache contains Generation objects instead of +ChatGeneration objects. This can happen due to serialization/deserialization +issues or legacy cache data (see #22389). + +**Parameters:** + + +List of cached generation objects. + + +**Returns:** `list[ChatGeneration]` + +List of ChatGeneration objects. + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._convert_input( + model_input: langchain_core.language_models.base.LanguageModelInput +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._generate( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.ChatResult +``` + + + + + + +abstract + +Generate the result. + +**Parameters:** + + +The messages to generate from. + + + +Optional list of stop words to use when generating. + + + +Optional callback manager to use for this call. + + + +Additional keyword arguments to pass to the model. + + +**Returns:** `ChatResult` + +The chat result. + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._generate_with_cache( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.ChatResult +``` + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._get_invocation_params( + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.language_models.chat_models.BaseChatModel.dict +``` + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._get_llm_string( + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> str +``` + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._get_ls_params( + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.language_models.base.LangSmithParams +``` + + + + + + +Get standard params for tracing. + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._should_stream( + async_api: bool, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +Determine if a given model call should hit the streaming API. + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel._stream( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.outputs.ChatGenerationChunk] +``` + + + + + + +Stream the output of the model. + +**Parameters:** + + +The messages to generate from. + + + +Optional list of stop words to use when generating. + + + +Optional callback manager to use for this call. + + + +Additional keyword arguments to pass to the model. + + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel.agenerate( + messages: list[list[langchain_core.messages.BaseMessage]], + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks = None, + tags: list[str] | None = None, + metadata: langchain_core.language_models.chat_models.BaseChatModel.dict[str, typing.Any] | None = None, + run_name: str | None = None, + run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +async + +Asynchronously pass a sequence of prompts to a model and return generations. + +This method should make use of batched calls for models that expose a batched +API. + +Use this method when you want to: + +1. Take advantage of batched calls, +2. Need more output from the model than just the top generated value, +3. Are building chains that are agnostic to the underlying language model + type (e.g., pure text completion models vs chat models). + +**Parameters:** + + +List of list of messages. + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + + + +`Callbacks` to pass through. + +Used for executing additional functionality, such as logging or +streaming, throughout generation. + + + +The tags to apply. + + + +The metadata to apply. + + + +The name of the run. + + + +The ID of the run. + + + +Arbitrary additional keyword arguments. + +These are usually passed to the model provider API call. + + +**Returns:** `LLMResult` + +An `LLMResult`, which contains a list of candidate `Generations` for each +input prompt and additional model provider-specific output. + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel.agenerate_prompt( + prompts: list[langchain_core.prompt_values.PromptValue], + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel.ainvoke( + input: langchain_core.language_models.base.LanguageModelInput, + config: langchain_core.runnables.RunnableConfig | None = None, + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.AIMessage +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel.astream( + input: langchain_core.language_models.base.LanguageModelInput, + config: langchain_core.runnables.RunnableConfig | None = None, + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.messages.AIMessageChunk] +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel.bind_tools( + tools: collections.abc.Sequence[builtins.dict[str, typing.Any] | type | collections.abc.Callable | langchain_core.tools.BaseTool], + tool_choice: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.Runnable[langchain_core.language_models.base.LanguageModelInput, langchain_core.messages.AIMessage] +``` + + + + + + +Bind tools to the model. + +**Parameters:** + + +Sequence of tools to bind to the model. + + + +The tool to use. If "any" then any tool can be used. + + +**Returns:** `Runnable[LanguageModelInput, AIMessage]` + +A Runnable that returns a message. + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel.dict( + kwargs: typing.Any = {} +) -> langchain_core.language_models.chat_models.BaseChatModel.dict +``` + + + + + + +Return a dictionary of the LLM. + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel.generate( + messages: list[list[langchain_core.messages.BaseMessage]], + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks = None, + tags: list[str] | None = None, + metadata: langchain_core.language_models.chat_models.BaseChatModel.dict[str, typing.Any] | None = None, + run_name: str | None = None, + run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +Pass a sequence of prompts to the model and return model generations. + +This method should make use of batched calls for models that expose a batched +API. + +Use this method when you want to: + +1. Take advantage of batched calls, +2. Need more output from the model than just the top generated value, +3. Are building chains that are agnostic to the underlying language model + type (e.g., pure text completion models vs chat models). + +**Parameters:** + + +List of list of messages. + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + + + +`Callbacks` to pass through. + +Used for executing additional functionality, such as logging or +streaming, throughout generation. + + + +The tags to apply. + + + +The metadata to apply. + + + +The name of the run. + + + +The ID of the run. + + + +Arbitrary additional keyword arguments. + +These are usually passed to the model provider API call. + + +**Returns:** `LLMResult` + +An `LLMResult`, which contains a list of candidate `Generations` for each +input prompt and additional model provider-specific output. + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel.generate_prompt( + prompts: list[langchain_core.prompt_values.PromptValue], + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel.invoke( + input: langchain_core.language_models.base.LanguageModelInput, + config: langchain_core.runnables.RunnableConfig | None = None, + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.AIMessage +``` + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel.stream( + input: langchain_core.language_models.base.LanguageModelInput, + config: langchain_core.runnables.RunnableConfig | None = None, + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.messages.AIMessageChunk] +``` + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models.BaseChatModel.with_structured_output( + schema: builtins.dict[str, typing.Any] | type, + include_raw: bool = False, + kwargs: typing.Any = {} +) -> langchain_core.runnables.Runnable[langchain_core.language_models.base.LanguageModelInput, builtins.dict[str, typing.Any] | pydantic.BaseModel] +``` + + + + + + +Model wrapper that returns outputs formatted to match the given schema. + +???+ example "Pydantic schema (`include_raw=False`)" + + ```python + from pydantic import BaseModel + + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + + answer: str + justification: str + + + model = ChatModel(model="model-name", temperature=0) + structured_model = model.with_structured_output(AnswerWithJustification) + + structured_model.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) + + # -> AnswerWithJustification( + # answer='They weigh the same', + # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.' + # ) + ``` + +??? example "Pydantic schema (`include_raw=True`)" + + ```python + from pydantic import BaseModel + + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + + answer: str + justification: str + + + model = ChatModel(model="model-name", temperature=0) + structured_model = model.with_structured_output( + AnswerWithJustification, include_raw=True + ) + + structured_model.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) + # -> { + # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), + # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), + # 'parsing_error': None + # } + ``` + +??? example "Dictionary schema (`include_raw=False`)" + + ```python + from pydantic import BaseModel + from langchain_core.utils.function_calling import convert_to_openai_tool + + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + + answer: str + justification: str + + + dict_schema = convert_to_openai_tool(AnswerWithJustification) + model = ChatModel(model="model-name", temperature=0) + structured_model = model.with_structured_output(dict_schema) + + structured_model.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) + # -> { + # 'answer': 'They weigh the same', + # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' + # } + ``` + +!!! warning "Behavior changed in `langchain-core` 0.2.26" + + Added support for `TypedDict` class. + +**Parameters:** + + +The output schema. Can be passed in as: + +- An OpenAI function/tool schema, +- A JSON Schema, +- A `TypedDict` class, +- Or a Pydantic class. + +If `schema` is a Pydantic class then the model output will be a +Pydantic instance of that class, and the model-generated fields will be +validated by the Pydantic class. Otherwise the model output will be a +dict and will not be validated. + +See `langchain_core.utils.function_calling.convert_to_openai_tool` for +more on how to properly specify types and descriptions of schema fields +when specifying a Pydantic or `TypedDict` class. + + + + +If `False` then only the parsed structured output is returned. + +If an error occurs during model output parsing it will be raised. + +If `True` then both the raw model response (a `BaseMessage`) and the +parsed model response will be returned. + +If an error occurs during output parsing it will be caught and returned +as well. + +The final output is always a `dict` with keys `'raw'`, `'parsed'`, and +`'parsing_error'`. + + +**Returns:** `Runnable[LanguageModelInput, builtins.dict[str, Any] | BaseModel]` + +A `Runnable` that takes same inputs as a +`langchain_core.language_models.chat.BaseChatModel`. If `include_raw` is +`False` and `schema` is a Pydantic class, `Runnable` outputs an instance +of `schema` (i.e., a Pydantic object). Otherwise, if `include_raw` is +`False` then `Runnable` outputs a `dict`. + +If `include_raw` is `True`, then `Runnable` outputs a `dict` with keys: + +- `'raw'`: `BaseMessage` +- `'parsed'`: `None` if there was a parsing error, otherwise the type + depends on the `schema` as described above. +- `'parsing_error'`: `BaseException | None` + +**Raises:** + +- `ValueError`: If there are any unsupported `kwargs`. +- `NotImplementedError`: If the model does not implement +`with_structured_output()`. + + + + + + + + + +```python +class langchain_core.language_models.chat_models.SimpleChatModel() +``` + + + + + + +**Bases:** [BaseChatModel](#langchain_core-language_models-chat_models-BaseChatModel) + +Simplified implementation for a chat model to inherit from. + +!!! note + This implementation is primarily here for backwards compatibility. For new + implementations, please use `BaseChatModel` directly. + + + + + + +```python +langchain_core.language_models.chat_models.SimpleChatModel._agenerate( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.ChatResult +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.chat_models.SimpleChatModel._call( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> str +``` + + + + + + +abstract + +Simpler interface. + + + + + + + +```python +langchain_core.language_models.chat_models.SimpleChatModel._generate( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.ChatResult +``` + + + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models._cleanup_llm_representation( + serialized: typing.Any, + depth: int +) -> None +``` + + + + + + +Remove non-serializable objects from a serialized object. + + + + + + + + +```python +langchain_core.language_models.chat_models._format_for_tracing( + messages: list[langchain_core.messages.BaseMessage] +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +Format messages for tracing in `on_chat_model_start`. + +- Update image content blocks to OpenAI Chat Completions format (backward +compatibility). +- Add `type` key to content blocks that have a single key. + +**Parameters:** + + +List of messages to format. + + +**Returns:** `list[BaseMessage]` + +List of messages formatted for tracing. + + + + + + + + +```python +langchain_core.language_models.chat_models._format_ls_structured_output( + ls_structured_output_format: dict | None +) -> dict +``` + + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models._gen_info_and_msg_metadata( + generation: langchain_core.outputs.ChatGeneration | langchain_core.outputs.ChatGenerationChunk +) -> dict +``` + + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models._generate_response_from_error( + error: BaseException +) -> list[langchain_core.outputs.ChatGeneration] +``` + + + + + + + + + + + + + +```python +langchain_core.language_models.chat_models.agenerate_from_stream( + stream: collections.abc.AsyncIterator[langchain_core.outputs.ChatGenerationChunk] +) -> langchain_core.outputs.ChatResult +``` + + + + + + +async + +Async generate from a stream. + +**Parameters:** + + +AsyncIterator of `ChatGenerationChunk`. + + +**Returns:** `ChatResult` + +Chat result. + + + + + + + + +```python +langchain_core.language_models.chat_models.generate_from_stream( + stream: collections.abc.Iterator[langchain_core.outputs.ChatGenerationChunk] +) -> langchain_core.outputs.ChatResult +``` + + + + + + +Generate from a stream. + +**Parameters:** + + +Iterator of `ChatGenerationChunk`. + + +**Returns:** `ChatResult` + +Chat result. + +**Raises:** + +- `ValueError`: If no generations are found in the stream. + + + + + + + + +```python +langchain_core.language_models.chat_models._MAX_CLEANUP_DEPTH = 100 +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/fake.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/fake.mdx new file mode 100644 index 0000000..2ccf6dc --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/fake.mdx @@ -0,0 +1,199 @@ +--- +layout: overview +slug: langchain-core/langchain_core/language_models/fake +title: langchain_core.language_models.fake +--- + +Fake LLMs for testing purposes. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FakeListLLM`](#langchain_core-language_models-fake-FakeListLLM) | Fake LLM for testing purposes. | +| [`FakeListLLMError`](#langchain_core-language_models-fake-FakeListLLMError) | Fake error for testing purposes. | +| [`FakeStreamingListLLM`](#langchain_core-language_models-fake-FakeStreamingListLLM) | Fake streaming list LLM for testing purposes. | + +### API + + + + + +```python +class langchain_core.language_models.fake.FakeListLLM() +``` + + + + + + +**Bases:** [LLM](/langchain-core/langchain_core/language_models/llms#langchain_core-language_models-llms-LLM) + +Fake LLM for testing purposes. + + + + + + +Return type of llm. + + + +Internally incremented after every model invocation. + +Useful primarily for testing purposes. + + + +List of responses to return in order. + + + +Sleep time in seconds between responses. + +Ignored by FakeListLLM, but used by sub-classes. + + + + + +```python +langchain_core.language_models.fake.FakeListLLM._acall( + prompt: str, + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> str +``` + + + + + + +async + +Return next response. + + + + + + + +```python +langchain_core.language_models.fake.FakeListLLM._call( + prompt: str, + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> str +``` + + + + + + +Return next response. + + + + + + + + + +```python +class langchain_core.language_models.fake.FakeListLLMError() +``` + + + + + + +Exception + +**Bases:** `Exception` + +Fake error for testing purposes. + + + + + + + + +```python +class langchain_core.language_models.fake.FakeStreamingListLLM() +``` + + + + + + +**Bases:** [FakeListLLM](#langchain_core-language_models-fake-FakeListLLM) + +Fake streaming list LLM for testing purposes. + +An LLM that will return responses from a list in order. + +This model also supports optionally sleeping between successive +chunks in a streaming implementation. + + + +If set, will raise an exception on the specified chunk number. + + + + + +```python +langchain_core.language_models.fake.FakeStreamingListLLM.astream( + input: langchain_core.language_models.LanguageModelInput, + config: langchain_core.runnables.RunnableConfig | None = None, + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[str] +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.fake.FakeStreamingListLLM.stream( + input: langchain_core.language_models.LanguageModelInput, + config: langchain_core.runnables.RunnableConfig | None = None, + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[str] +``` + + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/fake_chat_models.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/fake_chat_models.mdx new file mode 100644 index 0000000..8331358 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/fake_chat_models.mdx @@ -0,0 +1,437 @@ +--- +layout: overview +slug: langchain-core/langchain_core/language_models/fake_chat_models +title: langchain_core.language_models.fake_chat_models +--- + +Fake chat models for testing purposes. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FakeChatModel`](#langchain_core-language_models-fake_chat_models-FakeChatModel) | Fake Chat Model wrapper for testing purposes. | +| [`FakeListChatModel`](#langchain_core-language_models-fake_chat_models-FakeListChatModel) | Fake chat model for testing purposes. | +| [`FakeListChatModelError`](#langchain_core-language_models-fake_chat_models-FakeListChatModelError) | Fake error for testing purposes. | +| [`FakeMessagesListChatModel`](#langchain_core-language_models-fake_chat_models-FakeMessagesListChatModel) | Fake chat model for testing purposes. | +| [`GenericFakeChatModel`](#langchain_core-language_models-fake_chat_models-GenericFakeChatModel) | Generic fake chat model that can be used to test the chat model interface. | +| [`ParrotFakeChatModel`](#langchain_core-language_models-fake_chat_models-ParrotFakeChatModel) | Generic fake chat model that can be used to test the chat model interface. | + +### API + + + + + +```python +class langchain_core.language_models.fake_chat_models.FakeChatModel() +``` + + + + + + +**Bases:** [SimpleChatModel](/langchain-core/langchain_core/language_models/chat_models#langchain_core-language_models-chat_models-SimpleChatModel) + +Fake Chat Model wrapper for testing purposes. + + + + + + + + + + + +```python +langchain_core.language_models.fake_chat_models.FakeChatModel._agenerate( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.ChatResult +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.fake_chat_models.FakeChatModel._call( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> str +``` + + + + + + + + + + + + + + +```python +class langchain_core.language_models.fake_chat_models.FakeListChatModel() +``` + + + + + + +**Bases:** [SimpleChatModel](/langchain-core/langchain_core/language_models/chat_models#langchain_core-language_models-chat_models-SimpleChatModel) + +Fake chat model for testing purposes. + + + + + + + + + +If set, raise an error on the specified chunk number during streaming. + + + +Internally incremented after every model invocation. + + + +List of responses to **cycle** through in order. + + + + + + + + +```python +langchain_core.language_models.fake_chat_models.FakeListChatModel._astream( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.outputs.ChatGenerationChunk] +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.fake_chat_models.FakeListChatModel._call( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> str +``` + + + + + + +Return the next response in the list. + +Cycle back to the start if at the end. + + + + + + + +```python +langchain_core.language_models.fake_chat_models.FakeListChatModel._stream( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.outputs.ChatGenerationChunk] +``` + + + + + + + + + + + + +```python +langchain_core.language_models.fake_chat_models.FakeListChatModel.abatch( + inputs: list[typing.Any], + config: langchain_core.runnables.RunnableConfig | list[langchain_core.runnables.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any = {} +) -> list[langchain_core.messages.AIMessage] +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.fake_chat_models.FakeListChatModel.batch( + inputs: list[typing.Any], + config: langchain_core.runnables.RunnableConfig | list[langchain_core.runnables.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any = {} +) -> list[langchain_core.messages.AIMessage] +``` + + + + + + + + + + + + + + +```python +class langchain_core.language_models.fake_chat_models.FakeListChatModelError() +``` + + + + + + +Exception + +**Bases:** `Exception` + +Fake error for testing purposes. + + + + + + + + +```python +class langchain_core.language_models.fake_chat_models.FakeMessagesListChatModel() +``` + + + + + + +**Bases:** [BaseChatModel](/langchain-core/langchain_core/language_models/chat_models#langchain_core-language_models-chat_models-BaseChatModel) + +Fake chat model for testing purposes. + + + + + + +Internally incremented after every model invocation. + + + +List of responses to **cycle** through in order. + + + +Sleep time in seconds between responses. + + + + + +```python +langchain_core.language_models.fake_chat_models.FakeMessagesListChatModel._generate( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.ChatResult +``` + + + + + + + + + + + + + + +```python +class langchain_core.language_models.fake_chat_models.GenericFakeChatModel() +``` + + + + + + +**Bases:** [BaseChatModel](/langchain-core/langchain_core/language_models/chat_models#langchain_core-language_models-chat_models-BaseChatModel) + +Generic fake chat model that can be used to test the chat model interface. + +* Chat model should be usable in both sync and async tests +* Invokes `on_llm_new_token` to allow for testing of callback related code for new + tokens. +* Includes logic to break messages into message chunk to facilitate testing of + streaming. + + + + + + +Get an iterator over messages. + +This can be expanded to accept other types like Callables / dicts / strings +to make the interface more generic if needed. + +!!! note + if you want to pass a list, you can use `iter` to convert it to an iterator. + +!!! warning + Streaming is not implemented yet. We should try to implement it in the future by + delegating to invoke and then breaking the resulting output into message chunks. + + + + + +```python +langchain_core.language_models.fake_chat_models.GenericFakeChatModel._generate( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.ChatResult +``` + + + + + + + + + + + + +```python +langchain_core.language_models.fake_chat_models.GenericFakeChatModel._stream( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.outputs.ChatGenerationChunk] +``` + + + + + + + + + + + + + + +```python +class langchain_core.language_models.fake_chat_models.ParrotFakeChatModel() +``` + + + + + + +**Bases:** [BaseChatModel](/langchain-core/langchain_core/language_models/chat_models#langchain_core-language_models-chat_models-BaseChatModel) + +Generic fake chat model that can be used to test the chat model interface. + +* Chat model should be usable in both sync and async tests + + + + + + + + +```python +langchain_core.language_models.fake_chat_models.ParrotFakeChatModel._generate( + messages: list[langchain_core.messages.BaseMessage], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.ChatResult +``` + + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/llms.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/llms.mdx new file mode 100644 index 0000000..3793d5c --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/llms.mdx @@ -0,0 +1,1335 @@ +--- +layout: overview +slug: langchain-core/langchain_core/language_models/llms +title: langchain_core.language_models.llms +--- + +Base interface for traditional large language models (LLMs) to expose. + +These are traditionally older models (newer models generally are chat models). + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseLLM`](#langchain_core-language_models-llms-BaseLLM) | Base LLM abstract interface. | +| [`LLM`](#langchain_core-language_models-llms-LLM) | Simple interface for implementing a custom LLM. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_log_error_once`](#langchain_core-language_models-llms-_log_error_once) | Log an error once. | +| [`_resolve_cache`](#langchain_core-language_models-llms-_resolve_cache) | Resolve the cache. | +| [`aget_prompts`](#langchain_core-language_models-llms-aget_prompts) | Get prompts that are already cached. Async version. | +| [`aupdate_cache`](#langchain_core-language_models-llms-aupdate_cache) | Update the cache and get the LLM output. Async version. | +| [`create_base_retry_decorator`](#langchain_core-language_models-llms-create_base_retry_decorator) | Create a retry decorator for a given LLM and provided a list of error types. | +| [`get_prompts`](#langchain_core-language_models-llms-get_prompts) | Get prompts that are already cached. | +| [`update_cache`](#langchain_core-language_models-llms-update_cache) | Update the cache and get the LLM output. | + +### Data + +[`_background_tasks`](#langchain_core-language_models-llms-_background_tasks) + +[`logger`](#langchain_core-language_models-llms-logger) + +### API + + + + + +```python +class langchain_core.language_models.llms.BaseLLM() +``` + + + + + + +Abstract + +**Bases:** [BaseLanguageModel[str]](/langchain-core/langchain_core/language_models/base#langchain_core-language_models-base-BaseLanguageModel) + +Base LLM abstract interface. + +It should take in a prompt and return a string. + + + +Get the output type for this `Runnable`. + + + +Return type of llm. + + + + + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.__str__() -> str +``` + + + + + + +Return a string representation of the object for printing. + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM._agenerate( + prompts: list[str], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +async + +Run the LLM on the given prompts. + +**Parameters:** + + +The prompts to generate from. + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + +If stop tokens are not supported consider raising `NotImplementedError`. + + + +Callback manager for the run. + + +**Returns:** `LLMResult` + +The LLM result. + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM._agenerate_helper( + prompts: list[str], + stop: list[str] | None, + run_managers: list[langchain_core.callbacks.AsyncCallbackManagerForLLMRun], + new_arg_supported: bool, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM._astream( + prompt: str, + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.outputs.GenerationChunk] +``` + + + + + + +async + +An async version of the _stream method. + +The default implementation uses the synchronous _stream method and wraps it in +an async iterator. Subclasses that need to provide a true async implementation +should override this method. + +**Parameters:** + + +The prompt to generate from. + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + + + +Callback manager for the run. + + + +Arbitrary additional keyword arguments. + +These are usually passed to the model provider API call. + + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM._call_async( + prompt: str, + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks = None, + tags: list[str] | None = None, + metadata: langchain_core.language_models.llms.BaseLLM.dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> str +``` + + + + + + +async + +Check Cache and run the LLM on the given prompt and input. + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM._convert_input( + model_input: langchain_core.language_models.base.LanguageModelInput +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM._generate( + prompts: list[str], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +abstract + +Run the LLM on the given prompts. + +**Parameters:** + + +The prompts to generate from. + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + +If stop tokens are not supported consider raising `NotImplementedError`. + + + +Callback manager for the run. + + +**Returns:** `LLMResult` + +The LLM result. + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM._generate_helper( + prompts: list[str], + stop: list[str] | None, + run_managers: list[langchain_core.callbacks.CallbackManagerForLLMRun], + new_arg_supported: bool, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM._get_ls_params( + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.language_models.base.LangSmithParams +``` + + + + + + +Get standard params for tracing. + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM._get_run_ids_list( + run_id: uuid.UUID | list[uuid.UUID | None] | None, + prompts: list +) -> list +``` + + + + + + +staticmethod + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM._stream( + prompt: str, + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.outputs.GenerationChunk] +``` + + + + + + +Stream the LLM on the given prompt. + +This method should be overridden by subclasses that support streaming. + +If not implemented, the default behavior of calls to stream will be to +fallback to the non-streaming version of the model and return +the output as a single chunk. + +**Parameters:** + + +The prompt to generate from. + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + + + +Callback manager for the run. + + + +Arbitrary additional keyword arguments. + +These are usually passed to the model provider API call. + + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.abatch( + inputs: list[langchain_core.language_models.base.LanguageModelInput], + config: langchain_core.runnables.RunnableConfig | list[langchain_core.runnables.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any = {} +) -> list[str] +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.agenerate( + prompts: list[str], + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks | list[langchain_core.callbacks.Callbacks] | None = None, + tags: list[str] | list[list[str]] | None = None, + metadata: langchain_core.language_models.llms.BaseLLM.dict[str, typing.Any] | list[langchain_core.language_models.llms.BaseLLM.dict[str, typing.Any]] | None = None, + run_name: str | list[str] | None = None, + run_id: uuid.UUID | list[uuid.UUID | None] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +async + +Asynchronously pass a sequence of prompts to a model and return generations. + +This method should make use of batched calls for models that expose a batched +API. + +Use this method when you want to: + +1. Take advantage of batched calls, +2. Need more output from the model than just the top generated value, +3. Are building chains that are agnostic to the underlying language model + type (e.g., pure text completion models vs chat models). + +**Parameters:** + + +List of string prompts. + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + + + +`Callbacks` to pass through. + +Used for executing additional functionality, such as logging or +streaming, throughout generation. + + + +List of tags to associate with each prompt. If provided, the length +of the list must match the length of the prompts list. + + + +List of metadata dictionaries to associate with each prompt. If +provided, the length of the list must match the length of the prompts +list. + + + +List of run names to associate with each prompt. If provided, the +length of the list must match the length of the prompts list. + + + +List of run IDs to associate with each prompt. If provided, the +length of the list must match the length of the prompts list. + + + +Arbitrary additional keyword arguments. + +These are usually passed to the model provider API call. + + +**Returns:** `LLMResult` + +An `LLMResult`, which contains a list of candidate `Generations` for each +input prompt and additional model provider-specific output. + +**Raises:** + +- `ValueError`: If the length of `callbacks`, `tags`, `metadata`, or +`run_name` (if provided) does not match the length of prompts. + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.agenerate_prompt( + prompts: list[langchain_core.prompt_values.PromptValue], + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks | list[langchain_core.callbacks.Callbacks] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.ainvoke( + input: langchain_core.language_models.base.LanguageModelInput, + config: langchain_core.runnables.RunnableConfig | None = None, + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> str +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.astream( + input: langchain_core.language_models.base.LanguageModelInput, + config: langchain_core.runnables.RunnableConfig | None = None, + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[str] +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.batch( + inputs: list[langchain_core.language_models.base.LanguageModelInput], + config: langchain_core.runnables.RunnableConfig | list[langchain_core.runnables.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any = {} +) -> list[str] +``` + + + + + + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.dict( + kwargs: typing.Any = {} +) -> langchain_core.language_models.llms.BaseLLM.dict +``` + + + + + + +Return a dictionary of the LLM. + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.generate( + prompts: list[str], + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks | list[langchain_core.callbacks.Callbacks] | None = None, + tags: list[str] | list[list[str]] | None = None, + metadata: langchain_core.language_models.llms.BaseLLM.dict[str, typing.Any] | list[langchain_core.language_models.llms.BaseLLM.dict[str, typing.Any]] | None = None, + run_name: str | list[str] | None = None, + run_id: uuid.UUID | list[uuid.UUID | None] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +Pass a sequence of prompts to a model and return generations. + +This method should make use of batched calls for models that expose a batched +API. + +Use this method when you want to: + +1. Take advantage of batched calls, +2. Need more output from the model than just the top generated value, +3. Are building chains that are agnostic to the underlying language model + type (e.g., pure text completion models vs chat models). + +**Parameters:** + + +List of string prompts. + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + + + +`Callbacks` to pass through. + +Used for executing additional functionality, such as logging or +streaming, throughout generation. + + + +List of tags to associate with each prompt. If provided, the length +of the list must match the length of the prompts list. + + + +List of metadata dictionaries to associate with each prompt. If +provided, the length of the list must match the length of the prompts +list. + + + +List of run names to associate with each prompt. If provided, the +length of the list must match the length of the prompts list. + + + +List of run IDs to associate with each prompt. If provided, the +length of the list must match the length of the prompts list. + + + +Arbitrary additional keyword arguments. + +These are usually passed to the model provider API call. + + +**Returns:** `LLMResult` + +An `LLMResult`, which contains a list of candidate `Generations` for each +input prompt and additional model provider-specific output. + +**Raises:** + +- `ValueError`: If prompts is not a list. +- `ValueError`: If the length of `callbacks`, `tags`, `metadata`, or +`run_name` (if provided) does not match the length of prompts. + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.generate_prompt( + prompts: list[langchain_core.prompt_values.PromptValue], + stop: list[str] | None = None, + callbacks: langchain_core.callbacks.Callbacks | list[langchain_core.callbacks.Callbacks] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.invoke( + input: langchain_core.language_models.base.LanguageModelInput, + config: langchain_core.runnables.RunnableConfig | None = None, + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> str +``` + + + + + + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.save( + file_path: pathlib.Path | str +) -> None +``` + + + + + + +Save the LLM. + +**Parameters:** + + +Path to file to save the LLM to. + + +**Raises:** + +- `ValueError`: If the file path is not a string or Path object. + + + + + + + +```python +langchain_core.language_models.llms.BaseLLM.stream( + input: langchain_core.language_models.base.LanguageModelInput, + config: langchain_core.runnables.RunnableConfig | None = None, + stop: list[str] | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[str] +``` + + + + + + + + + + + + + + +```python +class langchain_core.language_models.llms.LLM() +``` + + + + + + +**Bases:** [BaseLLM](#langchain_core-language_models-llms-BaseLLM) + +Simple interface for implementing a custom LLM. + +You should subclass this class and implement the following: + +- `_call` method: Run the LLM on the given prompt and input (used by `invoke`). +- `_identifying_params` property: Return a dictionary of the identifying parameters + This is critical for caching and tracing purposes. Identifying parameters + is a dict that identifies the LLM. + It should mostly include a `model_name`. + +Optional: Override the following methods to provide more optimizations: + +- `_acall`: Provide a native async version of the `_call` method. + If not provided, will delegate to the synchronous version using + `run_in_executor`. (Used by `ainvoke`). +- `_stream`: Stream the LLM on the given prompt and input. + `stream` will use `_stream` if provided, otherwise it + use `_call` and output will arrive in one chunk. +- `_astream`: Override to provide a native async version of the `_stream` method. + `astream` will use `_astream` if provided, otherwise it will implement + a fallback behavior that will use `_stream` if `_stream` is implemented, + and use `_acall` if `_stream` is not implemented. + + + + + + +```python +langchain_core.language_models.llms.LLM._acall( + prompt: str, + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> str +``` + + + + + + +async + +Async version of the _call method. + +The default implementation delegates to the synchronous _call method using +`run_in_executor`. Subclasses that need to provide a true async implementation +should override this method to reduce the overhead of using `run_in_executor`. + +**Parameters:** + + +The prompt to generate from. + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + +If stop tokens are not supported consider raising `NotImplementedError`. + + + +Callback manager for the run. + + + +Arbitrary additional keyword arguments. + +These are usually passed to the model provider API call. + + +**Returns:** `str` + +The model output as a string. SHOULD NOT include the prompt. + + + + + + + +```python +langchain_core.language_models.llms.LLM._agenerate( + prompts: list[str], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + +async + + + + + + + +```python +langchain_core.language_models.llms.LLM._call( + prompt: str, + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> str +``` + + + + + + +abstract + +Run the LLM on the given input. + +Override this method to implement the LLM logic. + +**Parameters:** + + +The prompt to generate from. + + + +Stop words to use when generating. + +Model output is cut off at the first occurrence of any of these +substrings. + +If stop tokens are not supported consider raising `NotImplementedError`. + + + +Callback manager for the run. + + + +Arbitrary additional keyword arguments. + +These are usually passed to the model provider API call. + + +**Returns:** `str` + +The model output as a string. SHOULD NOT include the prompt. + + + + + + + +```python +langchain_core.language_models.llms.LLM._generate( + prompts: list[str], + stop: list[str] | None = None, + run_manager: langchain_core.callbacks.CallbackManagerForLLMRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.outputs.LLMResult +``` + + + + + + + + + + + + + + +```python +langchain_core.language_models.llms._log_error_once( + msg: str +) -> None +``` + + + + + + +Log an error once. + + + + + + + + +```python +langchain_core.language_models.llms._resolve_cache( + cache: langchain_core.caches.BaseCache | bool | None +) -> langchain_core.caches.BaseCache | None +``` + + + + + + +Resolve the cache. + + + + + + + + +```python +langchain_core.language_models.llms.aget_prompts( + params: dict[str, typing.Any], + prompts: list[str], + cache: langchain_core.caches.BaseCache | bool | None = None +) -> tuple[dict[int, list], str, list[int], list[str]] +``` + + + + + + +async + +Get prompts that are already cached. Async version. + +**Parameters:** + + +Dictionary of parameters. + + + +List of prompts. + + + +Cache object. + + +**Returns:** `tuple[dict[int, list], str, list[int], list[str]]` + +A tuple of existing prompts, llm_string, missing prompt indexes, +and missing prompts. + +**Raises:** + +- `ValueError`: If the cache is not set and cache is True. + + + + + + + + +```python +langchain_core.language_models.llms.aupdate_cache( + cache: langchain_core.caches.BaseCache | bool | None, + existing_prompts: dict[int, list], + llm_string: str, + missing_prompt_idxs: list[int], + new_results: langchain_core.outputs.LLMResult, + prompts: list[str] +) -> dict | None +``` + + + + + + +async + +Update the cache and get the LLM output. Async version. + +**Parameters:** + + +Cache object. + + + +Dictionary of existing prompts. + + + +LLM string. + + + +List of missing prompt indexes. + + + +LLMResult object. + + + +List of prompts. + + +**Returns:** `dict | None` + +LLM output. + +**Raises:** + +- `ValueError`: If the cache is not set and cache is True. + + + + + + + + +```python +langchain_core.language_models.llms.create_base_retry_decorator( + error_types: list[type[BaseException]], + max_retries: int = 1, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForLLMRun | langchain_core.callbacks.CallbackManagerForLLMRun | None = None +) -> collections.abc.Callable[[Any], typing.Any] +``` + + + + + + +Create a retry decorator for a given LLM and provided a list of error types. + +**Parameters:** + + +List of error types to retry on. + + + +Number of retries. + + + +Callback manager for the run. + + +**Returns:** `Callable[[Any], Any]` + +A retry decorator. + +**Raises:** + +- `ValueError`: If the cache is not set and cache is True. + + + + + + + + +```python +langchain_core.language_models.llms.get_prompts( + params: dict[str, typing.Any], + prompts: list[str], + cache: langchain_core.caches.BaseCache | bool | None = None +) -> tuple[dict[int, list], str, list[int], list[str]] +``` + + + + + + +Get prompts that are already cached. + +**Parameters:** + + +Dictionary of parameters. + + + +List of prompts. + + + +Cache object. + + +**Returns:** `tuple[dict[int, list], str, list[int], list[str]]` + +A tuple of existing prompts, llm_string, missing prompt indexes, +and missing prompts. + +**Raises:** + +- `ValueError`: If the cache is not set and cache is True. + + + + + + + + +```python +langchain_core.language_models.llms.update_cache( + cache: langchain_core.caches.BaseCache | bool | None, + existing_prompts: dict[int, list], + llm_string: str, + missing_prompt_idxs: list[int], + new_results: langchain_core.outputs.LLMResult, + prompts: list[str] +) -> dict | None +``` + + + + + + +Update the cache and get the LLM output. + +**Parameters:** + + +Cache object. + + + +Dictionary of existing prompts. + + + +LLM string. + + + +List of missing prompt indexes. + + + +LLMResult object. + + + +List of prompts. + + +**Returns:** `dict | None` + +LLM output. + +**Raises:** + +- `ValueError`: If the cache is not set and cache is True. + + + + + + + + +```python +langchain_core.language_models.llms._background_tasks: set[Task] = set() +``` + + + + + + + + + +```python +langchain_core.language_models.llms.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/model_profile.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/model_profile.mdx new file mode 100644 index 0000000..70be5ed --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/language_models/model_profile.mdx @@ -0,0 +1,144 @@ +--- +layout: overview +slug: langchain-core/langchain_core/language_models/model_profile +title: langchain_core.language_models.model_profile +--- + +Model profile types and utilities. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ModelProfile`](#langchain_core-language_models-model_profile-ModelProfile) | Model profile. | + +### Data + +[`ModelProfileRegistry`](#langchain_core-language_models-model_profile-ModelProfileRegistry) + +### API + + + + + +```python +class langchain_core.language_models.model_profile.ModelProfile +``` + + + + + + +**Bases:** `typing.TypedDict` + +Model profile. + +!!! warning "Beta feature" + + This is a beta feature. The format of model profiles is subject to change. + +Provides information about chat model capabilities, such as context window sizes +and supported features. + + +Whether [audio inputs](https://docs.langchain.com/oss/python/langchain/models#multimodal) +are supported. + + + +Whether [audio outputs](https://docs.langchain.com/oss/python/langchain/models#multimodal) +are supported. + + + +Whether image inputs are supported. + + + +Whether [image outputs](https://docs.langchain.com/oss/python/langchain/models#multimodal) +are supported. + + + +Whether images can be included in tool messages. + + + +Whether [image URL inputs](https://docs.langchain.com/oss/python/langchain/models#multimodal) +are supported. + + + +Maximum context window (tokens) + + + +Maximum output tokens + + + +Whether [PDF inputs](https://docs.langchain.com/oss/python/langchain/models#multimodal) +are supported. + + + +Whether PDFs can be included in tool messages. + + + +Whether the model supports [reasoning / chain-of-thought](https://docs.langchain.com/oss/python/langchain/models#reasoning) + + + +Whether the model supports a native [structured output](https://docs.langchain.com/oss/python/langchain/models#structured-outputs) +feature + + + +Whether text inputs are supported. + + + +Whether text outputs are supported. + + + +Whether the model supports [tool calling](https://docs.langchain.com/oss/python/langchain/models#tool-calling) + + + +Whether the model supports [tool choice](https://docs.langchain.com/oss/python/langchain/models#forcing-tool-calls) + + + +Whether [video inputs](https://docs.langchain.com/oss/python/langchain/models#multimodal) +are supported. + + + +Whether [video outputs](https://docs.langchain.com/oss/python/langchain/models#multimodal) +are supported. + + + + + + + + +```python +langchain_core.language_models.model_profile.ModelProfileRegistry = dict[str, ModelProfile] +``` + + + + + + +Registry mapping model identifiers or names to their ModelProfile. + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load.mdx new file mode 100644 index 0000000..9b31041 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load.mdx @@ -0,0 +1,90 @@ +--- +layout: overview +slug: langchain-core/langchain_core/load +title: langchain_core.load +--- + +**Load** module helps with serialization and deserialization. + +## Submodules + +- **[`langchain_core.load._validation`](/langchain-core/langchain_core/load/_validation)** +- **[`langchain_core.load.dump`](/langchain-core/langchain_core/load/dump)** +- **[`langchain_core.load.load`](/langchain-core/langchain_core/load/load)** +- **[`langchain_core.load.mapping`](/langchain-core/langchain_core/load/mapping)** +- **[`langchain_core.load.serializable`](/langchain-core/langchain_core/load/serializable)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-load-__dir__) | - | +| [`__getattr__`](#langchain_core-load-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-load-__all__) + +[`_dynamic_imports`](#langchain_core-load-_dynamic_imports) + +### API + + + + + +```python +langchain_core.load.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.load.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.load.__all__ = ('InitValidator', 'Serializable', 'dumpd', 'dumps', 'load', 'loads') +``` + + + + + + + + + +```python +langchain_core.load._dynamic_imports = {'dumpd': 'dump', 'dumps': 'dump', 'InitValidator': 'load', 'loads': 'load', 'Se... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/_validation.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/_validation.mdx new file mode 100644 index 0000000..e0e5b71 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/_validation.mdx @@ -0,0 +1,252 @@ +--- +layout: overview +slug: langchain-core/langchain_core/load/_validation +title: langchain_core.load._validation +--- + +Validation utilities for LangChain serialization. + +Provides escape-based protection against injection attacks in serialized objects. The +approach uses an allowlist design: only dicts explicitly produced by +`Serializable.to_json()` are treated as LC objects during deserialization. + +## How escaping works + +During serialization, plain dicts (user data) that contain an `'lc'` key are wrapped: + + + +```python +{"lc": 1, ...} # user data that looks like LC object +# becomes: +{"__lc_escaped__": {"lc": 1, ...}} +``` + + + +During deserialization, escaped dicts are unwrapped and returned as plain dicts, +NOT instantiated as LC objects. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_escape_dict`](#langchain_core-load-_validation-_escape_dict) | Wrap a dict in the escape marker. | +| [`_is_escaped_dict`](#langchain_core-load-_validation-_is_escaped_dict) | Check if a dict is an escaped user dict. | +| [`_is_lc_secret`](#langchain_core-load-_validation-_is_lc_secret) | Check if an object is a LangChain secret marker. | +| [`_needs_escaping`](#langchain_core-load-_validation-_needs_escaping) | Check if a dict needs escaping to prevent confusion with LC objects. | +| [`_serialize_lc_object`](#langchain_core-load-_validation-_serialize_lc_object) | Serialize a `Serializable` object with escaping of user data in kwargs. | +| [`_serialize_value`](#langchain_core-load-_validation-_serialize_value) | Serialize a value with escaping of user dicts. | +| [`_unescape_value`](#langchain_core-load-_validation-_unescape_value) | Unescape a value, processing escape markers in dict values and lists. | + +### Data + +[`_LC_ESCAPED_KEY`](#langchain_core-load-_validation-_LC_ESCAPED_KEY) + +### API + + + + + +```python +langchain_core.load._validation._escape_dict( + obj: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + +Wrap a dict in the escape marker. + + + + + + + + +```python +langchain_core.load._validation._is_escaped_dict( + obj: dict[str, typing.Any] +) -> bool +``` + + + + + + +Check if a dict is an escaped user dict. + + + + + + + + +```python +langchain_core.load._validation._is_lc_secret( + obj: typing.Any +) -> bool +``` + + + + + + +Check if an object is a LangChain secret marker. + + + + + + + + +```python +langchain_core.load._validation._needs_escaping( + obj: dict[str, typing.Any] +) -> bool +``` + + + + + + +Check if a dict needs escaping to prevent confusion with LC objects. + +A dict needs escaping if: + +1. It has an `'lc'` key (could be confused with LC serialization format) +2. It has only the escape key (would be mistaken for an escaped dict) + + + + + + + + +```python +langchain_core.load._validation._serialize_lc_object( + obj: typing.Any +) -> dict[str, typing.Any] +``` + + + + + + +Serialize a `Serializable` object with escaping of user data in kwargs. + +**Parameters:** + + +The `Serializable` object to serialize. + + +**Returns:** `dict[str, Any]` + +The serialized dict with user data in kwargs escaped as needed. + + + + + + + + +```python +langchain_core.load._validation._serialize_value( + obj: typing.Any +) -> typing.Any +``` + + + + + + +Serialize a value with escaping of user dicts. + +Called recursively on kwarg values to escape any plain dicts that could be confused +with LC objects. + +**Parameters:** + + +The value to serialize. + + +**Returns:** `Any` + +The serialized value with user dicts escaped as needed. + + + + + + + + +```python +langchain_core.load._validation._unescape_value( + obj: typing.Any +) -> typing.Any +``` + + + + + + +Unescape a value, processing escape markers in dict values and lists. + +When an escaped dict is encountered (`{"__lc_escaped__": ...}`), it's +unwrapped and the contents are returned AS-IS (no further processing). +The contents represent user data that should not be modified. + +For regular dicts and lists, we recurse to find any nested escape markers. + +**Parameters:** + + +The value to unescape. + + +**Returns:** `Any` + +The unescaped value. + + + + + + + + +```python +langchain_core.load._validation._LC_ESCAPED_KEY = '__lc_escaped__' +``` + + + + + + +Sentinel key used to mark escaped user dicts during serialization. + +When a plain dict contains 'lc' key (which could be confused with LC objects), +we wrap it as {"__lc_escaped__": {...original...}}. + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/dump.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/dump.mdx new file mode 100644 index 0000000..f660b38 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/dump.mdx @@ -0,0 +1,177 @@ +--- +layout: overview +slug: langchain-core/langchain_core/load/dump +title: langchain_core.load.dump +--- + +Serialize LangChain objects to JSON. + +Provides `dumps` (to JSON string) and `dumpd` (to dict) for serializing +`Serializable` objects. + +## Escaping + +During serialization, plain dicts (user data) that contain an `'lc'` key are escaped +by wrapping them: `{"__lc_escaped__": {...original...}}`. This prevents injection +attacks where malicious data could trick the deserializer into instantiating +arbitrary classes. The escape marker is removed during deserialization. + +This is an allowlist approach: only dicts explicitly produced by +`Serializable.to_json()` are treated as LC objects; everything else is escaped if it +could be confused with the LC format. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_dump_pydantic_models`](#langchain_core-load-dump-_dump_pydantic_models) | Convert nested Pydantic models to dicts for JSON serialization. | +| [`default`](#langchain_core-load-dump-default) | Return a default value for an object. | +| [`dumpd`](#langchain_core-load-dump-dumpd) | Return a dict representation of an object. | +| [`dumps`](#langchain_core-load-dump-dumps) | Return a JSON string representation of an object. | + +### API + + + + + +```python +langchain_core.load.dump._dump_pydantic_models( + obj: typing.Any +) -> typing.Any +``` + + + + + + +Convert nested Pydantic models to dicts for JSON serialization. + +Handles the special case where a `ChatGeneration` contains an `AIMessage` +with a parsed Pydantic model in `additional_kwargs["parsed"]`. Since +Pydantic models aren't directly JSON serializable, this converts them to +dicts. + +**Parameters:** + + +The object to process. + + +**Returns:** `Any` + +A copy of the object with nested Pydantic models converted to dicts, or +the original object unchanged if no conversion was needed. + + + + + + + + +```python +langchain_core.load.dump.default( + obj: typing.Any +) -> typing.Any +``` + + + + + + +Return a default value for an object. + +**Parameters:** + + +The object to serialize to json if it is a Serializable object. + + +**Returns:** `Any` + +A JSON serializable object or a SerializedNotImplemented object. + + + + + + + + +```python +langchain_core.load.dump.dumpd( + obj: typing.Any +) -> typing.Any +``` + + + + + + +Return a dict representation of an object. + +**Parameters:** + + +The object to dump. + + +**Returns:** `Any` + +Dictionary that can be serialized to json using `json.dumps`. + + + + + + + + +```python +langchain_core.load.dump.dumps( + obj: typing.Any, + pretty: bool = False, + kwargs: typing.Any = {} +) -> str +``` + + + + + + +Return a JSON string representation of an object. + +**Parameters:** + + +The object to dump. + + + +Whether to pretty print the json. + +If `True`, the json will be indented by either 2 spaces or the amount +provided in the `indent` kwarg. + + + +Additional arguments to pass to `json.dumps` + + +**Returns:** `str` + +A JSON string representation of the object. + +**Raises:** + +- `ValueError`: If `default` is passed as a kwarg. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/load.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/load.mdx new file mode 100644 index 0000000..3d448d9 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/load.mdx @@ -0,0 +1,692 @@ +--- +layout: overview +slug: langchain-core/langchain_core/load/load +title: langchain_core.load.load +--- + +Load LangChain objects from JSON strings or objects. + +## How it works + +Each `Serializable` LangChain object has a unique identifier (its "class path"), which +is a list of strings representing the module path and class name. For example: + +- `AIMessage` -> `["langchain_core", "messages", "ai", "AIMessage"]` +- `ChatPromptTemplate` -> `["langchain_core", "prompts", "chat", "ChatPromptTemplate"]` + +When deserializing, the class path from the JSON `'id'` field is checked against an +allowlist. If the class is not in the allowlist, deserialization raises a `ValueError`. + +## Security model + +!!! warning "Exercise caution with untrusted input" + + These functions deserialize by instantiating Python objects, which means + constructors (`__init__`) and validators may run and can trigger side effects. + With the default settings, deserialization is restricted to a core allowlist + of `langchain_core` types (for example: messages, documents, and prompts) + defined in `langchain_core.load.mapping`. + + If you broaden `allowed_objects` (for example, by using `'all'` or adding + additional classes), treat the serialized payload as a manifest and only + deserialize data that comes from a trusted source. A crafted payload that + is allowed to instantiate unintended classes could cause network calls, + file operations, or environment variable access during `__init__`. + +The `allowed_objects` parameter controls which classes can be deserialized: + +- **`'core'` (default)**: Allow classes defined in the serialization mappings for + langchain_core. +- **`'all'`**: Allow classes defined in the serialization mappings. This + includes core LangChain types (messages, prompts, documents, etc.) and trusted + partner integrations. See `langchain_core.load.mapping` for the full list. +- **Explicit list of classes**: Only those specific classes are allowed. + +For simple data types like messages and documents, the default allowlist is safe to use. +These classes do not perform side effects during initialization. + +!!! note "Side effects in allowed classes" + + Deserialization calls `__init__` on allowed classes. If those classes perform side + effects during initialization (network calls, file operations, etc.), those side + effects will occur. The allowlist prevents instantiation of classes outside the + allowlist, but does not sandbox the allowed classes themselves. + +Import paths are also validated against trusted namespaces before any module is +imported. + +### Best practices + +- Use the most restrictive `allowed_objects` possible. Prefer an explicit list + of classes over `'core'` or `'all'`. +- Keep `secrets_from_env` set to `False` (the default). If you must use it, + ensure the serialized data comes from a fully trusted source, as a crafted + payload can read arbitrary environment variables. +- When using `secrets_map`, include only the specific secrets that the + serialized object requires. + +### Injection protection (escape-based) + +During serialization, plain dicts that contain an `'lc'` key are escaped by wrapping +them: `{"__lc_escaped__": {...}}`. During deserialization, escaped dicts are unwrapped +and returned as plain dicts, NOT instantiated as LC objects. + +This is an allowlist approach: only dicts explicitly produced by +`Serializable.to_json()` (which are NOT escaped) are treated as LC objects; +everything else is user data. + +Even if an attacker's payload includes `__lc_escaped__` wrappers, it will be unwrapped +to plain dicts and NOT instantiated as malicious objects. + +## Examples + + + +```python +from langchain_core.load import load +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.messages import AIMessage, HumanMessage + +# Use default allowlist (classes from mappings) - recommended +obj = load(data) + +# Allow only specific classes (most restrictive) +obj = load( + data, + allowed_objects=[ + ChatPromptTemplate, + AIMessage, + HumanMessage, + ], +) +``` + + + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Reviver`](#langchain_core-load-load-Reviver) | Reviver for JSON objects. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_block_jinja2_templates`](#langchain_core-load-load-_block_jinja2_templates) | Block jinja2 templates during deserialization for security. | +| [`_compute_allowed_class_paths`](#langchain_core-load-load-_compute_allowed_class_paths) | Return allowed class paths from an explicit list of classes. | +| [`_get_default_allowed_class_paths`](#langchain_core-load-load-_get_default_allowed_class_paths) | Get the default allowed class paths from the serialization mappings. | +| [`default_init_validator`](#langchain_core-load-load-default_init_validator) | Default init validator that blocks jinja2 templates. | +| [`load`](#langchain_core-load-load-load) | Revive a LangChain class from a JSON object. | +| [`loads`](#langchain_core-load-load-loads) | Revive a LangChain class from a JSON string. | + +### Data + +[`ALL_SERIALIZABLE_MAPPINGS`](#langchain_core-load-load-ALL_SERIALIZABLE_MAPPINGS) + +[`AllowedObject`](#langchain_core-load-load-AllowedObject) + +[`DEFAULT_NAMESPACES`](#langchain_core-load-load-DEFAULT_NAMESPACES) + +[`DISALLOW_LOAD_FROM_PATH`](#langchain_core-load-load-DISALLOW_LOAD_FROM_PATH) + +[`InitValidator`](#langchain_core-load-load-InitValidator) + +[`_default_class_paths_cache`](#langchain_core-load-load-_default_class_paths_cache) + +### API + + + + + +```python +class langchain_core.load.load.Reviver( + allowed_objects: collections.abc.Iterable[langchain_core.load.load.AllowedObject] | typing.Literal['all', 'core'] = 'core', + secrets_map: dict[str, str] | None = None, + valid_namespaces: list[str] | None = None, + secrets_from_env: bool = False, + additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None, + ignore_unserializable_fields: bool = False, + init_validator: langchain_core.load.load.InitValidator | None = default_init_validator +) +``` + + + + + + +Reviver for JSON objects. + +Used as the `object_hook` for `json.loads` to reconstruct LangChain objects from +their serialized JSON representation. + +Only classes in the allowlist can be instantiated. + + + + + + + + + + + + + + + + + + + + +```python +langchain_core.load.load.Reviver.__call__( + value: dict[str, typing.Any] +) -> typing.Any +``` + + + + + + +Revive the value. + +**Parameters:** + + +The value to revive. + + +**Returns:** `Any` + +The revived value. + +**Raises:** + +- `ValueError`: If the namespace is invalid. +- `ValueError`: If trying to deserialize something that cannot +be deserialized in the current version of langchain-core. +- `NotImplementedError`: If the object is not implemented and +`ignore_unserializable_fields` is False. + + + + + + + + + +```python +langchain_core.load.load._block_jinja2_templates( + class_path: tuple[str, ...], + kwargs: dict[str, typing.Any] +) -> None +``` + + + + + + +Block jinja2 templates during deserialization for security. + +Jinja2 templates can execute arbitrary code, so they are blocked by default when +deserializing objects with `template_format='jinja2'`. + +**Parameters:** + + +The class path tuple being deserialized (unused). + + + +The kwargs dict for the class constructor. + + +**Raises:** + +- `ValueError`: If `template_format` is `'jinja2'`. + + + + + + + + +```python +langchain_core.load.load._compute_allowed_class_paths( + allowed_objects: collections.abc.Iterable[langchain_core.load.load.AllowedObject], + import_mappings: dict[tuple[str, ...], tuple[str, ...]] +) -> set[tuple[str, ...]] +``` + + + + + + +Return allowed class paths from an explicit list of classes. + +A class path is a tuple of strings identifying a serializable class, derived from +`Serializable.lc_id()`. For example: `('langchain_core', 'messages', 'AIMessage')`. + +**Parameters:** + + +Iterable of `Serializable` subclasses to allow. + + + +Mapping of legacy class paths to current class paths. + + +**Returns:** `set[tuple[str, ...]]` + +Set of allowed class paths. + + + + + + + + +```python +langchain_core.load.load._get_default_allowed_class_paths( + allowed_object_mode: typing.Literal['all', 'core'] +) -> set[tuple[str, ...]] +``` + + + + + + +Get the default allowed class paths from the serialization mappings. + +This uses the mappings as the source of truth for what classes are allowed +by default. Both the legacy paths (keys) and current paths (values) are included. + +**Parameters:** + + +either `'all'` or `'core'`. + + +**Returns:** `set[tuple[str, ...]]` + +Set of class path tuples that are allowed by default. + + + + + + + + +```python +langchain_core.load.load.default_init_validator( + class_path: tuple[str, ...], + kwargs: dict[str, typing.Any] +) -> None +``` + + + + + + +Default init validator that blocks jinja2 templates. + +This is the default validator used by `load()` and `loads()` when no custom +validator is provided. + +**Parameters:** + + +The class path tuple being deserialized. + + + +The kwargs dict for the class constructor. + + +**Raises:** + +- `ValueError`: If template_format is `'jinja2'`. + + + + + + + + +```python +langchain_core.load.load.load( + obj: typing.Any, + allowed_objects: collections.abc.Iterable[langchain_core.load.load.AllowedObject] | typing.Literal['all', 'core'] = 'core', + secrets_map: dict[str, str] | None = None, + valid_namespaces: list[str] | None = None, + secrets_from_env: bool = False, + additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None, + ignore_unserializable_fields: bool = False, + init_validator: langchain_core.load.load.InitValidator | None = default_init_validator +) -> typing.Any +``` + + + + + + +Revive a LangChain class from a JSON object. + +Use this if you already have a parsed JSON object, eg. from `json.load` or +`orjson.loads`. + +Only classes in the allowlist can be instantiated. The default allowlist includes +core LangChain types (messages, prompts, documents, etc.). See +`langchain_core.load.mapping` for the full list. + +!!! warning "Do not use with untrusted input" + + This function instantiates Python objects and can trigger side effects + during deserialization. **Never call `load()` on data from an untrusted + or unauthenticated source.** See the module-level security model + documentation for details and best practices. + +**Parameters:** + + +The object to load. + + + +Allowlist of classes that can be deserialized. + +- `'core'` (default): Allow classes defined in the serialization mappings + for `langchain_core`. +- `'all'`: Allow classes defined in the serialization mappings. + + This includes core LangChain types (messages, prompts, documents, etc.) + and trusted partner integrations. See `langchain_core.load.mapping` for + the full list. + +- Explicit list of classes: Only those specific classes are allowed. +- `[]`: Disallow all deserialization (will raise on any object). + + + +A map of secrets to load. + +Only include the specific secrets the serialized object requires. + +If a secret is not found in the map, it will be loaded from the environment +if `secrets_from_env` is `True`. + + + +Additional namespaces (modules) to allow during +deserialization, beyond the default trusted namespaces. + + + +Whether to load secrets from the environment. + +A crafted payload can name arbitrary environment variables in its +`secret` fields, so enabling this on untrusted data can leak +sensitive values. Keep this `False` (the default) unless the +serialized data is fully trusted. + + + +A dictionary of additional namespace mappings. + +You can use this to override default mappings or add new mappings. + +When `allowed_objects` is `None` (using defaults), paths from these +mappings are also added to the allowed class paths. + + + +Whether to ignore unserializable fields. + + + +Optional callable to validate kwargs before instantiation. + +If provided, this function is called with `(class_path, kwargs)` where +`class_path` is the class path tuple and `kwargs` is the kwargs dict. +The validator should raise an exception if the object should not be +deserialized, otherwise return `None`. + +Defaults to `default_init_validator` which blocks jinja2 templates. + + +**Returns:** `Any` + +Revived LangChain objects. + +**Raises:** + +- `ValueError`: If an object's class path is not in the `allowed_objects` allowlist. + + + + + + + + +```python +langchain_core.load.load.loads( + text: str, + allowed_objects: collections.abc.Iterable[langchain_core.load.load.AllowedObject] | typing.Literal['all', 'core'] = 'core', + secrets_map: dict[str, str] | None = None, + valid_namespaces: list[str] | None = None, + secrets_from_env: bool = False, + additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None, + ignore_unserializable_fields: bool = False, + init_validator: langchain_core.load.load.InitValidator | None = default_init_validator +) -> typing.Any +``` + + + + + + +Revive a LangChain class from a JSON string. + +Equivalent to `load(json.loads(text))`. + +Only classes in the allowlist can be instantiated. The default allowlist includes +core LangChain types (messages, prompts, documents, etc.). See +`langchain_core.load.mapping` for the full list. + +!!! warning "Do not use with untrusted input" + + This function instantiates Python objects and can trigger side effects + during deserialization. **Never call `loads()` on data from an untrusted + or unauthenticated source.** See the module-level security model + documentation for details and best practices. + +**Parameters:** + + +The string to load. + + + +Allowlist of classes that can be deserialized. + +- `'core'` (default): Allow classes defined in the serialization mappings + for `langchain_core`. +- `'all'`: Allow classes defined in the serialization mappings. + + This includes core LangChain types (messages, prompts, documents, etc.) + and trusted partner integrations. See `langchain_core.load.mapping` for + the full list. + +- Explicit list of classes: Only those specific classes are allowed. +- `[]`: Disallow all deserialization (will raise on any object). + + + +A map of secrets to load. + +Only include the specific secrets the serialized object requires. If +a secret is not found in the map, it will be loaded from the +environment if `secrets_from_env` is `True`. + + + +Additional namespaces (modules) to allow during +deserialization, beyond the default trusted namespaces. + + + +Whether to load secrets from the environment. + +A crafted payload can name arbitrary environment variables in its +`secret` fields, so enabling this on untrusted data can leak +sensitive values. Keep this `False` (the default) unless the +serialized data is fully trusted. + + + +A dictionary of additional namespace mappings. + +You can use this to override default mappings or add new mappings. + +When `allowed_objects` is `None` (using defaults), paths from these +mappings are also added to the allowed class paths. + + + +Whether to ignore unserializable fields. + + + +Optional callable to validate kwargs before instantiation. + +If provided, this function is called with `(class_path, kwargs)` where +`class_path` is the class path tuple and `kwargs` is the kwargs dict. +The validator should raise an exception if the object should not be +deserialized, otherwise return `None`. + +Defaults to `default_init_validator` which blocks jinja2 templates. + + +**Returns:** `Any` + +Revived LangChain objects. + +**Raises:** + +- `ValueError`: If an object's class path is not in the `allowed_objects` allowlist. + + + + + + + + +```python +langchain_core.load.load.ALL_SERIALIZABLE_MAPPINGS = {None: SERIALIZABLE_MAPPING, None: OLD_CORE_NAMESPACES_MAPPING, None: _OG_SERIAL... +``` + + + + + + + + + +```python +langchain_core.load.load.AllowedObject = type[Serializable] +``` + + + + + + +Type alias for classes that can be included in the `allowed_objects` parameter. + +Must be a `Serializable` subclass (the class itself, not an instance). + + + + + + + +```python +langchain_core.load.load.DEFAULT_NAMESPACES = ['langchain', 'langchain_core', 'langchain_community', 'langchain_anthropic', 'l... +``` + + + + + + + + + +```python +langchain_core.load.load.DISALLOW_LOAD_FROM_PATH = ['langchain_community', 'langchain'] +``` + + + + + + + + + +```python +langchain_core.load.load.InitValidator = Callable[[tuple[str, ...], dict[str, Any]], None] +``` + + + + + + +Type alias for a callable that validates kwargs during deserialization. + +The callable receives: + +- `class_path`: A tuple of strings identifying the class being instantiated + (e.g., `('langchain', 'schema', 'messages', 'AIMessage')`). +- `kwargs`: The kwargs dict that will be passed to the constructor. + +The validator should raise an exception if the object should not be deserialized. + + + + + + + +```python +langchain_core.load.load._default_class_paths_cache: dict[str, set[tuple[str, ...]]] = {} +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/mapping.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/mapping.mdx new file mode 100644 index 0000000..8d839b0 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/mapping.mdx @@ -0,0 +1,85 @@ +--- +layout: overview +slug: langchain-core/langchain_core/load/mapping +title: langchain_core.load.mapping +--- + +Serialization mapping. + +This file contains a mapping between the `lc_namespace` path for a given +subclass that implements from `Serializable` to the namespace +where that class is actually located. + +This mapping helps maintain the ability to serialize and deserialize +well-known LangChain objects even if they are moved around in the codebase +across different LangChain versions. + +For example, the code for the `AIMessage` class is located in +`langchain_core.messages.ai.AIMessage`. This message is associated with the +`lc_namespace` of `["langchain", "schema", "messages", "AIMessage"]`, +because this code was originally in `langchain.schema.messages.AIMessage`. + +The mapping allows us to deserialize an `AIMessage` created with an older +version of LangChain where the code was in a different location. + +## Module Contents + +### Data + +[`OLD_CORE_NAMESPACES_MAPPING`](#langchain_core-load-mapping-OLD_CORE_NAMESPACES_MAPPING) + +[`SERIALIZABLE_MAPPING`](#langchain_core-load-mapping-SERIALIZABLE_MAPPING) + +[`_JS_SERIALIZABLE_MAPPING`](#langchain_core-load-mapping-_JS_SERIALIZABLE_MAPPING) + +[`_OG_SERIALIZABLE_MAPPING`](#langchain_core-load-mapping-_OG_SERIALIZABLE_MAPPING) + +### API + + + + + +```python +langchain_core.load.mapping.OLD_CORE_NAMESPACES_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {('langchain_core', 'messages', 'ai', 'AIMessage'): ('langchain_core', 'messages... +``` + + + + + + + + + +```python +langchain_core.load.mapping.SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {('langchain', 'schema', 'messages', 'AIMessage'): ('langchain_core', 'messages'... +``` + + + + + + + + + +```python +langchain_core.load.mapping._JS_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {('langchain_core', 'messages', 'AIMessage'): ('langchain_core', 'messages', 'ai... +``` + + + + + + + + + +```python +langchain_core.load.mapping._OG_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {('langchain', 'schema', 'AIMessage'): ('langchain_core', 'messages', 'ai', 'AIM... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/serializable.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/serializable.mdx new file mode 100644 index 0000000..94cf1be --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/load/serializable.mdx @@ -0,0 +1,524 @@ +--- +layout: overview +slug: langchain-core/langchain_core/load/serializable +title: langchain_core.load.serializable +--- + +Serializable base class. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseSerialized`](#langchain_core-load-serializable-BaseSerialized) | Base class for serialized objects. | +| [`Serializable`](#langchain_core-load-serializable-Serializable) | Serializable base class. | +| [`SerializedConstructor`](#langchain_core-load-serializable-SerializedConstructor) | Serialized constructor. | +| [`SerializedNotImplemented`](#langchain_core-load-serializable-SerializedNotImplemented) | Serialized not implemented. | +| [`SerializedSecret`](#langchain_core-load-serializable-SerializedSecret) | Serialized secret. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_is_field_useful`](#langchain_core-load-serializable-_is_field_useful) | Check if a field is useful as a constructor argument. | +| [`_replace_secrets`](#langchain_core-load-serializable-_replace_secrets) | - | +| [`_try_neq_default`](#langchain_core-load-serializable-_try_neq_default) | - | +| [`to_json_not_implemented`](#langchain_core-load-serializable-to_json_not_implemented) | Serialize a "not implemented" object. | +| [`try_neq_default`](#langchain_core-load-serializable-try_neq_default) | Try to determine if a value is different from the default. | + +### Data + +[`logger`](#langchain_core-load-serializable-logger) + +### API + + + + + +```python +class langchain_core.load.serializable.BaseSerialized +``` + + + + + + +**Bases:** `typing.TypedDict` + +Base class for serialized objects. + + +The graph of the object. + + + +The unique identifier of the object. + + + +The version of the serialization format. + + + +The name of the object. + + + + + + + + +```python +class langchain_core.load.serializable.Serializable( + args: typing.Any = (), + kwargs: typing.Any = {} +) +``` + + + + + + +Abstract + +**Bases:** `BaseModel` + +Serializable base class. + +This class is used to serialize objects to JSON. + +It relies on the following methods and properties: + +- [`is_lc_serializable`][langchain_core.load.serializable.Serializable.is_lc_serializable]: Is this class serializable? + + By design, even if a class inherits from `Serializable`, it is not serializable + by default. This is to prevent accidental serialization of objects that should + not be serialized. +- [`get_lc_namespace`][langchain_core.load.serializable.Serializable.get_lc_namespace]: Get the namespace of the LangChain object. + + During deserialization, this namespace is used to identify + the correct class to instantiate. + + Please see the `Reviver` class in `langchain_core.load.load` for more details. + + During deserialization an additional mapping is handle classes that have moved + or been renamed across package versions. + +- [`lc_secrets`][langchain_core.load.serializable.Serializable.lc_secrets]: A map of constructor argument names to secret ids. +- [`lc_attributes`][langchain_core.load.serializable.Serializable.lc_attributes]: List of additional attribute names that should be included + as part of the serialized representation. + + + +List of attribute names that should be included in the serialized kwargs. + +These attributes must be accepted by the constructor. + +Default is an empty dictionary. + + + +A map of constructor argument names to secret ids. + +For example, `{"openai_api_key": "OPENAI_API_KEY"}` + + + + + + + + +```python +langchain_core.load.serializable.Serializable.__repr_args__() -> typing.Any +``` + + + + + + + + + + + + +```python +langchain_core.load.serializable.Serializable.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +The default implementation splits `cls.__module__` on `'.'`, e.g. +`langchain_openai.chat_models` becomes +`["langchain_openai", "chat_models"]`. This value is used by `lc_id` to +build the serialization identifier. + +New partner packages should **not** override this method. The default +behavior is correct for any class whose module path already reflects +its package name. Some older packages (e.g. `langchain-openai`, +`langchain-anthropic`) override it to return a legacy-style namespace +like `["langchain", "chat_models", "openai"]`, matching the module +paths that existed before those integrations were split out of the +main `langchain` package. Those overrides are kept for +backwards-compatible deserialization; new packages should not copy them. + +Deserialization mapping is handled separately by +`SERIALIZABLE_MAPPING` in `langchain_core.load.mapping`. + +**Returns:** `list[str]` + +The namespace. + + + + + + + +```python +langchain_core.load.serializable.Serializable.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Is this class serializable? + +By design, even if a class inherits from `Serializable`, it is not serializable +by default. This is to prevent accidental serialization of objects that should +not be serialized. + +**Returns:** `bool` + +Whether the class is serializable. Default is `False`. + + + + + + + +```python +langchain_core.load.serializable.Serializable.lc_id() -> list[str] +``` + + + + + + +classmethod + +Return a unique identifier for this class for serialization purposes. + +The unique identifier is a list of strings that describes the path +to the object. + +For example, for the class `langchain.llms.openai.OpenAI`, the id is +`["langchain", "llms", "openai", "OpenAI"]`. + + + + + + + +```python +langchain_core.load.serializable.Serializable.to_json() -> langchain_core.load.serializable.SerializedConstructor | langchain_core.load.serializable.SerializedNotImplemented +``` + + + + + + +Serialize the object to JSON. + +**Returns:** `SerializedConstructor | SerializedNotImplemented` + +A JSON serializable object or a `SerializedNotImplemented` object. + +**Raises:** + +- `ValueError`: If the class has deprecated attributes. + + + + + + + +```python +langchain_core.load.serializable.Serializable.to_json_not_implemented() -> langchain_core.load.serializable.SerializedNotImplemented +``` + + + + + + +Serialize a "not implemented" object. + +**Returns:** `SerializedNotImplemented` + +`SerializedNotImplemented`. + + + + + + + + + +```python +class langchain_core.load.serializable.SerializedConstructor() +``` + + + + + + +**Bases:** [BaseSerialized](#langchain_core-load-serializable-BaseSerialized) + +Serialized constructor. + + + +The constructor arguments. + + + +The type of the object. Must be `'constructor'`. + + + + + + + +```python +class langchain_core.load.serializable.SerializedNotImplemented() +``` + + + + + + +**Bases:** [BaseSerialized](#langchain_core-load-serializable-BaseSerialized) + +Serialized not implemented. + + + +The representation of the object. + + + +The type of the object. Must be `'not_implemented'`. + + + + + + + +```python +class langchain_core.load.serializable.SerializedSecret() +``` + + + + + + +**Bases:** [BaseSerialized](#langchain_core-load-serializable-BaseSerialized) + +Serialized secret. + + + +The type of the object. Must be `'secret'`. + + + + + + + +```python +langchain_core.load.serializable._is_field_useful( + inst: langchain_core.load.serializable.Serializable, + key: str, + value: typing.Any +) -> bool +``` + + + + + + +Check if a field is useful as a constructor argument. + +**Parameters:** + + +The instance. + + + +The key. + + + +The value. + + +**Returns:** `bool` + +Whether the field is useful. If the field is required, it is useful. + + + + + + + + +```python +langchain_core.load.serializable._replace_secrets( + root: dict[typing.Any, typing.Any], + secrets_map: dict[str, str] +) -> dict[typing.Any, typing.Any] +``` + + + + + + + + + + + + + +```python +langchain_core.load.serializable._try_neq_default( + value: typing.Any, + field: pydantic.fields.FieldInfo +) -> bool +``` + + + + + + + + + + + + + +```python +langchain_core.load.serializable.to_json_not_implemented( + obj: object +) -> langchain_core.load.serializable.SerializedNotImplemented +``` + + + + + + +Serialize a "not implemented" object. + +**Parameters:** + + +Object to serialize. + + +**Returns:** `SerializedNotImplemented` + +`SerializedNotImplemented` + + + + + + + + +```python +langchain_core.load.serializable.try_neq_default( + value: typing.Any, + key: str, + model: pydantic.BaseModel +) -> bool +``` + + + + + + +Try to determine if a value is different from the default. + +**Parameters:** + + +The value. + + + +The key. + + + +The Pydantic model. + + +**Returns:** `bool` + +Whether the value is different from the default. + + + + + + + + +```python +langchain_core.load.serializable.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages.mdx new file mode 100644 index 0000000..aee320c --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages.mdx @@ -0,0 +1,99 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages +title: langchain_core.messages +--- + +**Messages** are objects used in prompts and chat conversations. + +## Subpackages + +- **[`langchain_core.messages.block_translators`](/langchain-core/langchain_core/messages/block_translators)** + +## Submodules + +- **[`langchain_core.messages.ai`](/langchain-core/langchain_core/messages/ai)** +- **[`langchain_core.messages.base`](/langchain-core/langchain_core/messages/base)** +- **[`langchain_core.messages.chat`](/langchain-core/langchain_core/messages/chat)** +- **[`langchain_core.messages.content`](/langchain-core/langchain_core/messages/content)** +- **[`langchain_core.messages.function`](/langchain-core/langchain_core/messages/function)** +- **[`langchain_core.messages.human`](/langchain-core/langchain_core/messages/human)** +- **[`langchain_core.messages.modifier`](/langchain-core/langchain_core/messages/modifier)** +- **[`langchain_core.messages.system`](/langchain-core/langchain_core/messages/system)** +- **[`langchain_core.messages.tool`](/langchain-core/langchain_core/messages/tool)** +- **[`langchain_core.messages.utils`](/langchain-core/langchain_core/messages/utils)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-messages-__dir__) | - | +| [`__getattr__`](#langchain_core-messages-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-messages-__all__) + +[`_dynamic_imports`](#langchain_core-messages-_dynamic_imports) + +### API + + + + + +```python +langchain_core.messages.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.messages.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.messages.__all__ = ('LC_AUTO_PREFIX', 'LC_ID_PREFIX', 'AIMessage', 'AIMessageChunk', 'Annotation', ... +``` + + + + + + + + + +```python +langchain_core.messages._dynamic_imports = {'AIMessage': 'ai', 'AIMessageChunk': 'ai', 'Annotation': 'content', 'AudioConte... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/ai.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/ai.mdx new file mode 100644 index 0000000..d11d539 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/ai.mdx @@ -0,0 +1,486 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/ai +title: langchain_core.messages.ai +--- + +AI message. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AIMessage`](#langchain_core-messages-ai-AIMessage) | Message from an AI. | +| [`AIMessageChunk`](#langchain_core-messages-ai-AIMessageChunk) | Message chunk from an AI (yielded when streaming). | +| [`InputTokenDetails`](#langchain_core-messages-ai-InputTokenDetails) | Breakdown of input token counts. | +| [`OutputTokenDetails`](#langchain_core-messages-ai-OutputTokenDetails) | Breakdown of output token counts. | +| [`UsageMetadata`](#langchain_core-messages-ai-UsageMetadata) | Usage metadata for a message, such as token counts. | + +### Functions + +| Name | Description | +|------|-------------| +| [`add_ai_message_chunks`](#langchain_core-messages-ai-add_ai_message_chunks) | Add multiple `AIMessageChunk`s together. | +| [`add_usage`](#langchain_core-messages-ai-add_usage) | Recursively add two UsageMetadata objects. | +| [`subtract_usage`](#langchain_core-messages-ai-subtract_usage) | Recursively subtract two `UsageMetadata` objects. | + +### Data + +[`logger`](#langchain_core-messages-ai-logger) + +### API + + + + + +```python +class langchain_core.messages.ai.AIMessage( + content: str | list[str | dict] | None = None, + content_blocks: list[langchain_core.messages.content.ContentBlock] | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseMessage](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessage) + +Message from an AI. + +An `AIMessage` is returned from a chat model as a response to a prompt. + +This message represents the output of the model and consists of both +the raw output as returned by the model and standardized fields +(e.g., tool calls, usage metadata) added by the LangChain framework. + + + +Return standard, typed `ContentBlock` dicts from the message. + +If the message has a known model provider, use the provider-specific translator +first before falling back to best-effort parsing. For details, see the property +on `BaseMessage`. + + + +If present, tool calls with parsing errors associated with the message. + + + +Attributes to be serialized. + +Includes all attributes, even if they are derived from other initialization +arguments. + + + +If present, tool calls associated with the message. + + + +The type of the message (used for deserialization). + + + +If present, usage metadata for a message, such as token counts. + +This is a standard representation of token usage that is consistent across models. + + + + + +```python +langchain_core.messages.ai.AIMessage._backwards_compat_tool_calls( + values: dict +) -> typing.Any +``` + + + + + + +classmethod + + + + + + + +```python +langchain_core.messages.ai.AIMessage.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Return a pretty representation of the message for display. + +**Parameters:** + + +Whether to return an HTML-formatted string. + + +**Returns:** `str` + +A pretty representation of the message. + + + + + + + + + +```python +class langchain_core.messages.ai.AIMessageChunk() +``` + + + + + + +**Bases:** [AIMessage](#langchain_core-messages-ai-AIMessage), [BaseMessageChunk](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessageChunk) + +Message chunk from an AI (yielded when streaming). + + + +Optional span represented by an aggregated `AIMessageChunk`. + +If a chunk with `chunk_position="last"` is aggregated into a stream, +`tool_call_chunks` in message content will be parsed into `tool_calls`. + + + +Return standard, typed `ContentBlock` dicts from the message. + + + + + + +If provided, tool call chunks associated with the message. + + + +The type of the message (used for deserialization). + + + + + +```python +langchain_core.messages.ai.AIMessageChunk.__add__( + other: typing.Any +) -> langchain_core.messages.base.BaseMessageChunk +``` + + + + + + + + + + + + +```python +langchain_core.messages.ai.AIMessageChunk.init_server_tool_calls() -> typing_extensions.Self +``` + + + + + + +Initialize server tool calls. + +Parse `server_tool_call_chunks` from +[`ServerToolCallChunk`][langchain.messages.ServerToolCallChunk] objects. + + + + + + + +```python +langchain_core.messages.ai.AIMessageChunk.init_tool_calls() -> typing_extensions.Self +``` + + + + + + +Initialize tool calls from tool call chunks. + +**Returns:** `Self` + +The values with tool calls initialized. + +**Raises:** + +- `ValueError`: If the tool call chunks are malformed. + + + + + + + + + +```python +class langchain_core.messages.ai.InputTokenDetails +``` + + + + + + +**Bases:** `typing.TypedDict` + +Breakdown of input token counts. + +Does *not* need to sum to full input token count. Does *not* need to have all keys. + +May also hold extra provider-specific keys. + +!!! version-added "Added in `langchain-core` 0.3.9" + + +Audio input tokens. + + + +Input tokens that were cached and there was a cache miss. + +Since there was a cache miss, the cache was created from these tokens. + + + +Input tokens that were cached and there was a cache hit. + +Since there was a cache hit, the tokens were read from the cache. More precisely, +the model state given these tokens was read from the cache. + + + + + + + + +```python +class langchain_core.messages.ai.OutputTokenDetails +``` + + + + + + +**Bases:** `typing.TypedDict` + +Breakdown of output token counts. + +Does *not* need to sum to full output token count. Does *not* need to have all keys. + +May also hold extra provider-specific keys. + +!!! version-added "Added in `langchain-core` 0.3.9" + + +Audio output tokens. + + + +Reasoning output tokens. + +Tokens generated by the model in a chain of thought process (i.e. by OpenAI's o1 +models) that are not returned as part of model output. + + + + + + + + +```python +class langchain_core.messages.ai.UsageMetadata +``` + + + + + + +**Bases:** `typing.TypedDict` + +Usage metadata for a message, such as token counts. + +This is a standard representation of token usage that is consistent across models. + +!!! warning "Behavior changed in `langchain-core` 0.3.9" + + Added `input_token_details` and `output_token_details`. + +!!! note "LangSmith SDK" + + The LangSmith SDK also has a `UsageMetadata` class. While the two share fields, + LangSmith's `UsageMetadata` has additional fields to capture cost information + used by the LangSmith platform. + + +Breakdown of input token counts. + +Does *not* need to sum to full input token count. Does *not* need to have all keys. + + + +Count of input (or prompt) tokens. Sum of all input token types. + + + +Breakdown of output token counts. + +Does *not* need to sum to full output token count. Does *not* need to have all keys. + + + +Count of output (or completion) tokens. Sum of all output token types. + + + +Total token count. Sum of `input_tokens` + `output_tokens`. + + + + + + + + +```python +langchain_core.messages.ai.add_ai_message_chunks( + left: langchain_core.messages.ai.AIMessageChunk, + others: langchain_core.messages.ai.AIMessageChunk = () +) -> langchain_core.messages.ai.AIMessageChunk +``` + + + + + + +Add multiple `AIMessageChunk`s together. + +**Parameters:** + + +The first `AIMessageChunk`. + + + +Other `AIMessageChunk`s to add. + + +**Returns:** `AIMessageChunk` + +The resulting `AIMessageChunk`. + + + + + + + + +```python +langchain_core.messages.ai.add_usage( + left: langchain_core.messages.ai.UsageMetadata | None, + right: langchain_core.messages.ai.UsageMetadata | None +) -> langchain_core.messages.ai.UsageMetadata +``` + + + + + + +Recursively add two UsageMetadata objects. + +Args: + left: The first `UsageMetadata` object. + right: The second `UsageMetadata` object. + +**Returns:** `UsageMetadata` + +The sum of the two `UsageMetadata` objects. + + + + + + + + +```python +langchain_core.messages.ai.subtract_usage( + left: langchain_core.messages.ai.UsageMetadata | None, + right: langchain_core.messages.ai.UsageMetadata | None +) -> langchain_core.messages.ai.UsageMetadata +``` + + + + + + +Recursively subtract two `UsageMetadata` objects. + +Token counts cannot be negative so the actual operation is `max(left - right, 0)`. + +Args: + left: The first `UsageMetadata` object. + right: The second `UsageMetadata` object. + +**Returns:** `UsageMetadata` + +The resulting `UsageMetadata` after subtraction. + + + + + + + + +```python +langchain_core.messages.ai.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/base.mdx new file mode 100644 index 0000000..35988bf --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/base.mdx @@ -0,0 +1,537 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/base +title: langchain_core.messages.base +--- + +Base message. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseMessage`](#langchain_core-messages-base-BaseMessage) | Base abstract message class. | +| [`BaseMessageChunk`](#langchain_core-messages-base-BaseMessageChunk) | Message chunk, which can be concatenated with other Message chunks. | +| [`TextAccessor`](#langchain_core-messages-base-TextAccessor) | String-like object that supports both property and method access patterns. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_extract_reasoning_from_additional_kwargs`](#langchain_core-messages-base-_extract_reasoning_from_additional_kwargs) | Extract `reasoning_content` from `additional_kwargs`. | +| [`get_msg_title_repr`](#langchain_core-messages-base-get_msg_title_repr) | Get a title representation for a message. | +| [`merge_content`](#langchain_core-messages-base-merge_content) | Merge multiple message contents. | +| [`message_to_dict`](#langchain_core-messages-base-message_to_dict) | Convert a Message to a dictionary. | +| [`messages_to_dict`](#langchain_core-messages-base-messages_to_dict) | Convert a sequence of Messages to a list of dictionaries. | + +### API + + + + + +```python +class langchain_core.messages.base.BaseMessage( + content: str | list[str | dict] | None = None, + content_blocks: list[langchain_core.messages.content.ContentBlock] | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [Serializable](/langchain-core/langchain_core/load/serializable#langchain_core-load-serializable-Serializable) + +Base abstract message class. + +Messages are the inputs and outputs of a chat model. + +Examples include [`HumanMessage`][langchain.messages.HumanMessage], +[`AIMessage`][langchain.messages.AIMessage], and +[`SystemMessage`][langchain.messages.SystemMessage]. + + + +Reserved for additional payload data associated with the message. + +For example, for a message from an AI, this could include tool calls as +encoded by the model provider. + + + +The contents of the message. + + + +Load content blocks from the message content. + +!!! version-added "Added in `langchain-core` 1.0.0" + + + +An optional unique identifier for the message. + +This should ideally be provided by the provider/model which created the message. + + + + + + +An optional name for the message. + +This can be used to provide a human-readable name for the message. + +Usage of this field is optional, and whether it's used or not is up to the +model implementation. + + + +Examples: response headers, logprobs, token counts, model name. + + + +Get the text content of the message as a string. + +Can be used as both property (`message.text`) and method (`message.text()`). + +Handles both string and list content types (e.g. for content blocks). Only +extracts blocks with `type: 'text'`; other block types are ignored. + +!!! deprecated + As of `langchain-core` 1.0.0, calling `.text()` as a method is deprecated. + Use `.text` as a property instead. This method will be removed in 2.0.0. + + + +The type of the message. Must be a string that is unique to the message type. + +The purpose of this field is to allow for easy identification of the message type +when deserializing messages. + + + + + +```python +langchain_core.messages.base.BaseMessage.__add__( + other: typing.Any +) -> langchain_core.prompts.chat.ChatPromptTemplate +``` + + + + + + +Concatenate this message with another message. + +**Parameters:** + + +Another message to concatenate with this one. + + +**Returns:** `ChatPromptTemplate` + +A ChatPromptTemplate containing both messages. + + + + + + + +```python +langchain_core.messages.base.BaseMessage.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "messages"]` + + + + + + + +```python +langchain_core.messages.base.BaseMessage.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +`BaseMessage` is serializable. + +**Returns:** `bool` + +True + + + + + + + +```python +langchain_core.messages.base.BaseMessage.pretty_print() -> None +``` + + + + + + +Print a pretty representation of the message. + + + + + + + +```python +langchain_core.messages.base.BaseMessage.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Get a pretty representation of the message. + +**Parameters:** + + +Whether to format the message as HTML. If `True`, the message will be +formatted with HTML tags. + + +**Returns:** `str` + +A pretty representation of the message. + + + + + + + + + +```python +class langchain_core.messages.base.BaseMessageChunk() +``` + + + + + + +**Bases:** [BaseMessage](#langchain_core-messages-base-BaseMessage) + +Message chunk, which can be concatenated with other Message chunks. + + + + + + +```python +langchain_core.messages.base.BaseMessageChunk.__add__( + other: typing.Any +) -> langchain_core.messages.base.BaseMessageChunk +``` + + + + + + +Message chunks support concatenation with other message chunks. + +This functionality is useful to combine message chunks yielded from +a streaming model into a complete message. + +**Parameters:** + + +Another message chunk to concatenate with this one. + + +**Returns:** `BaseMessageChunk` + +A new message chunk that is the concatenation of this message chunk + +**Raises:** + +- `TypeError`: If the other object is not a message chunk. + + + + + + + + + +```python +class langchain_core.messages.base.TextAccessor() +``` + + + + + + +**Bases:** `str` + +String-like object that supports both property and method access patterns. + +Exists to maintain backward compatibility while transitioning from method-based to +property-based text access in message objects. In LangChain <v1.0, message text was +accessed via `.text()` method calls. In v1.0=<, the preferred pattern is property +access via `.text`. + +Rather than breaking existing code immediately, `TextAccessor` allows both +patterns: +- Modern property access: `message.text` (returns string directly) +- Legacy method access: `message.text()` (callable, emits deprecation warning) + + + + + + + + +```python +langchain_core.messages.base.TextAccessor.__call__() -> str +``` + + + + + + +Enable method-style text access for backward compatibility. + +This method exists solely to support legacy code that calls `.text()` +as a method. New code should use property access (`.text`) instead. + +!!! deprecated + As of `langchain-core` 1.0.0, calling `.text()` as a method is deprecated. + Use `.text` as a property instead. This method will be removed in 2.0.0. + +**Returns:** `str` + +The string content, identical to property access. + + + + + + + +```python +langchain_core.messages.base.TextAccessor.__new__( + value: str +) -> typing_extensions.Self +``` + + + + + + +Create new TextAccessor instance. + + + + + + + + + +```python +langchain_core.messages.base._extract_reasoning_from_additional_kwargs( + message: langchain_core.messages.base.BaseMessage +) -> langchain_core.messages.content.ReasoningContentBlock | None +``` + + + + + + +Extract `reasoning_content` from `additional_kwargs`. + +Handles reasoning content stored in various formats: +- `additional_kwargs["reasoning_content"]` (string) - Ollama, DeepSeek, XAI, Groq + +**Parameters:** + + +The message to extract reasoning from. + + +**Returns:** `types.ReasoningContentBlock | None` + +A `ReasoningContentBlock` if reasoning content is found, None otherwise. + + + + + + + + +```python +langchain_core.messages.base.get_msg_title_repr( + title: str, + bold: bool = False +) -> str +``` + + + + + + +Get a title representation for a message. + +**Parameters:** + + +The title. + + + +Whether to bold the title. + + +**Returns:** `str` + +The title representation. + + + + + + + + +```python +langchain_core.messages.base.merge_content( + first_content: str | list[str | dict], + contents: str | list[str | dict] = () +) -> str | list[str | dict] +``` + + + + + + +Merge multiple message contents. + +**Parameters:** + + +The first `content`. Can be a string or a list. + + + +The other `content`s. Can be a string or a list. + + +**Returns:** `str | list[str | dict]` + +The merged content. + + + + + + + + +```python +langchain_core.messages.base.message_to_dict( + message: langchain_core.messages.base.BaseMessage +) -> dict +``` + + + + + + +Convert a Message to a dictionary. + +**Parameters:** + + +Message to convert. + + +**Returns:** `dict` + +Message as a dict. The dict will have a `type` key with the message type + + + + + + + + +```python +langchain_core.messages.base.messages_to_dict( + messages: collections.abc.Sequence[langchain_core.messages.base.BaseMessage] +) -> list[dict] +``` + + + + + + +Convert a sequence of Messages to a list of dictionaries. + +**Parameters:** + + +Sequence of messages (as `BaseMessage`s) to convert. + + +**Returns:** `list[dict]` + +List of messages as dicts. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators.mdx new file mode 100644 index 0000000..d0be334 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators.mdx @@ -0,0 +1,159 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/block_translators +title: langchain_core.messages.block_translators +--- + +Derivations of standard content blocks from provider content. + +`AIMessage` will first attempt to use a provider-specific translator if +`model_provider` is set in `response_metadata` on the message. Consequently, each +provider translator must handle all possible content response types from the provider, +including text. + +If no provider is set, or if the provider does not have a registered translator, +`AIMessage` will fall back to best-effort parsing of the content into blocks using +the implementation in `BaseMessage`. + +## Submodules + +- **[`langchain_core.messages.block_translators.anthropic`](/langchain-core/langchain_core/messages/block_translators/anthropic)** +- **[`langchain_core.messages.block_translators.bedrock`](/langchain-core/langchain_core/messages/block_translators/bedrock)** +- **[`langchain_core.messages.block_translators.bedrock_converse`](/langchain-core/langchain_core/messages/block_translators/bedrock_converse)** +- **[`langchain_core.messages.block_translators.google_genai`](/langchain-core/langchain_core/messages/block_translators/google_genai)** +- **[`langchain_core.messages.block_translators.google_vertexai`](/langchain-core/langchain_core/messages/block_translators/google_vertexai)** +- **[`langchain_core.messages.block_translators.groq`](/langchain-core/langchain_core/messages/block_translators/groq)** +- **[`langchain_core.messages.block_translators.langchain_v0`](/langchain-core/langchain_core/messages/block_translators/langchain_v0)** +- **[`langchain_core.messages.block_translators.openai`](/langchain-core/langchain_core/messages/block_translators/openai)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_register_translators`](#langchain_core-messages-block_translators-_register_translators) | Register all translators in langchain-core. | +| [`get_translator`](#langchain_core-messages-block_translators-get_translator) | Get the translator functions for a provider. | +| [`register_translator`](#langchain_core-messages-block_translators-register_translator) | Register content translators for a provider in `PROVIDER_TRANSLATORS`. | + +### Data + +[`PROVIDER_TRANSLATORS`](#langchain_core-messages-block_translators-PROVIDER_TRANSLATORS) + +### API + + + + + +```python +langchain_core.messages.block_translators._register_translators() -> None +``` + + + + + + +Register all translators in langchain-core. + +A unit test ensures all modules in `block_translators` are represented here. + +For translators implemented outside langchain-core, they can be registered by +calling `register_translator` from within the integration package. + + + + + + + + +```python +langchain_core.messages.block_translators.get_translator( + provider: str +) -> dict[str, collections.abc.Callable[..., list[langchain_core.messages.content.ContentBlock]]] | None +``` + + + + + + +Get the translator functions for a provider. + +**Parameters:** + + +The model provider name. + + +**Returns:** `dict[str, Callable[..., list[types.ContentBlock]]] | None` + +Dictionary with `'translate_content'` and `'translate_content_chunk'` + + + + + + + + +```python +langchain_core.messages.block_translators.register_translator( + provider: str, + translate_content: collections.abc.Callable[[AIMessage], list[langchain_core.messages.content.ContentBlock]], + translate_content_chunk: collections.abc.Callable[[AIMessageChunk], list[langchain_core.messages.content.ContentBlock]] +) -> None +``` + + + + + + +Register content translators for a provider in `PROVIDER_TRANSLATORS`. + +**Parameters:** + + +The model provider name (e.g. `'openai'`, `'anthropic'`). + + + +Function to translate `AIMessage` content. + + + +Function to translate `AIMessageChunk` content. + + + + + + + + + +```python +langchain_core.messages.block_translators.PROVIDER_TRANSLATORS: dict[str, dict[str, Callable[..., list[ContentBlock]]]] = {} +``` + + + + + + +Map model provider names to translator functions. + +The dictionary maps provider names (e.g. `'openai'`, `'anthropic'`) to another +dictionary with two keys: +- `'translate_content'`: Function to translate `AIMessage` content. +- `'translate_content_chunk'`: Function to translate `AIMessageChunk` content. + +When calling `content_blocks` on an `AIMessage` or `AIMessageChunk`, if +`model_provider` is set in `response_metadata`, the corresponding translator +functions will be used to parse the content into blocks. Otherwise, best-effort parsing +in `BaseMessage` will be used. + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/anthropic.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/anthropic.mdx new file mode 100644 index 0000000..783049d --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/anthropic.mdx @@ -0,0 +1,200 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/block_translators/anthropic +title: langchain_core.messages.block_translators.anthropic +--- + +Derivations of standard content blocks from Anthropic content. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_convert_citation_to_v1`](#langchain_core-messages-block_translators-anthropic-_convert_citation_to_v1) | - | +| [`_convert_to_v1_from_anthropic`](#langchain_core-messages-block_translators-anthropic-_convert_to_v1_from_anthropic) | Convert Anthropic message content to v1 format. | +| [`_convert_to_v1_from_anthropic_input`](#langchain_core-messages-block_translators-anthropic-_convert_to_v1_from_anthropic_input) | Convert Anthropic format blocks to v1 format. | +| [`_populate_extras`](#langchain_core-messages-block_translators-anthropic-_populate_extras) | Mutate a block, populating extras. | +| [`_register_anthropic_translator`](#langchain_core-messages-block_translators-anthropic-_register_anthropic_translator) | Register the Anthropic translator with the central registry. | +| [`translate_content`](#langchain_core-messages-block_translators-anthropic-translate_content) | Derive standard content blocks from a message with Anthropic content. | +| [`translate_content_chunk`](#langchain_core-messages-block_translators-anthropic-translate_content_chunk) | Derive standard content blocks from a message chunk with Anthropic content. | + +### API + + + + + +```python +langchain_core.messages.block_translators.anthropic._convert_citation_to_v1( + citation: dict[str, typing.Any] +) -> langchain_core.messages.content.Annotation +``` + + + + + + + + + + + + + +```python +langchain_core.messages.block_translators.anthropic._convert_to_v1_from_anthropic( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert Anthropic message content to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.anthropic._convert_to_v1_from_anthropic_input( + content: list[langchain_core.messages.content.ContentBlock] +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert Anthropic format blocks to v1 format. + +During the `content_blocks` parsing process, we wrap blocks not recognized as a v1 +block as a `'non_standard'` block with the original block stored in the `value` +field. This function attempts to unpack those blocks and convert any blocks that +might be Anthropic format to v1 ContentBlocks. + +If conversion fails, the block is left as a `'non_standard'` block. + +**Parameters:** + + +List of content blocks to process. + + +**Returns:** `list[types.ContentBlock]` + +Updated list with Anthropic blocks converted to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.anthropic._populate_extras( + standard_block: langchain_core.messages.content.ContentBlock, + block: dict[str, typing.Any], + known_fields: set[str] +) -> langchain_core.messages.content.ContentBlock +``` + + + + + + +Mutate a block, populating extras. + + + + + + + + +```python +langchain_core.messages.block_translators.anthropic._register_anthropic_translator() -> None +``` + + + + + + +Register the Anthropic translator with the central registry. + +Run automatically when the module is imported. + + + + + + + + +```python +langchain_core.messages.block_translators.anthropic.translate_content( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a message with Anthropic content. + +**Parameters:** + + +The message to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + + + + + + +```python +langchain_core.messages.block_translators.anthropic.translate_content_chunk( + message: langchain_core.messages.AIMessageChunk +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a message chunk with Anthropic content. + +**Parameters:** + + +The message chunk to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/bedrock.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/bedrock.mdx new file mode 100644 index 0000000..cda61c1 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/bedrock.mdx @@ -0,0 +1,141 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/block_translators/bedrock +title: langchain_core.messages.block_translators.bedrock +--- + +Derivations of standard content blocks from Bedrock content. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_convert_to_v1_from_bedrock`](#langchain_core-messages-block_translators-bedrock-_convert_to_v1_from_bedrock) | Convert bedrock message content to v1 format. | +| [`_convert_to_v1_from_bedrock_chunk`](#langchain_core-messages-block_translators-bedrock-_convert_to_v1_from_bedrock_chunk) | Convert bedrock message chunk content to v1 format. | +| [`_register_bedrock_translator`](#langchain_core-messages-block_translators-bedrock-_register_bedrock_translator) | Register the bedrock translator with the central registry. | +| [`translate_content`](#langchain_core-messages-block_translators-bedrock-translate_content) | Derive standard content blocks from a message with Bedrock content. | +| [`translate_content_chunk`](#langchain_core-messages-block_translators-bedrock-translate_content_chunk) | Derive standard content blocks from a message chunk with Bedrock content. | + +### API + + + + + +```python +langchain_core.messages.block_translators.bedrock._convert_to_v1_from_bedrock( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert bedrock message content to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.bedrock._convert_to_v1_from_bedrock_chunk( + message: langchain_core.messages.AIMessageChunk +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert bedrock message chunk content to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.bedrock._register_bedrock_translator() -> None +``` + + + + + + +Register the bedrock translator with the central registry. + +Run automatically when the module is imported. + + + + + + + + +```python +langchain_core.messages.block_translators.bedrock.translate_content( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a message with Bedrock content. + +**Parameters:** + + +The message to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + + + + + + +```python +langchain_core.messages.block_translators.bedrock.translate_content_chunk( + message: langchain_core.messages.AIMessageChunk +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a message chunk with Bedrock content. + +**Parameters:** + + +The message chunk to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/bedrock_converse.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/bedrock_converse.mdx new file mode 100644 index 0000000..ba6f8da --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/bedrock_converse.mdx @@ -0,0 +1,219 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/block_translators/bedrock_converse +title: langchain_core.messages.block_translators.bedrock_converse +--- + +Derivations of standard content blocks from Amazon (Bedrock Converse) content. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_bytes_to_b64_str`](#langchain_core-messages-block_translators-bedrock_converse-_bytes_to_b64_str) | - | +| [`_convert_citation_to_v1`](#langchain_core-messages-block_translators-bedrock_converse-_convert_citation_to_v1) | - | +| [`_convert_to_v1_from_converse`](#langchain_core-messages-block_translators-bedrock_converse-_convert_to_v1_from_converse) | Convert Bedrock Converse message content to v1 format. | +| [`_convert_to_v1_from_converse_input`](#langchain_core-messages-block_translators-bedrock_converse-_convert_to_v1_from_converse_input) | Convert Bedrock Converse format blocks to v1 format. | +| [`_populate_extras`](#langchain_core-messages-block_translators-bedrock_converse-_populate_extras) | Mutate a block, populating extras. | +| [`_register_bedrock_converse_translator`](#langchain_core-messages-block_translators-bedrock_converse-_register_bedrock_converse_translator) | Register the Bedrock Converse translator with the central registry. | +| [`translate_content`](#langchain_core-messages-block_translators-bedrock_converse-translate_content) | Derive standard content blocks from a message with Bedrock Converse content. | +| [`translate_content_chunk`](#langchain_core-messages-block_translators-bedrock_converse-translate_content_chunk) | Derive standard content blocks from a chunk with Bedrock Converse content. | + +### API + + + + + +```python +langchain_core.messages.block_translators.bedrock_converse._bytes_to_b64_str( + bytes_: bytes +) -> str +``` + + + + + + + + + + + + + +```python +langchain_core.messages.block_translators.bedrock_converse._convert_citation_to_v1( + citation: dict[str, typing.Any] +) -> langchain_core.messages.content.Annotation +``` + + + + + + + + + + + + + +```python +langchain_core.messages.block_translators.bedrock_converse._convert_to_v1_from_converse( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert Bedrock Converse message content to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.bedrock_converse._convert_to_v1_from_converse_input( + content: list[langchain_core.messages.content.ContentBlock] +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert Bedrock Converse format blocks to v1 format. + +During the `content_blocks` parsing process, we wrap blocks not recognized as a v1 +block as a `'non_standard'` block with the original block stored in the `value` +field. This function attempts to unpack those blocks and convert any blocks that +might be Converse format to v1 ContentBlocks. + +If conversion fails, the block is left as a `'non_standard'` block. + +**Parameters:** + + +List of content blocks to process. + + +**Returns:** `list[types.ContentBlock]` + +Updated list with Converse blocks converted to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.bedrock_converse._populate_extras( + standard_block: langchain_core.messages.content.ContentBlock, + block: dict[str, typing.Any], + known_fields: set[str] +) -> langchain_core.messages.content.ContentBlock +``` + + + + + + +Mutate a block, populating extras. + + + + + + + + +```python +langchain_core.messages.block_translators.bedrock_converse._register_bedrock_converse_translator() -> None +``` + + + + + + +Register the Bedrock Converse translator with the central registry. + +Run automatically when the module is imported. + + + + + + + + +```python +langchain_core.messages.block_translators.bedrock_converse.translate_content( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a message with Bedrock Converse content. + +**Parameters:** + + +The message to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + + + + + + +```python +langchain_core.messages.block_translators.bedrock_converse.translate_content_chunk( + message: langchain_core.messages.AIMessageChunk +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a chunk with Bedrock Converse content. + +**Parameters:** + + +The message chunk to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/google_genai.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/google_genai.mdx new file mode 100644 index 0000000..b7ace43 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/google_genai.mdx @@ -0,0 +1,244 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/block_translators/google_genai +title: langchain_core.messages.block_translators.google_genai +--- + +Derivations of standard content blocks from Google (GenAI) content. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_bytes_to_b64_str`](#langchain_core-messages-block_translators-google_genai-_bytes_to_b64_str) | Convert bytes to base64 encoded string. | +| [`_convert_to_v1_from_genai`](#langchain_core-messages-block_translators-google_genai-_convert_to_v1_from_genai) | Convert Google GenAI message content to v1 format. | +| [`_convert_to_v1_from_genai_input`](#langchain_core-messages-block_translators-google_genai-_convert_to_v1_from_genai_input) | Convert Google GenAI format blocks to v1 format. | +| [`_register_google_genai_translator`](#langchain_core-messages-block_translators-google_genai-_register_google_genai_translator) | Register the Google (GenAI) translator with the central registry. | +| [`translate_content`](#langchain_core-messages-block_translators-google_genai-translate_content) | Derive standard content blocks from a message with Google (GenAI) content. | +| [`translate_content_chunk`](#langchain_core-messages-block_translators-google_genai-translate_content_chunk) | Derive standard content blocks from a chunk with Google (GenAI) content. | +| [`translate_grounding_metadata_to_citations`](#langchain_core-messages-block_translators-google_genai-translate_grounding_metadata_to_citations) | Translate Google AI grounding metadata to LangChain Citations. | + +### Data + +[`_HAS_FILETYPE`](#langchain_core-messages-block_translators-google_genai-_HAS_FILETYPE) + +### API + + + + + +```python +langchain_core.messages.block_translators.google_genai._bytes_to_b64_str( + bytes_: bytes +) -> str +``` + + + + + + +Convert bytes to base64 encoded string. + + + + + + + + +```python +langchain_core.messages.block_translators.google_genai._convert_to_v1_from_genai( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert Google GenAI message content to v1 format. + +Calling `.content_blocks` on an `AIMessage` where `response_metadata.model_provider` +is set to `'google_genai'` will invoke this function to parse the content into +standard content blocks for returning. + +**Parameters:** + + +The `AIMessage` or `AIMessageChunk` to convert. + + +**Returns:** `list[types.ContentBlock]` + +List of standard content blocks derived from the message content. + + + + + + + + +```python +langchain_core.messages.block_translators.google_genai._convert_to_v1_from_genai_input( + content: list[langchain_core.messages.content.ContentBlock] +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert Google GenAI format blocks to v1 format. + +Called when message isn't an `AIMessage` or `model_provider` isn't set on +`response_metadata`. + +During the `content_blocks` parsing process, we wrap blocks not recognized as a v1 +block as a `'non_standard'` block with the original block stored in the `value` +field. This function attempts to unpack those blocks and convert any blocks that +might be GenAI format to v1 ContentBlocks. + +If conversion fails, the block is left as a `'non_standard'` block. + +**Parameters:** + + +List of content blocks to process. + + +**Returns:** `list[types.ContentBlock]` + +Updated list with GenAI blocks converted to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.google_genai._register_google_genai_translator() -> None +``` + + + + + + +Register the Google (GenAI) translator with the central registry. + +Run automatically when the module is imported. + + + + + + + + +```python +langchain_core.messages.block_translators.google_genai.translate_content( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a message with Google (GenAI) content. + +**Parameters:** + + +The message to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + + + + + + +```python +langchain_core.messages.block_translators.google_genai.translate_content_chunk( + message: langchain_core.messages.AIMessageChunk +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a chunk with Google (GenAI) content. + +**Parameters:** + + +The message chunk to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + + + + + + +```python +langchain_core.messages.block_translators.google_genai.translate_grounding_metadata_to_citations( + grounding_metadata: dict[str, typing.Any] +) -> list[langchain_core.messages.content.Citation] +``` + + + + + + +Translate Google AI grounding metadata to LangChain Citations. + +**Parameters:** + + +Google AI grounding metadata containing web search +queries, grounding chunks, and grounding supports. + + +**Returns:** `list[Citation]` + +List of Citation content blocks derived from the grounding metadata. + + + + + + + + +```python +langchain_core.messages.block_translators.google_genai._HAS_FILETYPE = True +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/google_vertexai.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/google_vertexai.mdx new file mode 100644 index 0000000..16631cc --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/google_vertexai.mdx @@ -0,0 +1,37 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/block_translators/google_vertexai +title: langchain_core.messages.block_translators.google_vertexai +--- + +Derivations of standard content blocks from Google (VertexAI) content. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_register_google_vertexai_translator`](#langchain_core-messages-block_translators-google_vertexai-_register_google_vertexai_translator) | Register the Google (VertexAI) translator with the central registry. | + +### API + + + + + +```python +langchain_core.messages.block_translators.google_vertexai._register_google_vertexai_translator() -> None +``` + + + + + + +Register the Google (VertexAI) translator with the central registry. + +Run automatically when the module is imported. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/groq.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/groq.mdx new file mode 100644 index 0000000..4a718f7 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/groq.mdx @@ -0,0 +1,176 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/block_translators/groq +title: langchain_core.messages.block_translators.groq +--- + +Derivations of standard content blocks from Groq content. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_convert_to_v1_from_groq`](#langchain_core-messages-block_translators-groq-_convert_to_v1_from_groq) | Convert groq message content to v1 format. | +| [`_parse_code_json`](#langchain_core-messages-block_translators-groq-_parse_code_json) | Extract Python code from Groq built-in tool content. | +| [`_populate_extras`](#langchain_core-messages-block_translators-groq-_populate_extras) | Mutate a block, populating extras. | +| [`_register_groq_translator`](#langchain_core-messages-block_translators-groq-_register_groq_translator) | Register the groq translator with the central registry. | +| [`translate_content`](#langchain_core-messages-block_translators-groq-translate_content) | Derive standard content blocks from a message with groq content. | +| [`translate_content_chunk`](#langchain_core-messages-block_translators-groq-translate_content_chunk) | Derive standard content blocks from a message chunk with groq content. | + +### API + + + + + +```python +langchain_core.messages.block_translators.groq._convert_to_v1_from_groq( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert groq message content to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.groq._parse_code_json( + s: str +) -> dict +``` + + + + + + +Extract Python code from Groq built-in tool content. + +Extracts the value of the 'code' field from a string of the form: +{"code": some_arbitrary_text_with_unescaped_quotes} + +As Groq may not escape quotes in the executed tools, e.g.: + + +```python +'{"code": "import math; print("The square root of 101 is: "); print(math.sqrt(101))"}' +``` + + + + + + + + + + +```python +langchain_core.messages.block_translators.groq._populate_extras( + standard_block: langchain_core.messages.content.ContentBlock, + block: dict[str, typing.Any], + known_fields: set[str] +) -> langchain_core.messages.content.ContentBlock +``` + + + + + + +Mutate a block, populating extras. + + + + + + + + +```python +langchain_core.messages.block_translators.groq._register_groq_translator() -> None +``` + + + + + + +Register the groq translator with the central registry. + +Run automatically when the module is imported. + + + + + + + + +```python +langchain_core.messages.block_translators.groq.translate_content( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a message with groq content. + +**Parameters:** + + +The message to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + + + + + + +```python +langchain_core.messages.block_translators.groq.translate_content_chunk( + message: langchain_core.messages.AIMessageChunk +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a message chunk with groq content. + +**Parameters:** + + +The message chunk to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/langchain_v0.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/langchain_v0.mdx new file mode 100644 index 0000000..b0981b5 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/langchain_v0.mdx @@ -0,0 +1,79 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/block_translators/langchain_v0 +title: langchain_core.messages.block_translators.langchain_v0 +--- + +Derivations of standard content blocks from LangChain v0 multimodal content. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_convert_legacy_v0_content_block_to_v1`](#langchain_core-messages-block_translators-langchain_v0-_convert_legacy_v0_content_block_to_v1) | Convert a LangChain v0 content block to v1 format. | +| [`_convert_v0_multimodal_input_to_v1`](#langchain_core-messages-block_translators-langchain_v0-_convert_v0_multimodal_input_to_v1) | Convert v0 multimodal blocks to v1 format. | + +### API + + + + + +```python +langchain_core.messages.block_translators.langchain_v0._convert_legacy_v0_content_block_to_v1( + block: dict +) -> langchain_core.messages.content.ContentBlock | dict +``` + + + + + + +Convert a LangChain v0 content block to v1 format. + +Preserves unknown keys as extras to avoid data loss. + +Returns the original block unchanged if it's not in v0 format. + + + + + + + + +```python +langchain_core.messages.block_translators.langchain_v0._convert_v0_multimodal_input_to_v1( + content: list[langchain_core.messages.content.ContentBlock] +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert v0 multimodal blocks to v1 format. + +During the `content_blocks` parsing process, we wrap blocks not recognized as a v1 +block as a `'non_standard'` block with the original block stored in the `value` +field. This function attempts to unpack those blocks and convert any v0 format +blocks to v1 format. + +If conversion fails, the block is left as a `'non_standard'` block. + +**Parameters:** + + +List of content blocks to process. + + +**Returns:** `list[types.ContentBlock]` + +v1 content blocks. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/openai.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/openai.mdx new file mode 100644 index 0000000..cc84d5a --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/block_translators/openai.mdx @@ -0,0 +1,408 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/block_translators/openai +title: langchain_core.messages.block_translators.openai +--- + +Derivations of standard content blocks from OpenAI content. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_convert_annotation_to_v1`](#langchain_core-messages-block_translators-openai-_convert_annotation_to_v1) | - | +| [`_convert_from_v03_ai_message`](#langchain_core-messages-block_translators-openai-_convert_from_v03_ai_message) | Convert v0 AIMessage into `output_version="responses/v1"` format. | +| [`_convert_from_v1_to_chat_completions`](#langchain_core-messages-block_translators-openai-_convert_from_v1_to_chat_completions) | Convert a v1 message to the Chat Completions format. | +| [`_convert_openai_format_to_data_block`](#langchain_core-messages-block_translators-openai-_convert_openai_format_to_data_block) | Convert OpenAI image/audio/file content block to respective v1 multimodal block. | +| [`_convert_to_v1_from_chat_completions`](#langchain_core-messages-block_translators-openai-_convert_to_v1_from_chat_completions) | Mutate a Chat Completions message to v1 format. | +| [`_convert_to_v1_from_chat_completions_chunk`](#langchain_core-messages-block_translators-openai-_convert_to_v1_from_chat_completions_chunk) | Mutate a Chat Completions chunk to v1 format. | +| [`_convert_to_v1_from_chat_completions_input`](#langchain_core-messages-block_translators-openai-_convert_to_v1_from_chat_completions_input) | Convert OpenAI Chat Completions format blocks to v1 format. | +| [`_convert_to_v1_from_responses`](#langchain_core-messages-block_translators-openai-_convert_to_v1_from_responses) | Convert a Responses message to v1 format. | +| [`_explode_reasoning`](#langchain_core-messages-block_translators-openai-_explode_reasoning) | - | +| [`_register_openai_translator`](#langchain_core-messages-block_translators-openai-_register_openai_translator) | Register the OpenAI translator with the central registry. | +| [`convert_to_openai_data_block`](#langchain_core-messages-block_translators-openai-convert_to_openai_data_block) | Format standard data content block to format expected by OpenAI. | +| [`convert_to_openai_image_block`](#langchain_core-messages-block_translators-openai-convert_to_openai_image_block) | Convert `ImageContentBlock` to format expected by OpenAI Chat Completions. | +| [`translate_content`](#langchain_core-messages-block_translators-openai-translate_content) | Derive standard content blocks from a message with OpenAI content. | +| [`translate_content_chunk`](#langchain_core-messages-block_translators-openai-translate_content_chunk) | Derive standard content blocks from a message chunk with OpenAI content. | + +### Data + +[`_FUNCTION_CALL_IDS_MAP_KEY`](#langchain_core-messages-block_translators-openai-_FUNCTION_CALL_IDS_MAP_KEY) + +### API + + + + + +```python +langchain_core.messages.block_translators.openai._convert_annotation_to_v1( + annotation: dict[str, typing.Any] +) -> langchain_core.messages.content.Annotation +``` + + + + + + + + + + + + + +```python +langchain_core.messages.block_translators.openai._convert_from_v03_ai_message( + message: langchain_core.messages.AIMessage +) -> langchain_core.messages.AIMessage +``` + + + + + + +Convert v0 AIMessage into `output_version="responses/v1"` format. + + + + + + + + +```python +langchain_core.messages.block_translators.openai._convert_from_v1_to_chat_completions( + message: langchain_core.messages.AIMessage +) -> langchain_core.messages.AIMessage +``` + + + + + + +Convert a v1 message to the Chat Completions format. + + + + + + + + +```python +langchain_core.messages.block_translators.openai._convert_openai_format_to_data_block( + block: dict +) -> langchain_core.messages.content.ContentBlock | dict[typing.Any, typing.Any] +``` + + + + + + +Convert OpenAI image/audio/file content block to respective v1 multimodal block. + +We expect that the incoming block is verified to be in OpenAI Chat Completions +format. + +If parsing fails, passes block through unchanged. + +Mappings (Chat Completions to LangChain v1): +- Image -> `ImageContentBlock` +- Audio -> `AudioContentBlock` +- File -> `FileContentBlock` + + + + + + + + +```python +langchain_core.messages.block_translators.openai._convert_to_v1_from_chat_completions( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Mutate a Chat Completions message to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.openai._convert_to_v1_from_chat_completions_chunk( + chunk: langchain_core.messages.AIMessageChunk +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Mutate a Chat Completions chunk to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.openai._convert_to_v1_from_chat_completions_input( + content: list[langchain_core.messages.content.ContentBlock] +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert OpenAI Chat Completions format blocks to v1 format. + +During the `content_blocks` parsing process, we wrap blocks not recognized as a v1 +block as a `'non_standard'` block with the original block stored in the `value` +field. This function attempts to unpack those blocks and convert any blocks that +might be OpenAI format to v1 ContentBlocks. + +If conversion fails, the block is left as a `'non_standard'` block. + +**Parameters:** + + +List of content blocks to process. + + +**Returns:** `list[types.ContentBlock]` + +Updated list with OpenAI blocks converted to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.openai._convert_to_v1_from_responses( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Convert a Responses message to v1 format. + + + + + + + + +```python +langchain_core.messages.block_translators.openai._explode_reasoning( + block: dict[str, typing.Any] +) -> collections.abc.Iterator[langchain_core.messages.content.ReasoningContentBlock] +``` + + + + + + + + + + + + + +```python +langchain_core.messages.block_translators.openai._register_openai_translator() -> None +``` + + + + + + +Register the OpenAI translator with the central registry. + +Run automatically when the module is imported. + + + + + + + + +```python +langchain_core.messages.block_translators.openai.convert_to_openai_data_block( + block: dict, + api: typing.Literal['chat/completions', 'responses'] = 'chat/completions' +) -> dict +``` + + + + + + +Format standard data content block to format expected by OpenAI. + +"Standard data content block" can include old-style LangChain v0 blocks +(URLContentBlock, Base64ContentBlock, IDContentBlock) or new ones. + +**Parameters:** + + +The content block to convert. + + + +The OpenAI API being targeted. Either "chat/completions" or "responses". + + +**Returns:** `dict` + +The formatted content block. + +**Raises:** + +- `ValueError`: If required keys are missing. +- `ValueError`: If file URLs are used with Chat Completions API. +- `ValueError`: If block type is unsupported. + + + + + + + + +```python +langchain_core.messages.block_translators.openai.convert_to_openai_image_block( + block: dict[str, typing.Any] +) -> dict +``` + + + + + + +Convert `ImageContentBlock` to format expected by OpenAI Chat Completions. + +**Parameters:** + + +The image content block to convert. + + +**Returns:** `dict` + +The formatted image content block. + +**Raises:** + +- `ValueError`: If required keys are missing. +- `ValueError`: If source type is unsupported. + + + + + + + + +```python +langchain_core.messages.block_translators.openai.translate_content( + message: langchain_core.messages.AIMessage +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a message with OpenAI content. + +**Parameters:** + + +The message to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + + + + + + +```python +langchain_core.messages.block_translators.openai.translate_content_chunk( + message: langchain_core.messages.AIMessageChunk +) -> list[langchain_core.messages.content.ContentBlock] +``` + + + + + + +Derive standard content blocks from a message chunk with OpenAI content. + +**Parameters:** + + +The message chunk to translate. + + +**Returns:** `list[types.ContentBlock]` + +The derived content blocks. + + + + + + + + +```python +langchain_core.messages.block_translators.openai._FUNCTION_CALL_IDS_MAP_KEY = '__openai_function_call_ids__' +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/chat.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/chat.mdx new file mode 100644 index 0000000..26638c3 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/chat.mdx @@ -0,0 +1,85 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/chat +title: langchain_core.messages.chat +--- + +Chat Message. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ChatMessage`](#langchain_core-messages-chat-ChatMessage) | Message that can be assigned an arbitrary speaker (i.e. role). | +| [`ChatMessageChunk`](#langchain_core-messages-chat-ChatMessageChunk) | Chat Message chunk. | + +### API + + + + + +```python +class langchain_core.messages.chat.ChatMessage() +``` + + + + + + +**Bases:** [BaseMessage](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessage) + +Message that can be assigned an arbitrary speaker (i.e. role). + + + +The speaker / role of the Message. + + + +The type of the message (used during serialization). + + + + + + + +```python +class langchain_core.messages.chat.ChatMessageChunk() +``` + + + + + + +**Bases:** [ChatMessage](#langchain_core-messages-chat-ChatMessage), [BaseMessageChunk](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessageChunk) + +Chat Message chunk. + + + +The type of the message (used during serialization). + + + + + +```python +langchain_core.messages.chat.ChatMessageChunk.__add__( + other: typing.Any +) -> langchain_core.messages.base.BaseMessageChunk +``` + + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/content.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/content.mdx new file mode 100644 index 0000000..d93b316 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/content.mdx @@ -0,0 +1,1941 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/content +title: langchain_core.messages.content +--- + +Standard, multimodal content blocks for Large Language Model I/O. + +This module provides standardized data structures for representing inputs to and outputs +from LLMs. The core abstraction is the **Content Block**, a `TypedDict`. + +**Rationale** + +Different LLM providers use distinct and incompatible API schemas. This module provides +a unified, provider-agnostic format to facilitate these interactions. A message to or +from a model is simply a list of content blocks, allowing for the natural interleaving +of text, images, and other content in a single ordered sequence. + +An adapter for a specific provider is responsible for translating this standard list of +blocks into the format required by its API. + +**Extensibility** + +Data **not yet mapped** to a standard block may be represented using the +`NonStandardContentBlock`, which allows for provider-specific data to be included +without losing the benefits of type checking and validation. + +Furthermore, provider-specific fields **within** a standard block are fully supported +by default in the `extras` field of each block. This allows for additional metadata +to be included without breaking the standard structure. For example, Google's thought +signature: + + + +```python +AIMessage( + content=[ + { + "type": "text", + "text": "J'adore la programmation.", + "extras": {"signature": "EpoWCpc..."}, # Thought signature + } + ], ... +) +``` + + + + +!!! note + + Following widespread adoption of [PEP 728](https://peps.python.org/pep-0728/), we + intend to add `extra_items=Any` as a param to Content Blocks. This will signify to + type checkers that additional provider-specific fields are allowed outside of the + `extras` field, and that will become the new standard approach to adding + provider-specific metadata. + + ??? note + + **Example with PEP 728 provider-specific fields:** + + ```python + # Content block definition + # NOTE: `extra_items=Any` + class TextContentBlock(TypedDict, extra_items=Any): + type: Literal["text"] + id: NotRequired[str] + text: str + annotations: NotRequired[list[Annotation]] + index: NotRequired[int] + ``` + + ```python + from langchain_core.messages.content import TextContentBlock + + # Create a text content block with provider-specific fields + my_block: TextContentBlock = { + # Add required fields + "type": "text", + "text": "Hello, world!", + # Additional fields not specified in the TypedDict + # These are valid with PEP 728 and are typed as Any + "openai_metadata": {"model": "gpt-4", "temperature": 0.7}, + "anthropic_usage": {"input_tokens": 10, "output_tokens": 20}, + "custom_field": "any value", + } + + # Mutating an existing block to add provider-specific fields + openai_data = my_block["openai_metadata"] # Type: Any + ``` + +**Example Usage** + + + +```python +# Direct construction +from langchain_core.messages.content import TextContentBlock, ImageContentBlock + +multimodal_message: AIMessage( + content_blocks=[ + TextContentBlock(type="text", text="What is shown in this image?"), + ImageContentBlock( + type="image", + url="https://www.langchain.com/images/brand/langchain_logo_text_w_white.png", + mime_type="image/png", + ), + ] +) + +# Using factories +from langchain_core.messages.content import create_text_block, create_image_block + +multimodal_message: AIMessage( + content=[ + create_text_block("What is shown in this image?"), + create_image_block( + url="https://www.langchain.com/images/brand/langchain_logo_text_w_white.png", + mime_type="image/png", + ), + ] +) +``` + + + +Factory functions offer benefits such as: + +- Automatic ID generation (when not provided) +- No need to manually specify the `type` field + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AudioContentBlock`](#langchain_core-messages-content-AudioContentBlock) | Audio data. | +| [`Citation`](#langchain_core-messages-content-Citation) | Annotation for citing data from a document. | +| [`FileContentBlock`](#langchain_core-messages-content-FileContentBlock) | File data that doesn't fit into other multimodal block types. | +| [`ImageContentBlock`](#langchain_core-messages-content-ImageContentBlock) | Image data. | +| [`InvalidToolCall`](#langchain_core-messages-content-InvalidToolCall) | Allowance for errors made by LLM. | +| [`NonStandardAnnotation`](#langchain_core-messages-content-NonStandardAnnotation) | Provider-specific annotation format. | +| [`NonStandardContentBlock`](#langchain_core-messages-content-NonStandardContentBlock) | Provider-specific content data. | +| [`PlainTextContentBlock`](#langchain_core-messages-content-PlainTextContentBlock) | Plaintext data (e.g., from a `.txt` or `.md` document). | +| [`ReasoningContentBlock`](#langchain_core-messages-content-ReasoningContentBlock) | Reasoning output from a LLM. | +| [`ServerToolCall`](#langchain_core-messages-content-ServerToolCall) | Tool call that is executed server-side. | +| [`ServerToolCallChunk`](#langchain_core-messages-content-ServerToolCallChunk) | A chunk of a server-side tool call (yielded when streaming). | +| [`ServerToolResult`](#langchain_core-messages-content-ServerToolResult) | Result of a server-side tool call. | +| [`TextContentBlock`](#langchain_core-messages-content-TextContentBlock) | Text output from a LLM. | +| [`ToolCall`](#langchain_core-messages-content-ToolCall) | Represents an AI's request to call a tool. | +| [`ToolCallChunk`](#langchain_core-messages-content-ToolCallChunk) | A chunk of a tool call (yielded when streaming). | +| [`VideoContentBlock`](#langchain_core-messages-content-VideoContentBlock) | Video data. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_data_content_block_types`](#langchain_core-messages-content-_get_data_content_block_types) | Get type literals from DataContentBlock union members dynamically. | +| [`create_audio_block`](#langchain_core-messages-content-create_audio_block) | Create an `AudioContentBlock`. | +| [`create_citation`](#langchain_core-messages-content-create_citation) | Create a `Citation`. | +| [`create_file_block`](#langchain_core-messages-content-create_file_block) | Create a `FileContentBlock`. | +| [`create_image_block`](#langchain_core-messages-content-create_image_block) | Create an `ImageContentBlock`. | +| [`create_non_standard_block`](#langchain_core-messages-content-create_non_standard_block) | Create a `NonStandardContentBlock`. | +| [`create_plaintext_block`](#langchain_core-messages-content-create_plaintext_block) | Create a `PlainTextContentBlock`. | +| [`create_reasoning_block`](#langchain_core-messages-content-create_reasoning_block) | Create a `ReasoningContentBlock`. | +| [`create_text_block`](#langchain_core-messages-content-create_text_block) | Create a `TextContentBlock`. | +| [`create_tool_call`](#langchain_core-messages-content-create_tool_call) | Create a `ToolCall`. | +| [`create_video_block`](#langchain_core-messages-content-create_video_block) | Create a `VideoContentBlock`. | +| [`is_data_content_block`](#langchain_core-messages-content-is_data_content_block) | Check if the provided content block is a data content block. | + +### Data + +[`Annotation`](#langchain_core-messages-content-Annotation) + +[`ContentBlock`](#langchain_core-messages-content-ContentBlock) + +[`DataContentBlock`](#langchain_core-messages-content-DataContentBlock) + +[`KNOWN_BLOCK_TYPES`](#langchain_core-messages-content-KNOWN_BLOCK_TYPES) + +[`ToolContentBlock`](#langchain_core-messages-content-ToolContentBlock) + +### API + + + + + +```python +class langchain_core.messages.content.AudioContentBlock +``` + + + + + + +**Bases:** `typing.TypedDict` + +Audio data. + +!!! note "Factory function" + + `create_audio_block` may also be used as a factory to create an + `AudioContentBlock`. Benefits include: + + * Automatic ID generation (when not provided) + * Required arguments strictly validated at creation time + + +Data as a base64 string. + + + +Provider-specific metadata. This shouldn't be used for the audio data itself. + + + +Reference to the audio file in an external file storage system. + +For example, OpenAI or Anthropic's Files API. + + + +Unique identifier for this content block. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Index of block in aggregate response. Used during streaming. + + + +MIME type of the audio. + +Required for base64 data. + +[Examples from IANA](https://www.iana.org/assignments/media-types/media-types.xhtml#audio) + + + +Type of the content block. Used for discrimination. + + + +URL of the audio. + + + + + + + + +```python +class langchain_core.messages.content.Citation +``` + + + + + + +**Bases:** `typing.TypedDict` + +Annotation for citing data from a document. + +!!! note + + `start`/`end` indices refer to the **response text**, + not the source text. This means that the indices are relative to the model's + response, not the original document (as specified in the `url`). + +!!! note "Factory function" + + `create_citation` may also be used as a factory to create a `Citation`. + Benefits include: + + * Automatic ID generation (when not provided) + * Required arguments strictly validated at creation time + + +Excerpt of source text being cited. + + + +End index of the **response text** (`TextContentBlock.text`) + + + +Provider-specific metadata. + + + +Unique identifier for this content block. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Start index of the **response text** (`TextContentBlock.text`). + + + +Source document title. + +For example, the page title for a web page or the title of a paper. + + + +Type of the content block. Used for discrimination. + + + +URL of the document source. + + + + + + + + +```python +class langchain_core.messages.content.FileContentBlock +``` + + + + + + +**Bases:** `typing.TypedDict` + +File data that doesn't fit into other multimodal block types. + +This block is intended for files that are not images, audio, or plaintext. For +example, it can be used for PDFs, Word documents, etc. + +If the file is an image, audio, or plaintext, you should use the corresponding +content block type (e.g., `ImageContentBlock`, `AudioContentBlock`, +`PlainTextContentBlock`). + +!!! note "Factory function" + + `create_file_block` may also be used as a factory to create a + `FileContentBlock`. Benefits include: + + * Automatic ID generation (when not provided) + * Required arguments strictly validated at creation time + + +Data as a base64 string. + + + +Provider-specific metadata. This shouldn't be used for the file data itself. + + + +Reference to the file in an external file storage system. + +For example, a file ID from OpenAI's Files API or another cloud storage provider. +This is distinct from `id`, which identifies the content block itself. + + + +Unique identifier for this content block. + +Used for tracking and referencing specific blocks (e.g., during streaming). + +Not to be confused with `file_id`, which references an external file in a +storage system. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Index of block in aggregate response. Used during streaming. + + + +MIME type of the file. + +Required for base64 data. + +[Examples from IANA](https://www.iana.org/assignments/media-types/media-types.xhtml) + + + +Type of the content block. Used for discrimination. + + + +URL of the file. + + + + + + + + +```python +class langchain_core.messages.content.ImageContentBlock +``` + + + + + + +**Bases:** `typing.TypedDict` + +Image data. + +!!! note "Factory function" + + `create_image_block` may also be used as a factory to create an + `ImageContentBlock`. Benefits include: + + * Automatic ID generation (when not provided) + * Required arguments strictly validated at creation time + + +Data as a base64 string. + + + +Provider-specific metadata. This shouldn't be used for the image data itself. + + + +Reference to the image in an external file storage system. + +For example, OpenAI or Anthropic's Files API. + + + +Unique identifier for this content block. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Index of block in aggregate response. Used during streaming. + + + +MIME type of the image. + +Required for base64 data. + +[Examples from IANA](https://www.iana.org/assignments/media-types/media-types.xhtml#image) + + + +Type of the content block. Used for discrimination. + + + +URL of the image. + + + + + + + + +```python +class langchain_core.messages.content.InvalidToolCall +``` + + + + + + +**Bases:** `typing.TypedDict` + +Allowance for errors made by LLM. + +Here we add an `error` key to surface errors made during generation +(e.g., invalid JSON arguments.) + + +The arguments to the tool call. + + + +An error message associated with the tool call. + + + +Provider-specific metadata. + + + +An identifier associated with the tool call. + +An identifier is needed to associate a tool call request with a tool +call result in events when multiple concurrent tool calls are made. + + + +Index of block in aggregate response. Used during streaming. + + + +The name of the tool to be called. + + + +Used for discrimination. + + + + + + + + +```python +class langchain_core.messages.content.NonStandardAnnotation +``` + + + + + + +**Bases:** `typing.TypedDict` + +Provider-specific annotation format. + + +Unique identifier for this content block. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Type of the content block. Used for discrimination. + + + +Provider-specific annotation data. + + + + + + + + +```python +class langchain_core.messages.content.NonStandardContentBlock +``` + + + + + + +**Bases:** `typing.TypedDict` + +Provider-specific content data. + +This block contains data for which there is not yet a standard type. + +The purpose of this block should be to simply hold a provider-specific payload. +If a provider's non-standard output includes reasoning and tool calls, it should be +the adapter's job to parse that payload and emit the corresponding standard +`ReasoningContentBlock` and `ToolCalls`. + +Has no `extras` field, as provider-specific data should be included in the +`value` field. + +!!! note "Factory function" + + `create_non_standard_block` may also be used as a factory to create a + `NonStandardContentBlock`. Benefits include: + + * Automatic ID generation (when not provided) + * Required arguments strictly validated at creation time + + +Unique identifier for this content block. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Index of block in aggregate response. Used during streaming. + + + +Type of the content block. Used for discrimination. + + + +Provider-specific content data. + + + + + + + + +```python +class langchain_core.messages.content.PlainTextContentBlock +``` + + + + + + +**Bases:** `typing.TypedDict` + +Plaintext data (e.g., from a `.txt` or `.md` document). + +!!! note + + A `PlainTextContentBlock` existed in `langchain-core<1.0.0`. Although the + name has carried over, the structure has changed significantly. The only shared + keys between the old and new versions are `type` and `text`, though the + `type` value has changed from `'text'` to `'text-plain'`. + +!!! note + + Title and context are optional fields that may be passed to the model. See + Anthropic [example](https://platform.claude.com/docs/en/build-with-claude/citations#citable-vs-non-citable-content). + +!!! note "Factory function" + + `create_plaintext_block` may also be used as a factory to create a + `PlainTextContentBlock`. Benefits include: + + * Automatic ID generation (when not provided) + * Required arguments strictly validated at creation time + + +Data as a base64 string. + + + +Context for the text, e.g., a description or summary of the text's content. + + + +Provider-specific metadata. This shouldn't be used for the data itself. + + + +Reference to the plaintext file in an external file storage system. + +For example, OpenAI or Anthropic's Files API. + + + +Unique identifier for this content block. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Index of block in aggregate response. Used during streaming. + + + +MIME type of the file. + +Required for base64 data. + + + +Plaintext content. This is optional if the data is provided as base64. + + + +Title of the text data, e.g., the title of a document. + + + +Type of the content block. Used for discrimination. + + + +URL of the plaintext. + + + + + + + + +```python +class langchain_core.messages.content.ReasoningContentBlock +``` + + + + + + +**Bases:** `typing.TypedDict` + +Reasoning output from a LLM. + +!!! note "Factory function" + + `create_reasoning_block` may also be used as a factory to create a + `ReasoningContentBlock`. Benefits include: + + * Automatic ID generation (when not provided) + * Required arguments strictly validated at creation time + + +Provider-specific metadata. + + + +Unique identifier for this content block. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Index of block in aggregate response. Used during streaming. + + + +Reasoning text. + +Either the thought summary or the raw reasoning text itself. + +Often parsed from `<think>` tags in the model's response. + + + +Type of the content block. Used for discrimination. + + + + + + + + +```python +class langchain_core.messages.content.ServerToolCall +``` + + + + + + +**Bases:** `typing.TypedDict` + +Tool call that is executed server-side. + +For example: code execution, web search, etc. + + +The arguments to the tool call. + + + +Provider-specific metadata. + + + +An identifier associated with the tool call. + + + +Index of block in aggregate response. Used during streaming. + + + +The name of the tool to be called. + + + +Used for discrimination. + + + + + + + + +```python +class langchain_core.messages.content.ServerToolCallChunk +``` + + + + + + +**Bases:** `typing.TypedDict` + +A chunk of a server-side tool call (yielded when streaming). + + +JSON substring of the arguments to the tool call. + + + +Provider-specific metadata. + + + +Unique identifier for this server tool call chunk. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Index of block in aggregate response. Used during streaming. + + + +The name of the tool to be called. + + + +Used for discrimination. + + + + + + + + +```python +class langchain_core.messages.content.ServerToolResult +``` + + + + + + +**Bases:** `typing.TypedDict` + +Result of a server-side tool call. + + +Provider-specific metadata. + + + +Unique identifier for this server tool result. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Index of block in aggregate response. Used during streaming. + + + +Output of the executed tool. + + + +Execution status of the server-side tool. + + + +ID of the corresponding server tool call. + + + +Used for discrimination. + + + + + + + + +```python +class langchain_core.messages.content.TextContentBlock +``` + + + + + + +**Bases:** `typing.TypedDict` + +Text output from a LLM. + +This typically represents the main text content of a message, such as the response +from a language model or the text of a user message. + +!!! note "Factory function" + + `create_text_block` may also be used as a factory to create a + `TextContentBlock`. Benefits include: + + * Automatic ID generation (when not provided) + * Required arguments strictly validated at creation time + + +`Citation`s and other annotations. + + + +Provider-specific metadata. + + + +Unique identifier for this content block. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Index of block in aggregate response. Used during streaming. + + + +Block text. + + + +Type of the content block. Used for discrimination. + + + + + + + + +```python +class langchain_core.messages.content.ToolCall +``` + + + + + + +**Bases:** `typing.TypedDict` + +Represents an AI's request to call a tool. + +!!! note "Factory function" + + `create_tool_call` may also be used as a factory to create a + `ToolCall`. Benefits include: + + * Automatic ID generation (when not provided) + * Required arguments strictly validated at creation time + + +The arguments to the tool call. + + + +Provider-specific metadata. + + + +An identifier associated with the tool call. + +An identifier is needed to associate a tool call request with a tool +call result in events when multiple concurrent tool calls are made. + + + +Index of block in aggregate response. Used during streaming. + + + +The name of the tool to be called. + + + +Used for discrimination. + + + + + + + + +```python +class langchain_core.messages.content.ToolCallChunk +``` + + + + + + +**Bases:** `typing.TypedDict` + +A chunk of a tool call (yielded when streaming). + +When merging `ToolCallChunks` (e.g., via `AIMessageChunk.__add__`), +all string attributes are concatenated. Chunks are only merged if their +values of `index` are equal and not `None`. + +Example: + + +```python +left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)] +right_chunks = [ToolCallChunk(name=None, args="1}", index=0)] + +( + AIMessageChunk(content="", tool_call_chunks=left_chunks) + + AIMessageChunk(content="", tool_call_chunks=right_chunks) +).tool_call_chunks == [ToolCallChunk(name="foo", args='{"a":1}', index=0)] +``` + + + + +The arguments to the tool call. + + + +Provider-specific metadata. + + + +An identifier associated with the tool call. + +An identifier is needed to associate a tool call request with a tool +call result in events when multiple concurrent tool calls are made. + + + +The index of the tool call in a sequence. + + + +The name of the tool to be called. + + + +Used for serialization. + + + + + + + + +```python +class langchain_core.messages.content.VideoContentBlock +``` + + + + + + +**Bases:** `typing.TypedDict` + +Video data. + +!!! note "Factory function" + + `create_video_block` may also be used as a factory to create a + `VideoContentBlock`. Benefits include: + + * Automatic ID generation (when not provided) + * Required arguments strictly validated at creation time + + +Data as a base64 string. + + + +Provider-specific metadata. This shouldn't be used for the video data itself. + + + +Reference to the video in an external file storage system. + +For example, OpenAI or Anthropic's Files API. + + + +Unique identifier for this content block. + +Either: + +- Generated by the provider +- Generated by LangChain upon creation (`UUID4` prefixed with `'lc_'`)) + + + +Index of block in aggregate response. Used during streaming. + + + +MIME type of the video. + +Required for base64 data. + +[Examples from IANA](https://www.iana.org/assignments/media-types/media-types.xhtml#video) + + + +Type of the content block. Used for discrimination. + + + +URL of the video. + + + + + + + + +```python +langchain_core.messages.content._get_data_content_block_types() -> tuple[str, ...] +``` + + + + + + +Get type literals from DataContentBlock union members dynamically. + +Example: ("image", "video", "audio", "text-plain", "file") + +Note that old style multimodal blocks type literals with new style blocks. +Specifically, "image", "audio", and "file". + +See the docstring of `_normalize_messages` in `language_models._utils` for details. + + + + + + + + +```python +langchain_core.messages.content.create_audio_block( + url: str | None = None, + base64: str | None = None, + file_id: str | None = None, + mime_type: str | None = None, + id: str | None = None, + index: int | str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.content.AudioContentBlock +``` + + + + + + +Create an `AudioContentBlock`. + +!!! note + + The `id` is generated automatically if not provided, using a UUID4 format + prefixed with `'lc_'` to indicate it is a LangChain-generated ID. + +**Parameters:** + + +URL of the audio. + + + +Base64-encoded audio data. + + + +ID of the audio file from a file storage system. + + + +MIME type of the audio. + +Required for base64 data. + + + +Content block identifier. + +Generated automatically if not provided. + + + +Index of block in aggregate response. + +Used during streaming. + + +**Returns:** `AudioContentBlock` + +A properly formatted `AudioContentBlock`. + +**Raises:** + +- `ValueError`: If no audio source is provided or if `base64` is used without +`mime_type`. + + + + + + + + +```python +langchain_core.messages.content.create_citation( + url: str | None = None, + title: str | None = None, + start_index: int | None = None, + end_index: int | None = None, + cited_text: str | None = None, + id: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.content.Citation +``` + + + + + + +Create a `Citation`. + +!!! note + + The `id` is generated automatically if not provided, using a UUID4 format + prefixed with `'lc_'` to indicate it is a LangChain-generated ID. + +**Parameters:** + + +URL of the document source. + + + +Source document title. + + + +Start index in the response text where citation applies. + + + +End index in the response text where citation applies. + + + +Excerpt of source text being cited. + + + +Content block identifier. + +Generated automatically if not provided. + + +**Returns:** `Citation` + +A properly formatted `Citation`. + + + + + + + + +```python +langchain_core.messages.content.create_file_block( + url: str | None = None, + base64: str | None = None, + file_id: str | None = None, + mime_type: str | None = None, + id: str | None = None, + index: int | str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.content.FileContentBlock +``` + + + + + + +Create a `FileContentBlock`. + +!!! note + + The `id` is generated automatically if not provided, using a UUID4 format + prefixed with `'lc_'` to indicate it is a LangChain-generated ID. + +**Parameters:** + + +URL of the file. + + + +Base64-encoded file data. + + + +ID of the file from a file storage system. + + + +MIME type of the file. + +Required for base64 data. + + + +Content block identifier. + +Generated automatically if not provided. + + + +Index of block in aggregate response. + +Used during streaming. + + +**Returns:** `FileContentBlock` + +A properly formatted `FileContentBlock`. + +**Raises:** + +- `ValueError`: If no file source is provided or if `base64` is used without +`mime_type`. + + + + + + + + +```python +langchain_core.messages.content.create_image_block( + url: str | None = None, + base64: str | None = None, + file_id: str | None = None, + mime_type: str | None = None, + id: str | None = None, + index: int | str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.content.ImageContentBlock +``` + + + + + + +Create an `ImageContentBlock`. + +!!! note + + The `id` is generated automatically if not provided, using a UUID4 format + prefixed with `'lc_'` to indicate it is a LangChain-generated ID. + +**Parameters:** + + +URL of the image. + + + +Base64-encoded image data. + + + +ID of the image file from a file storage system. + + + +MIME type of the image. + +Required for base64 data. + + + +Content block identifier. + +Generated automatically if not provided. + + + +Index of block in aggregate response. + +Used during streaming. + + +**Returns:** `ImageContentBlock` + +A properly formatted `ImageContentBlock`. + +**Raises:** + +- `ValueError`: If no image source is provided or if `base64` is used without +`mime_type`. + + + + + + + + +```python +langchain_core.messages.content.create_non_standard_block( + value: dict[str, typing.Any], + id: str | None = None, + index: int | str | None = None +) -> langchain_core.messages.content.NonStandardContentBlock +``` + + + + + + +Create a `NonStandardContentBlock`. + +!!! note + + The `id` is generated automatically if not provided, using a UUID4 format + prefixed with `'lc_'` to indicate it is a LangChain-generated ID. + +**Parameters:** + + +Provider-specific content data. + + + +Content block identifier. + +Generated automatically if not provided. + + + +Index of block in aggregate response. + +Used during streaming. + + +**Returns:** `NonStandardContentBlock` + +A properly formatted `NonStandardContentBlock`. + + + + + + + + +```python +langchain_core.messages.content.create_plaintext_block( + text: str | None = None, + url: str | None = None, + base64: str | None = None, + file_id: str | None = None, + title: str | None = None, + context: str | None = None, + id: str | None = None, + index: int | str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.content.PlainTextContentBlock +``` + + + + + + +Create a `PlainTextContentBlock`. + +!!! note + + The `id` is generated automatically if not provided, using a UUID4 format + prefixed with `'lc_'` to indicate it is a LangChain-generated ID. + +**Parameters:** + + +The plaintext content. + + + +URL of the plaintext file. + + + +Base64-encoded plaintext data. + + + +ID of the plaintext file from a file storage system. + + + +Title of the text data. + + + +Context or description of the text content. + + + +Content block identifier. + +Generated automatically if not provided. + + + +Index of block in aggregate response. + +Used during streaming. + + +**Returns:** `PlainTextContentBlock` + +A properly formatted `PlainTextContentBlock`. + + + + + + + + +```python +langchain_core.messages.content.create_reasoning_block( + reasoning: str | None = None, + id: str | None = None, + index: int | str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.content.ReasoningContentBlock +``` + + + + + + +Create a `ReasoningContentBlock`. + +!!! note + + The `id` is generated automatically if not provided, using a UUID4 format + prefixed with `'lc_'` to indicate it is a LangChain-generated ID. + +**Parameters:** + + +The reasoning text or thought summary. + + + +Content block identifier. + +Generated automatically if not provided. + + + +Index of block in aggregate response. + +Used during streaming. + + +**Returns:** `ReasoningContentBlock` + +A properly formatted `ReasoningContentBlock`. + + + + + + + + +```python +langchain_core.messages.content.create_text_block( + text: str, + id: str | None = None, + annotations: list[langchain_core.messages.content.Annotation] | None = None, + index: int | str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.content.TextContentBlock +``` + + + + + + +Create a `TextContentBlock`. + +!!! note + + The `id` is generated automatically if not provided, using a UUID4 format + prefixed with `'lc_'` to indicate it is a LangChain-generated ID. + +**Parameters:** + + +The text content of the block. + + + +Content block identifier. + +Generated automatically if not provided. + + + +`Citation`s and other annotations for the text. + + + +Index of block in aggregate response. + +Used during streaming. + + +**Returns:** `TextContentBlock` + +A properly formatted `TextContentBlock`. + + + + + + + + +```python +langchain_core.messages.content.create_tool_call( + name: str, + args: dict[str, typing.Any], + id: str | None = None, + index: int | str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.content.ToolCall +``` + + + + + + +Create a `ToolCall`. + +!!! note + + The `id` is generated automatically if not provided, using a UUID4 format + prefixed with `'lc_'` to indicate it is a LangChain-generated ID. + +**Parameters:** + + +The name of the tool to be called. + + + +The arguments to the tool call. + + + +An identifier for the tool call. + +Generated automatically if not provided. + + + +Index of block in aggregate response. + +Used during streaming. + + +**Returns:** `ToolCall` + +A properly formatted `ToolCall`. + + + + + + + + +```python +langchain_core.messages.content.create_video_block( + url: str | None = None, + base64: str | None = None, + file_id: str | None = None, + mime_type: str | None = None, + id: str | None = None, + index: int | str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.messages.content.VideoContentBlock +``` + + + + + + +Create a `VideoContentBlock`. + +!!! note + + The `id` is generated automatically if not provided, using a UUID4 format + prefixed with `'lc_'` to indicate it is a LangChain-generated ID. + +**Parameters:** + + +URL of the video. + + + +Base64-encoded video data. + + + +ID of the video file from a file storage system. + + + +MIME type of the video. + +Required for base64 data. + + + +Content block identifier. + +Generated automatically if not provided. + + + +Index of block in aggregate response. + +Used during streaming. + + +**Returns:** `VideoContentBlock` + +A properly formatted `VideoContentBlock`. + +**Raises:** + +- `ValueError`: If no video source is provided or if `base64` is used without +`mime_type`. + + + + + + + + +```python +langchain_core.messages.content.is_data_content_block( + block: dict +) -> bool +``` + + + + + + +Check if the provided content block is a data content block. + +Returns True for both v0 (old-style) and v1 (new-style) multimodal data blocks. + +**Parameters:** + + +The content block to check. + + +**Returns:** `bool` + +`True` if the content block is a data content block, `False` otherwise. + + + + + + + + +```python +langchain_core.messages.content.Annotation = Citation | NonStandardAnnotation +``` + + + + + + +A union of all defined `Annotation` types. + + + + + + + +```python +langchain_core.messages.content.ContentBlock = TextContentBlock | InvalidToolCall | ReasoningContentBlock | NonStandardContentB... +``` + + + + + + +A union of all defined `ContentBlock` types and aliases. + + + + + + + +```python +langchain_core.messages.content.DataContentBlock = ImageContentBlock | VideoContentBlock | AudioContentBlock | PlainTextContentBloc... +``` + + + + + + +A union of all defined multimodal data `ContentBlock` types. + + + + + + + +```python +langchain_core.messages.content.KNOWN_BLOCK_TYPES = {'text', 'reasoning', 'tool_call', 'invalid_tool_call', 'tool_call_chunk', 'imag... +``` + + + + + + +These are block types known to `langchain-core >= 1.0.0`. + +If a block has a type not in this set, it is considered to be provider-specific. + + + + + + + +```python +langchain_core.messages.content.ToolContentBlock = ToolCall | ToolCallChunk | ServerToolCall | ServerToolCallChunk | ServerToolResu... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/function.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/function.mdx new file mode 100644 index 0000000..4cae689 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/function.mdx @@ -0,0 +1,92 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/function +title: langchain_core.messages.function +--- + +Function Message. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FunctionMessage`](#langchain_core-messages-function-FunctionMessage) | Message for passing the result of executing a tool back to a model. | +| [`FunctionMessageChunk`](#langchain_core-messages-function-FunctionMessageChunk) | Function Message chunk. | + +### API + + + + + +```python +class langchain_core.messages.function.FunctionMessage() +``` + + + + + + +**Bases:** [BaseMessage](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessage) + +Message for passing the result of executing a tool back to a model. + +`FunctionMessage` are an older version of the `ToolMessage` schema, and +do not contain the `tool_call_id` field. + +The `tool_call_id` field is used to associate the tool call request with the +tool call response. Useful in situations where a chat model is able +to request multiple tool calls in parallel. + + + +The name of the function that was executed. + + + +The type of the message (used for serialization). + + + + + + + +```python +class langchain_core.messages.function.FunctionMessageChunk() +``` + + + + + + +**Bases:** [FunctionMessage](#langchain_core-messages-function-FunctionMessage), [BaseMessageChunk](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessageChunk) + +Function Message chunk. + + + +The type of the message (used for serialization). + + + + + +```python +langchain_core.messages.function.FunctionMessageChunk.__add__( + other: typing.Any +) -> langchain_core.messages.base.BaseMessageChunk +``` + + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/human.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/human.mdx new file mode 100644 index 0000000..33c100c --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/human.mdx @@ -0,0 +1,70 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/human +title: langchain_core.messages.human +--- + +Human message. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`HumanMessage`](#langchain_core-messages-human-HumanMessage) | Message from the user. | +| [`HumanMessageChunk`](#langchain_core-messages-human-HumanMessageChunk) | Human Message chunk. | + +### API + + + + + +```python +class langchain_core.messages.human.HumanMessage( + content: str | list[str | dict] | None = None, + content_blocks: list[langchain_core.messages.content.ContentBlock] | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseMessage](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessage) + +Message from the user. + +A `HumanMessage` is a message that is passed in from a user to the model. + + + +The type of the message (used for serialization). + + + + + + + +```python +class langchain_core.messages.human.HumanMessageChunk() +``` + + + + + + +**Bases:** [HumanMessage](#langchain_core-messages-human-HumanMessage), [BaseMessageChunk](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessageChunk) + +Human Message chunk. + + + +The type of the message (used for serialization). + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/modifier.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/modifier.mdx new file mode 100644 index 0000000..a3e9ed7 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/modifier.mdx @@ -0,0 +1,43 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/modifier +title: langchain_core.messages.modifier +--- + +Message responsible for deleting other messages. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RemoveMessage`](#langchain_core-messages-modifier-RemoveMessage) | Message responsible for deleting other messages. | + +### API + + + + + +```python +class langchain_core.messages.modifier.RemoveMessage( + id: str, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseMessage](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessage) + +Message responsible for deleting other messages. + + + +The type of the message (used for serialization). + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/system.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/system.mdx new file mode 100644 index 0000000..244fe57 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/system.mdx @@ -0,0 +1,71 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/system +title: langchain_core.messages.system +--- + +System message. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`SystemMessage`](#langchain_core-messages-system-SystemMessage) | Message for priming AI behavior. | +| [`SystemMessageChunk`](#langchain_core-messages-system-SystemMessageChunk) | System Message chunk. | + +### API + + + + + +```python +class langchain_core.messages.system.SystemMessage( + content: str | list[str | dict] | None = None, + content_blocks: list[langchain_core.messages.content.ContentBlock] | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseMessage](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessage) + +Message for priming AI behavior. + +The system message is usually passed in as the first of a sequence +of input messages. + + + +The type of the message (used for serialization). + + + + + + + +```python +class langchain_core.messages.system.SystemMessageChunk() +``` + + + + + + +**Bases:** [SystemMessage](#langchain_core-messages-system-SystemMessage), [BaseMessageChunk](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessageChunk) + +System Message chunk. + + + +The type of the message (used for serialization). + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/tool.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/tool.mdx new file mode 100644 index 0000000..6a04dfb --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/tool.mdx @@ -0,0 +1,495 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/tool +title: langchain_core.messages.tool +--- + +Messages for tools. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ToolCall`](#langchain_core-messages-tool-ToolCall) | Represents an AI's request to call a tool. | +| [`ToolCallChunk`](#langchain_core-messages-tool-ToolCallChunk) | A chunk of a tool call (yielded when streaming). | +| [`ToolMessage`](#langchain_core-messages-tool-ToolMessage) | Message for passing the result of executing a tool back to a model. | +| [`ToolMessageChunk`](#langchain_core-messages-tool-ToolMessageChunk) | Tool Message chunk. | +| [`ToolOutputMixin`](#langchain_core-messages-tool-ToolOutputMixin) | Mixin for objects that tools can return directly. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_merge_status`](#langchain_core-messages-tool-_merge_status) | - | +| [`default_tool_chunk_parser`](#langchain_core-messages-tool-default_tool_chunk_parser) | Best-effort parsing of tool chunks. | +| [`default_tool_parser`](#langchain_core-messages-tool-default_tool_parser) | Best-effort parsing of tools. | +| [`invalid_tool_call`](#langchain_core-messages-tool-invalid_tool_call) | Create an invalid tool call. | +| [`tool_call`](#langchain_core-messages-tool-tool_call) | Create a tool call. | +| [`tool_call_chunk`](#langchain_core-messages-tool-tool_call_chunk) | Create a tool call chunk. | + +### API + + + + + +```python +class langchain_core.messages.tool.ToolCall +``` + + + + + + +**Bases:** `typing.TypedDict` + +Represents an AI's request to call a tool. + +!!! note "Factory function" + + `tool_call` may also be used as a factory to create a `ToolCall`. Benefits + include: + + * Required arguments strictly validated at creation time + + +The arguments to the tool call as a dictionary. + + + +An identifier associated with the tool call. + +An identifier is needed to associate a tool call request with a tool +call result in events when multiple concurrent tool calls are made. + + + +The name of the tool to be called. + + + +Used for discrimination. + + + + + + + + +```python +class langchain_core.messages.tool.ToolCallChunk +``` + + + + + + +**Bases:** `typing.TypedDict` + +A chunk of a tool call (yielded when streaming). + +When merging `ToolCallChunk` objects (e.g., via `AIMessageChunk.__add__`), all +string attributes are concatenated. Chunks are only merged if their values of +`index` are equal and not `None`. + +Example: + + +```python +left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)] +right_chunks = [ToolCallChunk(name=None, args="1}", index=0)] + +( + AIMessageChunk(content="", tool_call_chunks=left_chunks) + + AIMessageChunk(content="", tool_call_chunks=right_chunks) +).tool_call_chunks == [ToolCallChunk(name="foo", args='{"a":1}', index=0)] +``` + + + + +The arguments to the tool call as a JSON-parseable string. + + + +An identifier associated with the tool call. + +An identifier is needed to associate a tool call request with a tool +call result in events when multiple concurrent tool calls are made. + + + +The index of the tool call in a sequence. + +Used for merging chunks. + + + +The name of the tool to be called. + + + +Used for discrimination. + + + + + + + + +```python +class langchain_core.messages.tool.ToolMessage( + content: str | list[str | dict] | None = None, + content_blocks: list[langchain_core.messages.content.ContentBlock] | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseMessage](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessage), [ToolOutputMixin](#langchain_core-messages-tool-ToolOutputMixin) + +Message for passing the result of executing a tool back to a model. + +`ToolMessage` objects contain the result of a tool invocation. Typically, the result +is encoded inside the `content` field. + +`tool_call_id` is used to associate the tool call request with the tool call +response. Useful in situations where a chat model is able to request multiple tool +calls in parallel. + + + +Currently inherited from `BaseMessage`, but not used. + + + +Artifact of the Tool execution which is not meant to be sent to the model. + +Should only be specified if it is different from the message content, e.g. if only +a subset of the full tool output is being passed as message content but the full +output is needed in other parts of the code. + + + +Currently inherited from `BaseMessage`, but not used. + + + +Status of the tool invocation. + + + +Tool call that this message is responding to. + + + +The type of the message (used for serialization). + + + + + +```python +langchain_core.messages.tool.ToolMessage.coerce_args( + values: dict +) -> dict +``` + + + + + + +classmethod + +Coerce the model arguments to the correct types. + +**Parameters:** + + +The model arguments. + + + + + + + + + + +```python +class langchain_core.messages.tool.ToolMessageChunk() +``` + + + + + + +**Bases:** [ToolMessage](#langchain_core-messages-tool-ToolMessage), [BaseMessageChunk](/langchain-core/langchain_core/messages/base#langchain_core-messages-base-BaseMessageChunk) + +Tool Message chunk. + + + + + + + + +```python +langchain_core.messages.tool.ToolMessageChunk.__add__( + other: typing.Any +) -> langchain_core.messages.base.BaseMessageChunk +``` + + + + + + + + + + + + + + +```python +class langchain_core.messages.tool.ToolOutputMixin() +``` + + + + + + +Mixin for objects that tools can return directly. + +If a custom BaseTool is invoked with a `ToolCall` and the output of custom code is +not an instance of `ToolOutputMixin`, the output will automatically be coerced to +a string and wrapped in a `ToolMessage`. + + + + + + + + +```python +langchain_core.messages.tool._merge_status( + left: typing.Literal['success', 'error'], + right: typing.Literal['success', 'error'] +) -> typing.Literal['success', 'error'] +``` + + + + + + + + + + + + + +```python +langchain_core.messages.tool.default_tool_chunk_parser( + raw_tool_calls: list[dict] +) -> list[langchain_core.messages.tool.ToolCallChunk] +``` + + + + + + +Best-effort parsing of tool chunks. + +**Parameters:** + + +List of raw tool call dicts to parse. + + +**Returns:** `list[ToolCallChunk]` + +List of parsed ToolCallChunk objects. + + + + + + + + +```python +langchain_core.messages.tool.default_tool_parser( + raw_tool_calls: list[dict] +) -> tuple[list[langchain_core.messages.tool.ToolCall], list[langchain_core.messages.content.InvalidToolCall]] +``` + + + + + + +Best-effort parsing of tools. + +**Parameters:** + + +List of raw tool call dicts to parse. + + +**Returns:** `tuple[list[ToolCall], list[InvalidToolCall]]` + +A list of tool calls and invalid tool calls. + + + + + + + + +```python +langchain_core.messages.tool.invalid_tool_call( + name: str | None = None, + args: str | None = None, + id: str | None = None, + error: str | None = None +) -> langchain_core.messages.content.InvalidToolCall +``` + + + + + + +Create an invalid tool call. + +**Parameters:** + + +The name of the tool to be called. + + + +The arguments to the tool call as a JSON string. + + + +An identifier associated with the tool call. + + + +An error message associated with the tool call. + + +**Returns:** `InvalidToolCall` + +The created invalid tool call. + + + + + + + + +```python +langchain_core.messages.tool.tool_call( + name: str, + args: dict[str, typing.Any], + id: str | None +) -> langchain_core.messages.tool.ToolCall +``` + + + + + + +Create a tool call. + +**Parameters:** + + +The name of the tool to be called. + + + +The arguments to the tool call as a dictionary. + + + +An identifier associated with the tool call. + + +**Returns:** `ToolCall` + +The created tool call. + + + + + + + + +```python +langchain_core.messages.tool.tool_call_chunk( + name: str | None = None, + args: str | None = None, + id: str | None = None, + index: int | None = None +) -> langchain_core.messages.tool.ToolCallChunk +``` + + + + + + +Create a tool call chunk. + +**Parameters:** + + +The name of the tool to be called. + + + +The arguments to the tool call as a JSON string. + + + +An identifier associated with the tool call. + + + +The index of the tool call in a sequence. + + +**Returns:** `ToolCallChunk` + +The created tool call chunk. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/utils.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/utils.mdx new file mode 100644 index 0000000..f4da485 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/messages/utils.mdx @@ -0,0 +1,1401 @@ +--- +layout: overview +slug: langchain-core/langchain_core/messages/utils +title: langchain_core.messages.utils +--- + +Module contains utility functions for working with messages. + +Some examples of what you can do with these functions include: + +* Convert messages to strings (serialization) +* Convert messages from dicts to Message objects (deserialization) +* Filter messages from a list of messages based on name, type or id etc. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`_RunnableSupportCallable`](#langchain_core-messages-utils-_RunnableSupportCallable) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_approximate_token_counter`](#langchain_core-messages-utils-_approximate_token_counter) | Wrapper for `count_tokens_approximately` that matches expected signature. | +| [`_bytes_to_b64_str`](#langchain_core-messages-utils-_bytes_to_b64_str) | - | +| [`_chunk_to_msg`](#langchain_core-messages-utils-_chunk_to_msg) | - | +| [`_convert_to_message`](#langchain_core-messages-utils-_convert_to_message) | Instantiate a `Message` from a variety of message formats. | +| [`_convert_to_openai_tool_calls`](#langchain_core-messages-utils-_convert_to_openai_tool_calls) | - | +| [`_create_message_from_message_type`](#langchain_core-messages-utils-_create_message_from_message_type) | Create a message from a `Message` type and content string. | +| [`_default_text_splitter`](#langchain_core-messages-utils-_default_text_splitter) | - | +| [`_first_max_tokens`](#langchain_core-messages-utils-_first_max_tokens) | - | +| [`_format_content_block_xml`](#langchain_core-messages-utils-_format_content_block_xml) | Format a content block as XML. | +| [`_get_message_openai_role`](#langchain_core-messages-utils-_get_message_openai_role) | - | +| [`_get_message_type_str`](#langchain_core-messages-utils-_get_message_type_str) | Get the type string for XML message element. | +| [`_get_type`](#langchain_core-messages-utils-_get_type) | Get the type associated with the object for serialization purposes. | +| [`_has_base64_data`](#langchain_core-messages-utils-_has_base64_data) | Check if a content block contains base64 encoded data. | +| [`_is_message_type`](#langchain_core-messages-utils-_is_message_type) | - | +| [`_last_max_tokens`](#langchain_core-messages-utils-_last_max_tokens) | - | +| [`_message_from_dict`](#langchain_core-messages-utils-_message_from_dict) | - | +| [`_msg_to_chunk`](#langchain_core-messages-utils-_msg_to_chunk) | - | +| [`_runnable_support`](#langchain_core-messages-utils-_runnable_support) | - | +| [`_truncate`](#langchain_core-messages-utils-_truncate) | Truncate text to `max_len` characters, adding ellipsis if truncated. | +| [`convert_to_messages`](#langchain_core-messages-utils-convert_to_messages) | Convert a sequence of messages to a list of messages. | +| [`convert_to_openai_messages`](#langchain_core-messages-utils-convert_to_openai_messages) | Convert LangChain messages into OpenAI message dicts. | +| [`count_tokens_approximately`](#langchain_core-messages-utils-count_tokens_approximately) | Approximate the total number of tokens in messages. | +| [`filter_messages`](#langchain_core-messages-utils-filter_messages) | Filter messages based on `name`, `type` or `id`. | +| [`get_buffer_string`](#langchain_core-messages-utils-get_buffer_string) | Convert a sequence of messages to strings and concatenate them into one string. | +| [`merge_message_runs`](#langchain_core-messages-utils-merge_message_runs) | Merge consecutive Messages of the same type. | +| [`message_chunk_to_message`](#langchain_core-messages-utils-message_chunk_to_message) | Convert a message chunk to a `Message`. | +| [`messages_from_dict`](#langchain_core-messages-utils-messages_from_dict) | Convert a sequence of messages from dicts to `Message` objects. | +| [`trim_messages`](#langchain_core-messages-utils-trim_messages) | Trim messages to be below a token count. | + +### Data + +[`AnyMessage`](#langchain_core-messages-utils-AnyMessage) + +[`MessageLikeRepresentation`](#langchain_core-messages-utils-MessageLikeRepresentation) + +[`_CHUNK_MSG_MAP`](#langchain_core-messages-utils-_CHUNK_MSG_MAP) + +[`_HAS_LANGCHAIN_TEXT_SPLITTERS`](#langchain_core-messages-utils-_HAS_LANGCHAIN_TEXT_SPLITTERS) + +[`_MSG_CHUNK_MAP`](#langchain_core-messages-utils-_MSG_CHUNK_MAP) + +[`_MultipleMessages`](#langchain_core-messages-utils-_MultipleMessages) + +[`_P`](#langchain_core-messages-utils-_P) + +[`_R_co`](#langchain_core-messages-utils-_R_co) + +[`_SingleMessage`](#langchain_core-messages-utils-_SingleMessage) + +[`_T`](#langchain_core-messages-utils-_T) + +[`_TOKEN_COUNTER_SHORTCUTS`](#langchain_core-messages-utils-_TOKEN_COUNTER_SHORTCUTS) + +[`_XML_CONTENT_BLOCK_MAX_LEN`](#langchain_core-messages-utils-_XML_CONTENT_BLOCK_MAX_LEN) + +[`logger`](#langchain_core-messages-utils-logger) + +### API + + + + + +```python +class langchain_core.messages.utils._RunnableSupportCallable() +``` + + + + + + +Protocol + +**Bases:** `Protocol[_P, _R_co]` + + + + + +```python +langchain_core.messages.utils._RunnableSupportCallable.__call__( + messages: collections.abc.Sequence[langchain_core.messages.utils.MessageLikeRepresentation] | langchain_core.prompt_values.PromptValue | None = None, + args: langchain_core.messages.utils._P.args = (), + kwargs: langchain_core.messages.utils._P.kwargs = {} +) -> langchain_core.messages.utils._R_co | langchain_core.runnables.base.Runnable[collections.abc.Sequence[langchain_core.messages.utils.MessageLikeRepresentation], langchain_core.messages.utils._R_co] +``` + + + + + + + + + + + + + + +```python +langchain_core.messages.utils._approximate_token_counter( + messages: collections.abc.Sequence[langchain_core.messages.base.BaseMessage] +) -> int +``` + + + + + + +Wrapper for `count_tokens_approximately` that matches expected signature. + + + + + + + + +```python +langchain_core.messages.utils._bytes_to_b64_str( + bytes_: bytes +) -> str +``` + + + + + + + + + + + + + +```python +langchain_core.messages.utils._chunk_to_msg( + chunk: langchain_core.messages.base.BaseMessageChunk +) -> langchain_core.messages.base.BaseMessage +``` + + + + + + + + + + + + + +```python +langchain_core.messages.utils._convert_to_message( + message: langchain_core.messages.utils.MessageLikeRepresentation +) -> langchain_core.messages.base.BaseMessage +``` + + + + + + +Instantiate a `Message` from a variety of message formats. + +The message format can be one of the following: + +- `BaseMessagePromptTemplate` +- `BaseMessage` +- 2-tuple of (role string, template); e.g., (`'human'`, `'{user_input}'`) +- dict: a message dict with role and content keys +- string: shorthand for (`'human'`, template); e.g., `'{user_input}'` + +**Parameters:** + + +a representation of a message in one of the supported formats. + + +**Returns:** `BaseMessage` + +An instance of a message or a message template. + +**Raises:** + +- `NotImplementedError`: if the message type is not supported. +- `ValueError`: if the message dict does not contain the required keys. + + + + + + + + +```python +langchain_core.messages.utils._convert_to_openai_tool_calls( + tool_calls: list[langchain_core.messages.tool.ToolCall] +) -> list[dict] +``` + + + + + + + + + + + + + +```python +langchain_core.messages.utils._create_message_from_message_type( + message_type: str, + content: str, + name: str | None = None, + tool_call_id: str | None = None, + tool_calls: list[dict[str, typing.Any]] | None = None, + id: str | None = None, + additional_kwargs: typing.Any = {} +) -> langchain_core.messages.base.BaseMessage +``` + + + + + + +Create a message from a `Message` type and content string. + +**Parameters:** + + +the type of the message (e.g., `'human'`, `'ai'`, etc.). + + + +the content string. + + + +the name of the message. + + + +the tool call id. + + + +the tool calls. + + + +the id of the message. + + + +additional keyword arguments. + + +**Returns:** `BaseMessage` + +a message of the appropriate type. + +**Raises:** + +- `ValueError`: if the message type is not one of `'human'`, `'user'`, `'ai'`, +`'assistant'`, `'function'`, `'tool'`, `'system'`, or +`'developer'`. + + + + + + + + +```python +langchain_core.messages.utils._default_text_splitter( + text: str +) -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.messages.utils._first_max_tokens( + messages: collections.abc.Sequence[langchain_core.messages.base.BaseMessage], + max_tokens: int, + token_counter: collections.abc.Callable[[list[BaseMessage]], int], + text_splitter: collections.abc.Callable[[str], list[str]], + partial_strategy: typing.Literal['first', 'last'] | None = None, + end_on: str | type[langchain_core.messages.base.BaseMessage] | collections.abc.Sequence[str | type[langchain_core.messages.base.BaseMessage]] | None = None +) -> list[langchain_core.messages.base.BaseMessage] +``` + + + + + + + + + + + + + +```python +langchain_core.messages.utils._format_content_block_xml( + block: dict +) -> str | None +``` + + + + + + +Format a content block as XML. + +**Parameters:** + + +A LangChain content block. + + +**Returns:** `str | None` + +XML string representation of the block, or `None` if the block should be +skipped. + + + + + + + + +```python +langchain_core.messages.utils._get_message_openai_role( + message: langchain_core.messages.base.BaseMessage +) -> str +``` + + + + + + + + + + + + + +```python +langchain_core.messages.utils._get_message_type_str( + m: langchain_core.messages.base.BaseMessage, + human_prefix: str, + ai_prefix: str, + system_prefix: str, + function_prefix: str, + tool_prefix: str +) -> str +``` + + + + + + +Get the type string for XML message element. + +**Parameters:** + + +The message to get the type string for. + + + +The prefix to use for `HumanMessage`. + + + +The prefix to use for `AIMessage`. + + + +The prefix to use for `SystemMessage`. + + + +The prefix to use for `FunctionMessage`. + + + +The prefix to use for `ToolMessage`. + + +**Returns:** `str` + +The type string for the message element. + +**Raises:** + +- `ValueError`: If an unsupported message type is encountered. + + + + + + + + +```python +langchain_core.messages.utils._get_type( + v: typing.Any +) -> str +``` + + + + + + +Get the type associated with the object for serialization purposes. + + + + + + + + +```python +langchain_core.messages.utils._has_base64_data( + block: dict +) -> bool +``` + + + + + + +Check if a content block contains base64 encoded data. + +**Parameters:** + + +A content block dictionary. + + +**Returns:** `bool` + +Whether the block contains base64 data. + + + + + + + + +```python +langchain_core.messages.utils._is_message_type( + message: langchain_core.messages.base.BaseMessage, + type_: str | type[langchain_core.messages.base.BaseMessage] | collections.abc.Sequence[str | type[langchain_core.messages.base.BaseMessage]] +) -> bool +``` + + + + + + + + + + + + + +```python +langchain_core.messages.utils._last_max_tokens( + messages: collections.abc.Sequence[langchain_core.messages.base.BaseMessage], + max_tokens: int, + token_counter: collections.abc.Callable[[list[BaseMessage]], int], + text_splitter: collections.abc.Callable[[str], list[str]], + allow_partial: bool = False, + include_system: bool = False, + start_on: str | type[langchain_core.messages.base.BaseMessage] | collections.abc.Sequence[str | type[langchain_core.messages.base.BaseMessage]] | None = None, + end_on: str | type[langchain_core.messages.base.BaseMessage] | collections.abc.Sequence[str | type[langchain_core.messages.base.BaseMessage]] | None = None +) -> list[langchain_core.messages.base.BaseMessage] +``` + + + + + + + + + + + + + +```python +langchain_core.messages.utils._message_from_dict( + message: dict +) -> langchain_core.messages.base.BaseMessage +``` + + + + + + + + + + + + + +```python +langchain_core.messages.utils._msg_to_chunk( + message: langchain_core.messages.base.BaseMessage +) -> langchain_core.messages.base.BaseMessageChunk +``` + + + + + + + + + + + + + +```python +langchain_core.messages.utils._runnable_support( + func: collections.abc.Callable[typing.Concatenate[collections.abc.Sequence[langchain_core.messages.utils.MessageLikeRepresentation] | langchain_core.prompt_values.PromptValue, langchain_core.messages.utils._P], langchain_core.messages.utils._R_co] +) -> langchain_core.messages.utils._RunnableSupportCallable[langchain_core.messages.utils._P, langchain_core.messages.utils._R_co] +``` + + + + + + + + + + + + + +```python +langchain_core.messages.utils._truncate( + text: str, + max_len: int = _XML_CONTENT_BLOCK_MAX_LEN +) -> str +``` + + + + + + +Truncate text to `max_len` characters, adding ellipsis if truncated. + + + + + + + + +```python +langchain_core.messages.utils.convert_to_messages( + messages: collections.abc.Iterable[langchain_core.messages.utils.MessageLikeRepresentation] | langchain_core.prompt_values.PromptValue +) -> list[langchain_core.messages.base.BaseMessage] +``` + + + + + + +Convert a sequence of messages to a list of messages. + +**Parameters:** + + +Sequence of messages to convert. + + +**Returns:** `list[BaseMessage]` + +list of messages (BaseMessages). + + + + + + + + +```python +langchain_core.messages.utils.convert_to_openai_messages( + messages: langchain_core.messages.utils.MessageLikeRepresentation | collections.abc.Sequence[langchain_core.messages.utils.MessageLikeRepresentation], + text_format: typing.Literal['string', 'block'] = 'string', + include_id: bool = False, + pass_through_unknown_blocks: bool = True +) -> dict | list[dict] +``` + + + + + + +Convert LangChain messages into OpenAI message dicts. + +!!! version-added "Added in `langchain-core` 0.3.11" + +**Parameters:** + + +Message-like object or iterable of objects whose contents are +in OpenAI, Anthropic, Bedrock Converse, or VertexAI formats. + + + +How to format string or text block contents: +- `'string'`: + If a message has a string content, this is left as a string. If + a message has content blocks that are all of type `'text'`, these + are joined with a newline to make a single string. If a message has + content blocks and at least one isn't of type `'text'`, then + all blocks are left as dicts. +- `'block'`: + If a message has a string content, this is turned into a list + with a single content block of type `'text'`. If a message has + content blocks these are left as is. + + + +Whether to include message IDs in the openai messages, if they +are present in the source messages. + + + +Whether to include content blocks with unknown +formats in the output. If `False`, an error is raised if an unknown +content block is encountered. + + +**Returns:** `dict | list[dict]` + +The return type depends on the input type: + +**Raises:** + +- `ValueError`: if an unrecognized `text_format` is specified, or if a message +content block is missing expected keys. + + + + + + + + +```python +langchain_core.messages.utils.count_tokens_approximately( + messages: collections.abc.Iterable[langchain_core.messages.utils.MessageLikeRepresentation], + chars_per_token: float = 4.0, + extra_tokens_per_message: float = 3.0, + count_name: bool = True, + tokens_per_image: int = 85, + use_usage_metadata_scaling: bool = False, + tools: list[langchain_core.tools.BaseTool | dict[str, typing.Any]] | None = None +) -> int +``` + + + + + + +Approximate the total number of tokens in messages. + +The token count includes stringified message content, role, and (optionally) name. + +- For AI messages, the token count also includes stringified tool calls. +- For tool messages, the token count also includes the tool call ID. +- For multimodal messages with images, applies a fixed token penalty per image + instead of counting base64-encoded characters. +- If tools are provided, the token count also includes stringified tool schemas. + +!!! version-added "Added in `langchain-core` 0.3.46" + +**Parameters:** + + +List of messages to count tokens for. + + + +Number of characters per token to use for the approximation. +One token corresponds to ~4 chars for common English text. +You can also specify `float` values for more fine-grained control. +[See more here](https://platform.openai.com/tokenizer). + + + +Number of extra tokens to add per message, e.g. +special tokens, including beginning/end of message. +You can also specify `float` values for more fine-grained control. +[See more here](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb). + + + +Whether to include message names in the count. + + + +Fixed token cost per image (default: 85, aligned with +OpenAI's low-resolution image token cost). + + + +If True, and all AI messages have consistent +`response_metadata['model_provider']`, scale the approximate token count +using the **most recent** AI message that has +`usage_metadata['total_tokens']`. The scaling factor is: +`AI_total_tokens / approx_tokens_up_to_that_AI_message` + + + +List of tools to include in the token count. Each tool can be either +a `BaseTool` instance or a dict representing a tool schema. `BaseTool` +instances are converted to OpenAI tool format before counting. + + +**Returns:** `int` + +Approximate number of tokens in the messages (and tools, if provided). + + + + + + + + +```python +langchain_core.messages.utils.filter_messages( + messages: collections.abc.Iterable[langchain_core.messages.utils.MessageLikeRepresentation] | langchain_core.prompt_values.PromptValue, + include_names: collections.abc.Sequence[str] | None = None, + exclude_names: collections.abc.Sequence[str] | None = None, + include_types: collections.abc.Sequence[str | type[langchain_core.messages.base.BaseMessage]] | None = None, + exclude_types: collections.abc.Sequence[str | type[langchain_core.messages.base.BaseMessage]] | None = None, + include_ids: collections.abc.Sequence[str] | None = None, + exclude_ids: collections.abc.Sequence[str] | None = None, + exclude_tool_calls: collections.abc.Sequence[str] | bool | None = None +) -> list[langchain_core.messages.base.BaseMessage] +``` + + + + + + +Filter messages based on `name`, `type` or `id`. + +**Parameters:** + + +Sequence Message-like objects to filter. + + + +Message names to include. + + + +Messages names to exclude. + + + +Message types to include. Can be specified as string names +(e.g. `'system'`, `'human'`, `'ai'`, ...) or as `BaseMessage` +classes (e.g. `SystemMessage`, `HumanMessage`, `AIMessage`, ...). + + + +Message types to exclude. Can be specified as string names +(e.g. `'system'`, `'human'`, `'ai'`, ...) or as `BaseMessage` +classes (e.g. `SystemMessage`, `HumanMessage`, `AIMessage`, ...). + + + +Message IDs to include. + + + +Message IDs to exclude. + + + +Tool call IDs to exclude. +Can be one of the following: +- `True`: All `AIMessage` objects with tool calls and all `ToolMessage` + objects will be excluded. +- a sequence of tool call IDs to exclude: + - `ToolMessage` objects with the corresponding tool call ID will be + excluded. + - The `tool_calls` in the AIMessage will be updated to exclude + matching tool calls. If all `tool_calls` are filtered from an + AIMessage, the whole message is excluded. + + +**Returns:** `list[BaseMessage]` + +A list of Messages that meets at least one of the `incl_*` conditions and none + +**Raises:** + +- `ValueError`: If two incompatible arguments are provided. + + + + + + + + +```python +langchain_core.messages.utils.get_buffer_string( + messages: collections.abc.Sequence[langchain_core.messages.base.BaseMessage], + human_prefix: str = 'Human', + ai_prefix: str = 'AI', + system_prefix: str = 'System', + function_prefix: str = 'Function', + tool_prefix: str = 'Tool', + message_separator: str = '\n', + format: typing.Literal['prefix', 'xml'] = 'prefix' +) -> str +``` + + + + + + +Convert a sequence of messages to strings and concatenate them into one string. + +!!! warning + + If a message is an `AIMessage` and contains both tool calls under `tool_calls` + and a function call under `additional_kwargs["function_call"]`, only the tool + calls will be appended to the string representation. + +!!! note "XML format" + + When using `format='xml'`: + + - All messages use uniform `<message type="role">content</message>` format. + - The `type` attribute uses `human_prefix` (lowercased) for `HumanMessage`, + `ai_prefix` (lowercased) for `AIMessage`, `system_prefix` (lowercased) + for `SystemMessage`, `function_prefix` (lowercased) for `FunctionMessage`, + `tool_prefix` (lowercased) for `ToolMessage`, and the original role + (unchanged) for `ChatMessage`. + - Message content is escaped using `xml.sax.saxutils.escape()`. + - Attribute values are escaped using `xml.sax.saxutils.quoteattr()`. + - AI messages with tool calls use nested structure with `<content>` and + `<tool_call>` elements. + - For multi-modal content (list of content blocks), supported block types + are: `text`, `reasoning`, `image` (URL/file_id only), `image_url` + (OpenAI-style, URL only), `audio` (URL/file_id only), `video` (URL/file_id + only), `text-plain`, `server_tool_call`, and `server_tool_result`. + - Content blocks with base64-encoded data are skipped (including blocks + with `base64` field or `data:` URLs). + - Unknown block types are skipped. + - Plain text document content (`text-plain`), server tool call arguments, + and server tool result outputs are truncated to 500 characters. + +**Parameters:** + + +Messages to be converted to strings. + + + +The prefix to prepend to contents of `HumanMessage`s. + + + +The prefix to prepend to contents of `AIMessage`. + + + +The prefix to prepend to contents of `SystemMessage`s. + + + +The prefix to prepend to contents of `FunctionMessage`s. + + + +The prefix to prepend to contents of `ToolMessage`s. + + + +The separator to use between messages. + + + +The output format. `'prefix'` uses `Role: content` format (default). + +`'xml'` uses XML-style `<message type='role'>` format with proper character +escaping, which is useful when message content may contain role-like +prefixes that could cause ambiguity. + + +**Returns:** `str` + +A single string concatenation of all input messages. + +**Raises:** + +- `ValueError`: If an unsupported message type is encountered. + + + + + + + + +```python +langchain_core.messages.utils.merge_message_runs( + messages: collections.abc.Iterable[langchain_core.messages.utils.MessageLikeRepresentation] | langchain_core.prompt_values.PromptValue, + chunk_separator: str = '\n' +) -> list[langchain_core.messages.base.BaseMessage] +``` + + + + + + +Merge consecutive Messages of the same type. + +!!! note + `ToolMessage` objects are not merged, as each has a distinct tool call id that + can't be merged. + +**Parameters:** + + +Sequence Message-like objects to merge. + + + +Specify the string to be inserted between message chunks. + + +**Returns:** `list[BaseMessage]` + +list of BaseMessages with consecutive runs of message types merged into single + + + + + + + + +```python +langchain_core.messages.utils.message_chunk_to_message( + chunk: langchain_core.messages.base.BaseMessage +) -> langchain_core.messages.base.BaseMessage +``` + + + + + + +Convert a message chunk to a `Message`. + +**Parameters:** + + +Message chunk to convert. + + +**Returns:** `BaseMessage` + +Message. + + + + + + + + +```python +langchain_core.messages.utils.messages_from_dict( + messages: collections.abc.Sequence[dict] +) -> list[langchain_core.messages.base.BaseMessage] +``` + + + + + + +Convert a sequence of messages from dicts to `Message` objects. + +**Parameters:** + + +Sequence of messages (as dicts) to convert. + + +**Returns:** `list[BaseMessage]` + +list of messages (BaseMessages). + + + + + + + + +```python +langchain_core.messages.utils.trim_messages( + messages: collections.abc.Iterable[langchain_core.messages.utils.MessageLikeRepresentation] | langchain_core.prompt_values.PromptValue, + max_tokens: int, + token_counter: collections.abc.Callable[[list[BaseMessage]], int] | collections.abc.Callable[[BaseMessage], int] | langchain_core.language_models.BaseLanguageModel | typing.Literal['approximate'], + strategy: typing.Literal['first', 'last'] = 'last', + allow_partial: bool = False, + end_on: str | type[langchain_core.messages.base.BaseMessage] | collections.abc.Sequence[str | type[langchain_core.messages.base.BaseMessage]] | None = None, + start_on: str | type[langchain_core.messages.base.BaseMessage] | collections.abc.Sequence[str | type[langchain_core.messages.base.BaseMessage]] | None = None, + include_system: bool = False, + text_splitter: collections.abc.Callable[[str], list[str]] | langchain_text_splitters.TextSplitter | None = None +) -> list[langchain_core.messages.base.BaseMessage] +``` + + + + + + +Trim messages to be below a token count. + +`trim_messages` can be used to reduce the size of a chat history to a specified +token or message count. + +In either case, if passing the trimmed chat history back into a chat model +directly, the resulting chat history should usually satisfy the following +properties: + +1. The resulting chat history should be valid. Most chat models expect that chat + history starts with either (1) a `HumanMessage` or (2) a `SystemMessage` + followed by a `HumanMessage`. To achieve this, set `start_on='human'`. + In addition, generally a `ToolMessage` can only appear after an `AIMessage` + that involved a tool call. +2. It includes recent messages and drops old messages in the chat history. + To achieve this set the `strategy='last'`. +3. Usually, the new chat history should include the `SystemMessage` if it + was present in the original chat history since the `SystemMessage` includes + special instructions to the chat model. The `SystemMessage` is almost always + the first message in the history if present. To achieve this set the + `include_system=True`. + +!!! note + The examples below show how to configure `trim_messages` to achieve a behavior + consistent with the above properties. + +**Parameters:** + + +Sequence of Message-like objects to trim. + + + +Max token count of trimmed messages. + + + +Function or llm for counting tokens in a `BaseMessage` or a +list of `BaseMessage`. + +If a `BaseLanguageModel` is passed in then +`BaseLanguageModel.get_num_tokens_from_messages()` will be used. Set to +`len` to count the number of **messages** in the chat history. + +You can also use string shortcuts for convenience: + +- `'approximate'`: Uses `count_tokens_approximately` for fast, approximate + token counts. + +!!! note + + `count_tokens_approximately` (or the shortcut `'approximate'`) is + recommended for using `trim_messages` on the hot path, where exact token + counting is not necessary. + + + +Strategy for trimming. + +- `'first'`: Keep the first `<= n_count` tokens of the messages. +- `'last'`: Keep the last `<= n_count` tokens of the messages. + + + +Whether to split a message if only part of the message can be +included. + +If `strategy='last'` then the last partial contents of a message are +included. If `strategy='first'` then the first partial contents of a +message are included. + + + +The message type to end on. + +If specified then every message after the last occurrence of this type is +ignored. If `strategy='last'` then this is done before we attempt to get the +last `max_tokens`. If `strategy='first'` then this is done after we get the +first `max_tokens`. Can be specified as string names (e.g. `'system'`, +`'human'`, `'ai'`, ...) or as `BaseMessage` classes (e.g. `SystemMessage`, +`HumanMessage`, `AIMessage`, ...). Can be a single type or a list of types. + + + +The message type to start on. + +Should only be specified if `strategy='last'`. If specified then every +message before the first occurrence of this type is ignored. This is done +after we trim the initial messages to the last `max_tokens`. Does not apply +to a `SystemMessage` at index 0 if `include_system=True`. Can be specified +as string names (e.g. `'system'`, `'human'`, `'ai'`, ...) or as +`BaseMessage` classes (e.g. `SystemMessage`, `HumanMessage`, `AIMessage`, +...). Can be a single type or a list of types. + + + +Whether to keep the `SystemMessage` if there is one at index +`0`. + +Should only be specified if `strategy="last"`. + + + +Function or `langchain_text_splitters.TextSplitter` for +splitting the string contents of a message. + +Only used if `allow_partial=True`. If `strategy='last'` then the last split +tokens from a partial message will be included. if `strategy='first'` then +the first split tokens from a partial message will be included. Token +splitter assumes that separators are kept, so that split contents can be +directly concatenated to recreate the original text. Defaults to splitting +on newlines. + + +**Returns:** `list[BaseMessage]` + +List of trimmed `BaseMessage`. + +**Raises:** + +- `ValueError`: if two incompatible arguments are specified or an unrecognized +`strategy` is specified. + + + + + + + + +```python +langchain_core.messages.utils.AnyMessage = Annotated[Annotated[AIMessage, Tag(tag='ai')] | Annotated[HumanMessage, Tag(tag=... +``` + + + + + + +A type representing any defined `Message` or `MessageChunk` type. + + + + + + + +```python +langchain_core.messages.utils.MessageLikeRepresentation = BaseMessage | list[str] | tuple[str, str] | str | dict[str, Any] +``` + + + + + + +A type representing the various ways a message can be represented. + + + + + + + +```python +langchain_core.messages.utils._CHUNK_MSG_MAP = {v: k for k, v in (_MSG_CHUNK_MAP.items())} +``` + + + + + + + + + +```python +langchain_core.messages.utils._HAS_LANGCHAIN_TEXT_SPLITTERS = True +``` + + + + + + + + + +```python +langchain_core.messages.utils._MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = {HumanMessage: HumanMessageChunk, AIMessage: AIMessageChunk, SystemMessage: Syst... +``` + + + + + + + + + +```python +langchain_core.messages.utils._MultipleMessages = Sequence[_T] +``` + + + + + + + + + +```python +langchain_core.messages.utils._P = ParamSpec('_P') +``` + + + + + + + + + +```python +langchain_core.messages.utils._R_co = TypeVar('_R_co', covariant=True) +``` + + + + + + + + + +```python +langchain_core.messages.utils._SingleMessage = BaseMessage | str | dict[str, Any] +``` + + + + + + + + + +```python +langchain_core.messages.utils._T = TypeVar('_T', bound=_SingleMessage) +``` + + + + + + + + + +```python +langchain_core.messages.utils._TOKEN_COUNTER_SHORTCUTS = {'approximate': _approximate_token_counter} +``` + + + + + + + + + +```python +langchain_core.messages.utils._XML_CONTENT_BLOCK_MAX_LEN = 500 +``` + + + + + + + + + +```python +langchain_core.messages.utils.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers.mdx new file mode 100644 index 0000000..80e42fc --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers.mdx @@ -0,0 +1,110 @@ +--- +layout: overview +slug: langchain-core/langchain_core/output_parsers +title: langchain_core.output_parsers +--- + +`OutputParser` classes parse the output of an LLM call into structured data. + +!!! tip "Structured output" + + Output parsers emerged as an early solution to the challenge of obtaining structured + output from LLMs. + + Today, most LLMs support [structured output](https://docs.langchain.com/oss/python/langchain/models#structured-outputs) + natively. In such cases, using output parsers may be unnecessary, and you should + leverage the model's built-in capabilities for structured output. Refer to the + [documentation of your chosen model](https://docs.langchain.com/oss/python/integrations/providers/overview) + for guidance on how to achieve structured output directly. + + Output parsers remain valuable when working with models that do not support + structured output natively, or when you require additional processing or validation + of the model's output beyond its inherent capabilities. + +## Submodules + +- **[`langchain_core.output_parsers.base`](/langchain-core/langchain_core/output_parsers/base)** +- **[`langchain_core.output_parsers.format_instructions`](/langchain-core/langchain_core/output_parsers/format_instructions)** +- **[`langchain_core.output_parsers.json`](/langchain-core/langchain_core/output_parsers/json)** +- **[`langchain_core.output_parsers.list`](/langchain-core/langchain_core/output_parsers/list)** +- **[`langchain_core.output_parsers.openai_functions`](/langchain-core/langchain_core/output_parsers/openai_functions)** +- **[`langchain_core.output_parsers.openai_tools`](/langchain-core/langchain_core/output_parsers/openai_tools)** +- **[`langchain_core.output_parsers.pydantic`](/langchain-core/langchain_core/output_parsers/pydantic)** +- **[`langchain_core.output_parsers.string`](/langchain-core/langchain_core/output_parsers/string)** +- **[`langchain_core.output_parsers.transform`](/langchain-core/langchain_core/output_parsers/transform)** +- **[`langchain_core.output_parsers.xml`](/langchain-core/langchain_core/output_parsers/xml)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-output_parsers-__dir__) | - | +| [`__getattr__`](#langchain_core-output_parsers-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-output_parsers-__all__) + +[`_dynamic_imports`](#langchain_core-output_parsers-_dynamic_imports) + +### API + + + + + +```python +langchain_core.output_parsers.__dir__() -> langchain_core.output_parsers.list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.output_parsers.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.output_parsers.__all__ = ['BaseCumulativeTransformOutputParser', 'BaseGenerationOutputParser', 'BaseLLMOu... +``` + + + + + + + + + +```python +langchain_core.output_parsers._dynamic_imports = {'BaseLLMOutputParser': 'base', 'BaseGenerationOutputParser': 'base', 'BaseOutpu... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/base.mdx new file mode 100644 index 0000000..b83e49a --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/base.mdx @@ -0,0 +1,516 @@ +--- +layout: overview +slug: langchain-core/langchain_core/output_parsers/base +title: langchain_core.output_parsers.base +--- + +Base parser for language model outputs. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseGenerationOutputParser`](#langchain_core-output_parsers-base-BaseGenerationOutputParser) | Base class to parse the output of an LLM call. | +| [`BaseLLMOutputParser`](#langchain_core-output_parsers-base-BaseLLMOutputParser) | Abstract base class for parsing the outputs of a model. | +| [`BaseOutputParser`](#langchain_core-output_parsers-base-BaseOutputParser) | Base class to parse the output of an LLM call. | + +### Data + +[`OutputParserLike`](#langchain_core-output_parsers-base-OutputParserLike) + +[`T`](#langchain_core-output_parsers-base-T) + +### API + + + + + +```python +class langchain_core.output_parsers.base.BaseGenerationOutputParser() +``` + + + + + + +**Bases:** [BaseLLMOutputParser](#langchain_core-output_parsers-base-BaseLLMOutputParser), [RunnableSerializable[LanguageModelOutput, T]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +Base class to parse the output of an LLM call. + + + +Return the input type for the parser. + + + +Return the output type for the parser. + + + + + +```python +langchain_core.output_parsers.base.BaseGenerationOutputParser.ainvoke( + input: str | langchain_core.messages.BaseMessage, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.output_parsers.base.T +``` + + + + + + +async + + + + + + + +```python +langchain_core.output_parsers.base.BaseGenerationOutputParser.invoke( + input: str | langchain_core.messages.BaseMessage, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.output_parsers.base.T +``` + + + + + + + + + + + + + + +```python +class langchain_core.output_parsers.base.BaseLLMOutputParser() +``` + + + + + + +Abstract + +**Bases:** `Generic[T]` + +Abstract base class for parsing the outputs of a model. + + + + + + +```python +langchain_core.output_parsers.base.BaseLLMOutputParser.aparse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> langchain_core.output_parsers.base.T +``` + + + + + + +async + +Parse a list of candidate model `Generation` objects into a specific format. + +**Parameters:** + + +A list of `Generation` to be parsed. + +The Generations are assumed to be different candidate outputs for a +single model input. + + + +Whether to parse the output as a partial result. + +This is useful for parsers that can parse partial results. + + +**Returns:** `T` + +Structured output. + + + + + + + +```python +langchain_core.output_parsers.base.BaseLLMOutputParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> langchain_core.output_parsers.base.T +``` + + + + + + +abstract + +Parse a list of candidate model `Generation` objects into a specific format. + +**Parameters:** + + +A list of `Generation` to be parsed. + +The `Generation` objects are assumed to be different candidate outputs +for a single model input. + + + +Whether to parse the output as a partial result. + +This is useful for parsers that can parse partial results. + + +**Returns:** `T` + +Structured output. + + + + + + + + + +```python +class langchain_core.output_parsers.base.BaseOutputParser() +``` + + + + + + +**Bases:** [BaseLLMOutputParser](#langchain_core-output_parsers-base-BaseLLMOutputParser), [RunnableSerializable[LanguageModelOutput, T]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +Base class to parse the output of an LLM call. + +Output parsers help structure language model responses. + + + +Return the input type for the parser. + + + +Return the output type for the parser. + +This property is inferred from the first type argument of the class. + + + +Return the output parser type for serialization. + + + + + +```python +langchain_core.output_parsers.base.BaseOutputParser.ainvoke( + input: str | langchain_core.messages.BaseMessage, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.output_parsers.base.T +``` + + + + + + +async + + + + + + + +```python +langchain_core.output_parsers.base.BaseOutputParser.aparse( + text: str +) -> langchain_core.output_parsers.base.T +``` + + + + + + +async + +Async parse a single string model output into some structure. + +**Parameters:** + + +String output of a language model. + + +**Returns:** `T` + +Structured output. + + + + + + + +```python +langchain_core.output_parsers.base.BaseOutputParser.aparse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> langchain_core.output_parsers.base.T +``` + + + + + + +async + +Parse a list of candidate model `Generation` objects into a specific format. + +The return value is parsed from only the first `Generation` in the result, which +is assumed to be the highest-likelihood `Generation`. + +**Parameters:** + + +A list of `Generation` to be parsed. + +The `Generation` objects are assumed to be different candidate outputs +for a single model input. + + + +Whether to parse the output as a partial result. + +This is useful for parsers that can parse partial results. + + +**Returns:** `T` + +Structured output. + + + + + + + +```python +langchain_core.output_parsers.base.BaseOutputParser.dict( + kwargs: typing.Any = {} +) -> langchain_core.output_parsers.base.BaseOutputParser.dict +``` + + + + + + +Return dictionary representation of output parser. + + + + + + + +```python +langchain_core.output_parsers.base.BaseOutputParser.get_format_instructions() -> str +``` + + + + + + +Instructions on how the LLM output should be formatted. + + + + + + + +```python +langchain_core.output_parsers.base.BaseOutputParser.invoke( + input: str | langchain_core.messages.BaseMessage, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.output_parsers.base.T +``` + + + + + + + + + + + + +```python +langchain_core.output_parsers.base.BaseOutputParser.parse( + text: str +) -> langchain_core.output_parsers.base.T +``` + + + + + + +abstract + +Parse a single string model output into some structure. + +**Parameters:** + + +String output of a language model. + + +**Returns:** `T` + +Structured output. + + + + + + + +```python +langchain_core.output_parsers.base.BaseOutputParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> langchain_core.output_parsers.base.T +``` + + + + + + +Parse a list of candidate model `Generation` objects into a specific format. + +The return value is parsed from only the first `Generation` in the result, which +is assumed to be the highest-likelihood `Generation`. + +**Parameters:** + + +A list of `Generation` to be parsed. + +The `Generation` objects are assumed to be different candidate outputs +for a single model input. + + + +Whether to parse the output as a partial result. + +This is useful for parsers that can parse partial results. + + +**Returns:** `T` + +Structured output. + + + + + + + +```python +langchain_core.output_parsers.base.BaseOutputParser.parse_with_prompt( + completion: str, + prompt: langchain_core.prompt_values.PromptValue +) -> typing.Any +``` + + + + + + +Parse the output of an LLM call with the input prompt for context. + +The prompt is largely provided in the event the `OutputParser` wants to retry or +fix the output in some way, and needs information from the prompt to do so. + +**Parameters:** + + +String output of a language model. + + + +Input `PromptValue`. + + +**Returns:** `Any` + +Structured output. + + + + + + + + + +```python +langchain_core.output_parsers.base.OutputParserLike = Runnable[LanguageModelOutput, T] +``` + + + + + + + + + +```python +langchain_core.output_parsers.base.T = TypeVar('T') +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/format_instructions.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/format_instructions.mdx new file mode 100644 index 0000000..e562fd6 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/format_instructions.mdx @@ -0,0 +1,27 @@ +--- +layout: overview +slug: langchain-core/langchain_core/output_parsers/format_instructions +title: langchain_core.output_parsers.format_instructions +--- + +Format instructions. + +## Module Contents + +### Data + +[`JSON_FORMAT_INSTRUCTIONS`](#langchain_core-output_parsers-format_instructions-JSON_FORMAT_INSTRUCTIONS) + +### API + + + + + +```python +langchain_core.output_parsers.format_instructions.JSON_FORMAT_INSTRUCTIONS = 'STRICT OUTPUT FORMAT:\n- Return only the JSON value that conforms to the schema... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/json.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/json.mdx new file mode 100644 index 0000000..2eaeb4d --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/json.mdx @@ -0,0 +1,242 @@ +--- +layout: overview +slug: langchain-core/langchain_core/output_parsers/json +title: langchain_core.output_parsers.json +--- + +Parser for JSON output. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`JsonOutputParser`](#langchain_core-output_parsers-json-JsonOutputParser) | Parse the output of an LLM call to a JSON object. | + +### Data + +[`PydanticBaseModel`](#langchain_core-output_parsers-json-PydanticBaseModel) + +[`SimpleJsonOutputParser`](#langchain_core-output_parsers-json-SimpleJsonOutputParser) + +[`TBaseModel`](#langchain_core-output_parsers-json-TBaseModel) + +[`__all__`](#langchain_core-output_parsers-json-__all__) + +### API + + + + + +```python +class langchain_core.output_parsers.json.JsonOutputParser() +``` + + + + + + +**Bases:** [BaseCumulativeTransformOutputParser[Any]](/langchain-core/langchain_core/output_parsers/transform#langchain_core-output_parsers-transform-BaseCumulativeTransformOutputParser) + +Parse the output of an LLM call to a JSON object. + +Probably the most reliable output parser for getting structured data that does *not* +use function calling. + +When used in streaming mode, it will yield partial JSON objects containing all the +keys that have been returned so far. + +In streaming, if `diff` is set to `True`, yields `JSONPatch` operations describing +the difference between the previous and the current object. + + + + + + +The Pydantic object to use for validation. + +If `None`, no validation is performed. + + + + + +```python +langchain_core.output_parsers.json.JsonOutputParser._diff( + prev: typing.Any | None, + next: typing.Any +) -> typing.Any +``` + + + + + + + + + + + + +```python +langchain_core.output_parsers.json.JsonOutputParser._get_schema( + pydantic_object: type[langchain_core.output_parsers.json.TBaseModel] +) -> dict[str, typing.Any] +``` + + + + + + +staticmethod + + + + + + + +```python +langchain_core.output_parsers.json.JsonOutputParser.get_format_instructions() -> str +``` + + + + + + +Return the format instructions for the JSON output. + +**Returns:** `str` + +The format instructions for the JSON output. + + + + + + + +```python +langchain_core.output_parsers.json.JsonOutputParser.parse( + text: str +) -> typing.Any +``` + + + + + + +Parse the output of an LLM call to a JSON object. + +**Parameters:** + + +The output of the LLM call. + + +**Returns:** `Any` + +The parsed JSON object. + + + + + + + +```python +langchain_core.output_parsers.json.JsonOutputParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> typing.Any +``` + + + + + + +Parse the result of an LLM call to a JSON object. + +**Parameters:** + + +The result of the LLM call. + + + +Whether to parse partial JSON objects. + +If `True`, the output will be a JSON object containing all the keys that +have been returned so far. + +If `False`, the output will be the full JSON object. + + +**Returns:** `Any` + +The parsed JSON object. + +**Raises:** + +- `OutputParserException`: If the output is not valid JSON. + + + + + + + + + +```python +langchain_core.output_parsers.json.PydanticBaseModel = BaseModel | pydantic.BaseModel +``` + + + + + + + + + +```python +langchain_core.output_parsers.json.SimpleJsonOutputParser = JsonOutputParser +``` + + + + + + + + + +```python +langchain_core.output_parsers.json.TBaseModel = TypeVar('TBaseModel', bound=PydanticBaseModel) +``` + + + + + + + + + +```python +langchain_core.output_parsers.json.__all__ = ['JsonOutputParser', 'SimpleJsonOutputParser', 'parse_and_check_json_markdown', ... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/list.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/list.mdx new file mode 100644 index 0000000..9163009 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/list.mdx @@ -0,0 +1,473 @@ +--- +layout: overview +slug: langchain-core/langchain_core/output_parsers/list +title: langchain_core.output_parsers.list +--- + +Parsers for list output. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CommaSeparatedListOutputParser`](#langchain_core-output_parsers-list-CommaSeparatedListOutputParser) | Parse the output of a model to a comma-separated list. | +| [`ListOutputParser`](#langchain_core-output_parsers-list-ListOutputParser) | Parse the output of a model to a list. | +| [`MarkdownListOutputParser`](#langchain_core-output_parsers-list-MarkdownListOutputParser) | Parse a Markdown list. | +| [`NumberedListOutputParser`](#langchain_core-output_parsers-list-NumberedListOutputParser) | Parse a numbered list. | + +### Functions + +| Name | Description | +|------|-------------| +| [`droplastn`](#langchain_core-output_parsers-list-droplastn) | Drop the last `n` elements of an iterator. | + +### Data + +[`T`](#langchain_core-output_parsers-list-T) + +### API + + + + + +```python +class langchain_core.output_parsers.list.CommaSeparatedListOutputParser() +``` + + + + + + +**Bases:** [ListOutputParser](#langchain_core-output_parsers-list-ListOutputParser) + +Parse the output of a model to a comma-separated list. + + + + + + + + +```python +langchain_core.output_parsers.list.CommaSeparatedListOutputParser.get_format_instructions() -> str +``` + + + + + + +Return the format instructions for the comma-separated list output. + + + + + + + +```python +langchain_core.output_parsers.list.CommaSeparatedListOutputParser.get_lc_namespace() -> langchain_core.output_parsers.list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "output_parsers", "list"]` + + + + + + + +```python +langchain_core.output_parsers.list.CommaSeparatedListOutputParser.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.output_parsers.list.CommaSeparatedListOutputParser.parse( + text: str +) -> langchain_core.output_parsers.list[str] +``` + + + + + + +Parse the output of an LLM call. + +**Parameters:** + + +The output of an LLM call. + + +**Returns:** `list[str]` + +A list of strings. + + + + + + + + + +```python +class langchain_core.output_parsers.list.ListOutputParser() +``` + + + + + + +**Bases:** [BaseTransformOutputParser[list[str]]](/langchain-core/langchain_core/output_parsers/transform#langchain_core-output_parsers-transform-BaseTransformOutputParser) + +Parse the output of a model to a list. + + + + + + + + +```python +langchain_core.output_parsers.list.ListOutputParser._atransform( + input: collections.abc.AsyncIterator[str | langchain_core.messages.BaseMessage] +) -> collections.abc.AsyncIterator[langchain_core.output_parsers.list[str]] +``` + + + + + + +async + + + + + + + +```python +langchain_core.output_parsers.list.ListOutputParser._transform( + input: collections.abc.Iterator[str | langchain_core.messages.BaseMessage] +) -> collections.abc.Iterator[langchain_core.output_parsers.list[str]] +``` + + + + + + + + + + + + +```python +langchain_core.output_parsers.list.ListOutputParser.parse( + text: str +) -> langchain_core.output_parsers.list[str] +``` + + + + + + +abstract + +Parse the output of an LLM call. + +**Parameters:** + + +The output of an LLM call. + + +**Returns:** `list[str]` + +A list of strings. + + + + + + + +```python +langchain_core.output_parsers.list.ListOutputParser.parse_iter( + text: str +) -> collections.abc.Iterator[re.Match] +``` + + + + + + +Parse the output of an LLM call. + +**Parameters:** + + +The output of an LLM call. + + + + + + + + + + +```python +class langchain_core.output_parsers.list.MarkdownListOutputParser() +``` + + + + + + +**Bases:** [ListOutputParser](#langchain_core-output_parsers-list-ListOutputParser) + +Parse a Markdown list. + + + + + + +The pattern to match a Markdown list item. + + + + + +```python +langchain_core.output_parsers.list.MarkdownListOutputParser.get_format_instructions() -> str +``` + + + + + + +Return the format instructions for the Markdown list output. + + + + + + + +```python +langchain_core.output_parsers.list.MarkdownListOutputParser.parse( + text: str +) -> langchain_core.output_parsers.list[str] +``` + + + + + + +Parse the output of an LLM call. + +**Parameters:** + + +The output of an LLM call. + + +**Returns:** `list[str]` + +A list of strings. + + + + + + + +```python +langchain_core.output_parsers.list.MarkdownListOutputParser.parse_iter( + text: str +) -> collections.abc.Iterator[re.Match] +``` + + + + + + + + + + + + + + +```python +class langchain_core.output_parsers.list.NumberedListOutputParser() +``` + + + + + + +**Bases:** [ListOutputParser](#langchain_core-output_parsers-list-ListOutputParser) + +Parse a numbered list. + + + + + + +The pattern to match a numbered list item. + + + + + +```python +langchain_core.output_parsers.list.NumberedListOutputParser.get_format_instructions() -> str +``` + + + + + + + + + + + + +```python +langchain_core.output_parsers.list.NumberedListOutputParser.parse( + text: str +) -> langchain_core.output_parsers.list[str] +``` + + + + + + +Parse the output of an LLM call. + +**Parameters:** + + +The output of an LLM call. + + +**Returns:** `list[str]` + +A list of strings. + + + + + + + +```python +langchain_core.output_parsers.list.NumberedListOutputParser.parse_iter( + text: str +) -> collections.abc.Iterator[re.Match] +``` + + + + + + + + + + + + + + +```python +langchain_core.output_parsers.list.droplastn( + iter: collections.abc.Iterator[langchain_core.output_parsers.list.T], + n: int +) -> collections.abc.Iterator[langchain_core.output_parsers.list.T] +``` + + + + + + +Drop the last `n` elements of an iterator. + +**Parameters:** + + +The iterator to drop elements from. + + + +The number of elements to drop. + + + + + + + + + +```python +langchain_core.output_parsers.list.T = TypeVar('T') +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/openai_functions.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/openai_functions.mdx new file mode 100644 index 0000000..e3f7d2a --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/openai_functions.mdx @@ -0,0 +1,422 @@ +--- +layout: overview +slug: langchain-core/langchain_core/output_parsers/openai_functions +title: langchain_core.output_parsers.openai_functions +--- + +Parsers for OpenAI functions output. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`JsonKeyOutputFunctionsParser`](#langchain_core-output_parsers-openai_functions-JsonKeyOutputFunctionsParser) | Parse an output as the element of the JSON object. | +| [`JsonOutputFunctionsParser`](#langchain_core-output_parsers-openai_functions-JsonOutputFunctionsParser) | Parse an output as the JSON object. | +| [`OutputFunctionsParser`](#langchain_core-output_parsers-openai_functions-OutputFunctionsParser) | Parse an output that is one of sets of values. | +| [`PydanticAttrOutputFunctionsParser`](#langchain_core-output_parsers-openai_functions-PydanticAttrOutputFunctionsParser) | Parse an output as an attribute of a Pydantic object. | +| [`PydanticOutputFunctionsParser`](#langchain_core-output_parsers-openai_functions-PydanticOutputFunctionsParser) | Parse an output as a Pydantic object. | + +### API + + + + + +```python +class langchain_core.output_parsers.openai_functions.JsonKeyOutputFunctionsParser() +``` + + + + + + +**Bases:** [JsonOutputFunctionsParser](#langchain_core-output_parsers-openai_functions-JsonOutputFunctionsParser) + +Parse an output as the element of the JSON object. + + + +The name of the key to return. + + + + + +```python +langchain_core.output_parsers.openai_functions.JsonKeyOutputFunctionsParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> typing.Any +``` + + + + + + +Parse the result of an LLM call to a JSON object. + +**Parameters:** + + +The result of the LLM call. + + + +Whether to parse partial JSON objects. + + +**Returns:** `Any` + +The parsed JSON object. + + + + + + + + + +```python +class langchain_core.output_parsers.openai_functions.JsonOutputFunctionsParser() +``` + + + + + + +**Bases:** [BaseCumulativeTransformOutputParser[Any]](/langchain-core/langchain_core/output_parsers/transform#langchain_core-output_parsers-transform-BaseCumulativeTransformOutputParser) + +Parse an output as the JSON object. + + + + + + +Whether to only return the arguments to the function call. + + + +Whether to allow non-JSON-compliant strings. + +See: https://docs.python.org/3/library/json.html#encoders-and-decoders + +Useful when the parsed output may include unicode characters or new lines. + + + + + +```python +langchain_core.output_parsers.openai_functions.JsonOutputFunctionsParser._diff( + prev: typing.Any | None, + next: typing.Any +) -> typing.Any +``` + + + + + + + + + + + + +```python +langchain_core.output_parsers.openai_functions.JsonOutputFunctionsParser.parse( + text: str +) -> typing.Any +``` + + + + + + +Parse the output of an LLM call to a JSON object. + +**Parameters:** + + +The output of the LLM call. + + +**Returns:** `Any` + +The parsed JSON object. + + + + + + + +```python +langchain_core.output_parsers.openai_functions.JsonOutputFunctionsParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> typing.Any +``` + + + + + + +Parse the result of an LLM call to a JSON object. + +**Parameters:** + + +The result of the LLM call. + + + +Whether to parse partial JSON objects. + + +**Returns:** `Any` + +The parsed JSON object. + +**Raises:** + +- `OutputParserException`: If the output is not valid JSON. + + + + + + + + + +```python +class langchain_core.output_parsers.openai_functions.OutputFunctionsParser() +``` + + + + + + +**Bases:** [BaseGenerationOutputParser[Any]](/langchain-core/langchain_core/output_parsers/base#langchain_core-output_parsers-base-BaseGenerationOutputParser) + +Parse an output that is one of sets of values. + + + +Whether to only return the arguments to the function call. + + + + + +```python +langchain_core.output_parsers.openai_functions.OutputFunctionsParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> typing.Any +``` + + + + + + +Parse the result of an LLM call to a JSON object. + +**Parameters:** + + +The result of the LLM call. + + + +Whether to parse partial JSON objects. + + +**Returns:** `Any` + +The parsed JSON object. + +**Raises:** + +- `OutputParserException`: If the output is not valid JSON. + + + + + + + + + +```python +class langchain_core.output_parsers.openai_functions.PydanticAttrOutputFunctionsParser() +``` + + + + + + +**Bases:** [PydanticOutputFunctionsParser](#langchain_core-output_parsers-openai_functions-PydanticOutputFunctionsParser) + +Parse an output as an attribute of a Pydantic object. + + + +The name of the attribute to return. + + + + + +```python +langchain_core.output_parsers.openai_functions.PydanticAttrOutputFunctionsParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> typing.Any +``` + + + + + + +Parse the result of an LLM call to a JSON object. + +**Parameters:** + + +The result of the LLM call. + + + +Whether to parse partial JSON objects. + + +**Returns:** `Any` + +The parsed JSON object. + + + + + + + + + +```python +class langchain_core.output_parsers.openai_functions.PydanticOutputFunctionsParser() +``` + + + + + + +**Bases:** [OutputFunctionsParser](#langchain_core-output_parsers-openai_functions-OutputFunctionsParser) + +Parse an output as a Pydantic object. + +This parser is used to parse the output of a chat model that uses OpenAI function +format to invoke functions. + +The parser extracts the function call invocation and matches them to the Pydantic +schema provided. + +An exception will be raised if the function call does not match the provided schema. + + + +The Pydantic schema to parse the output with. + +If multiple schemas are provided, then the function name will be used to +determine which schema to use. + + + + + +```python +langchain_core.output_parsers.openai_functions.PydanticOutputFunctionsParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> typing.Any +``` + + + + + + +Parse the result of an LLM call to a JSON object. + +**Parameters:** + + +The result of the LLM call. + + + +Whether to parse partial JSON objects. + + +**Returns:** `Any` + +The parsed JSON object. + +**Raises:** + +- `ValueError`: If the Pydantic schema is not valid. + + + + + + + +```python +langchain_core.output_parsers.openai_functions.PydanticOutputFunctionsParser.validate_schema( + values: dict[str, typing.Any] +) -> typing.Any +``` + + + + + + +classmethod + +Validate the Pydantic schema. + +**Parameters:** + + +The values to validate. + + +**Returns:** `Any` + +The validated values. + +**Raises:** + +- `ValueError`: If the schema is not a Pydantic schema. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/openai_tools.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/openai_tools.mdx new file mode 100644 index 0000000..2cbb841 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/openai_tools.mdx @@ -0,0 +1,437 @@ +--- +layout: overview +slug: langchain-core/langchain_core/output_parsers/openai_tools +title: langchain_core.output_parsers.openai_tools +--- + +Parse tools for OpenAI tools output. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`JsonOutputKeyToolsParser`](#langchain_core-output_parsers-openai_tools-JsonOutputKeyToolsParser) | Parse tools from OpenAI response. | +| [`JsonOutputToolsParser`](#langchain_core-output_parsers-openai_tools-JsonOutputToolsParser) | Parse tools from OpenAI response. | +| [`PydanticToolsParser`](#langchain_core-output_parsers-openai_tools-PydanticToolsParser) | Parse tools from OpenAI response. | + +### Functions + +| Name | Description | +|------|-------------| +| [`make_invalid_tool_call`](#langchain_core-output_parsers-openai_tools-make_invalid_tool_call) | Create an `InvalidToolCall` from a raw tool call. | +| [`parse_tool_call`](#langchain_core-output_parsers-openai_tools-parse_tool_call) | Parse a single tool call. | +| [`parse_tool_calls`](#langchain_core-output_parsers-openai_tools-parse_tool_calls) | Parse a list of tool calls. | + +### Data + +[`_MAX_TOKENS_ERROR`](#langchain_core-output_parsers-openai_tools-_MAX_TOKENS_ERROR) + +[`logger`](#langchain_core-output_parsers-openai_tools-logger) + +### API + + + + + +```python +class langchain_core.output_parsers.openai_tools.JsonOutputKeyToolsParser() +``` + + + + + + +**Bases:** [JsonOutputToolsParser](#langchain_core-output_parsers-openai_tools-JsonOutputToolsParser) + +Parse tools from OpenAI response. + + + +The type of tools to return. + + + + + +```python +langchain_core.output_parsers.openai_tools.JsonOutputKeyToolsParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> typing.Any +``` + + + + + + +Parse the result of an LLM call to a list of tool calls. + +**Parameters:** + + +The result of the LLM call. + + + +Whether to parse partial JSON. +If `True`, the output will be a JSON object containing + all the keys that have been returned so far. +If `False`, the output will be the full JSON object. + + +**Returns:** `Any` + +The parsed tool calls. + +**Raises:** + +- `OutputParserException`: If the generation is not a chat generation. + + + + + + + + + +```python +class langchain_core.output_parsers.openai_tools.JsonOutputToolsParser() +``` + + + + + + +**Bases:** [BaseCumulativeTransformOutputParser[Any]](/langchain-core/langchain_core/output_parsers/transform#langchain_core-output_parsers-transform-BaseCumulativeTransformOutputParser) + +Parse tools from OpenAI response. + + + +Whether to return only the first tool call. + +If `False`, the result will be a list of tool calls, or an empty list if no tool +calls are found. + +If `True`, and multiple tool calls are found, only the first one will be returned, +and the other tool calls will be ignored. + +If no tool calls are found, `None` will be returned. + + + +Whether to return the tool call id. + + + +Whether to allow non-JSON-compliant strings. + +See: https://docs.python.org/3/library/json.html#encoders-and-decoders + +Useful when the parsed output may include unicode characters or new lines. + + + + + +```python +langchain_core.output_parsers.openai_tools.JsonOutputToolsParser.parse( + text: str +) -> typing.Any +``` + + + + + + +Parse the output of an LLM call to a list of tool calls. + +**Parameters:** + + +The output of the LLM call. + + +**Returns:** `Any` + +The parsed tool calls. + + + + + + + +```python +langchain_core.output_parsers.openai_tools.JsonOutputToolsParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> typing.Any +``` + + + + + + +Parse the result of an LLM call to a list of tool calls. + +**Parameters:** + + +The result of the LLM call. + + + +Whether to parse partial JSON. + +If `True`, the output will be a JSON object containing +all the keys that have been returned so far. + +If `False`, the output will be the full JSON object. + + +**Returns:** `Any` + +The parsed tool calls. + +**Raises:** + +- `OutputParserException`: If the output is not valid JSON. + + + + + + + + + +```python +class langchain_core.output_parsers.openai_tools.PydanticToolsParser() +``` + + + + + + +**Bases:** [JsonOutputToolsParser](#langchain_core-output_parsers-openai_tools-JsonOutputToolsParser) + +Parse tools from OpenAI response. + + + +The tools to parse. + + + + + +```python +langchain_core.output_parsers.openai_tools.PydanticToolsParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> typing.Any +``` + + + + + + +Parse the result of an LLM call to a list of Pydantic objects. + +**Parameters:** + + +The result of the LLM call. + + + +Whether to parse partial JSON. + +If `True`, the output will be a JSON object containing all the keys that +have been returned so far. + +If `False`, the output will be the full JSON object. + + +**Returns:** `Any` + +The parsed Pydantic objects. + +**Raises:** + +- `ValueError`: If the tool call arguments are not a dict. +- `ValidationError`: If the tool call arguments do not conform to the Pydantic +model. + + + + + + + + + +```python +langchain_core.output_parsers.openai_tools.make_invalid_tool_call( + raw_tool_call: dict[str, typing.Any], + error_msg: str | None +) -> langchain_core.messages.InvalidToolCall +``` + + + + + + +Create an `InvalidToolCall` from a raw tool call. + +**Parameters:** + + +The raw tool call. + + + +The error message. + + +**Returns:** `InvalidToolCall` + +An `InvalidToolCall` instance with the error message. + + + + + + + + +```python +langchain_core.output_parsers.openai_tools.parse_tool_call( + raw_tool_call: dict[str, typing.Any], + partial: bool = False, + strict: bool = False, + return_id: bool = True +) -> dict[str, typing.Any] | None +``` + + + + + + +Parse a single tool call. + +**Parameters:** + + +The raw tool call to parse. + + + +Whether to parse partial JSON. + + + +Whether to allow non-JSON-compliant strings. + + + +Whether to return the tool call id. + + +**Returns:** `dict[str, Any] | None` + +The parsed tool call. + +**Raises:** + +- `OutputParserException`: If the tool call is not valid JSON. + + + + + + + + +```python +langchain_core.output_parsers.openai_tools.parse_tool_calls( + raw_tool_calls: langchain_core.output_parsers.list[dict], + partial: bool = False, + strict: bool = False, + return_id: bool = True +) -> langchain_core.output_parsers.list[dict[str, typing.Any]] +``` + + + + + + +Parse a list of tool calls. + +**Parameters:** + + +The raw tool calls to parse. + + + +Whether to parse partial JSON. + + + +Whether to allow non-JSON-compliant strings. + + + +Whether to return the tool call id. + + +**Returns:** `list[dict[str, Any]]` + +The parsed tool calls. + +**Raises:** + +- `OutputParserException`: If any of the tool calls are not valid JSON. + + + + + + + + +```python +langchain_core.output_parsers.openai_tools._MAX_TOKENS_ERROR = 'Output parser received a `max_tokens` stop reason. The output is likely incompl... +``` + + + + + + + + + +```python +langchain_core.output_parsers.openai_tools.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/pydantic.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/pydantic.mdx new file mode 100644 index 0000000..ed3b542 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/pydantic.mdx @@ -0,0 +1,204 @@ +--- +layout: overview +slug: langchain-core/langchain_core/output_parsers/pydantic +title: langchain_core.output_parsers.pydantic +--- + +Output parsers using Pydantic. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`PydanticOutputParser`](#langchain_core-output_parsers-pydantic-PydanticOutputParser) | Parse an output using a Pydantic model. | + +### Data + +[`_PYDANTIC_FORMAT_INSTRUCTIONS`](#langchain_core-output_parsers-pydantic-_PYDANTIC_FORMAT_INSTRUCTIONS) + +[`__all__`](#langchain_core-output_parsers-pydantic-__all__) + +### API + + + + + +```python +class langchain_core.output_parsers.pydantic.PydanticOutputParser() +``` + + + + + + +**Bases:** [JsonOutputParser](/langchain-core/langchain_core/output_parsers/json#langchain_core-output_parsers-json-JsonOutputParser), `Generic[TBaseModel]` + +Parse an output using a Pydantic model. + + + +Return the Pydantic model. + + + + + + +The Pydantic model to parse. + + + + + +```python +langchain_core.output_parsers.pydantic.PydanticOutputParser._parse_obj( + obj: dict +) -> langchain_core.utils.pydantic.TBaseModel +``` + + + + + + + + + + + + +```python +langchain_core.output_parsers.pydantic.PydanticOutputParser._parser_exception( + e: Exception, + json_object: dict +) -> langchain_core.exceptions.OutputParserException +``` + + + + + + + + + + + + +```python +langchain_core.output_parsers.pydantic.PydanticOutputParser.get_format_instructions() -> str +``` + + + + + + +Return the format instructions for the JSON output. + +**Returns:** `str` + +The format instructions for the JSON output. + + + + + + + +```python +langchain_core.output_parsers.pydantic.PydanticOutputParser.parse( + text: str +) -> langchain_core.utils.pydantic.TBaseModel +``` + + + + + + +Parse the output of an LLM call to a Pydantic object. + +**Parameters:** + + +The output of the LLM call. + + +**Returns:** `TBaseModel` + +The parsed Pydantic object. + + + + + + + +```python +langchain_core.output_parsers.pydantic.PydanticOutputParser.parse_result( + result: langchain_core.output_parsers.list[langchain_core.outputs.Generation], + partial: bool = False +) -> langchain_core.utils.pydantic.TBaseModel | None +``` + + + + + + +Parse the result of an LLM call to a Pydantic object. + +**Parameters:** + + +The result of the LLM call. + + + +Whether to parse partial JSON objects. + +If `True`, the output will be a JSON object containing all the keys that +have been returned so far. + + +**Returns:** `TBaseModel | None` + +The parsed Pydantic object. + +**Raises:** + +- `OutputParserException`: If the result is not valid JSON or does not conform +to the Pydantic model. + + + + + + + + + +```python +langchain_core.output_parsers.pydantic._PYDANTIC_FORMAT_INSTRUCTIONS = 'The output should be formatted as a JSON instance that conforms to the JSON sch... +``` + + + + + + + + + +```python +langchain_core.output_parsers.pydantic.__all__ = ['PydanticBaseModel', 'PydanticOutputParser', 'TBaseModel'] +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/string.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/string.mdx new file mode 100644 index 0000000..74ea66c --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/string.mdx @@ -0,0 +1,111 @@ +--- +layout: overview +slug: langchain-core/langchain_core/output_parsers/string +title: langchain_core.output_parsers.string +--- + +String output parser. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`StrOutputParser`](#langchain_core-output_parsers-string-StrOutputParser) | Extract text content from model outputs as a string. | + +### API + + + + + +```python +class langchain_core.output_parsers.string.StrOutputParser() +``` + + + + + + +**Bases:** [BaseTransformOutputParser[str]](/langchain-core/langchain_core/output_parsers/transform#langchain_core-output_parsers-transform-BaseTransformOutputParser) + +Extract text content from model outputs as a string. + +Converts model outputs (such as `AIMessage` or `AIMessageChunk` objects) into plain +text strings. It's the simplest output parser and is useful when you need string +responses for downstream processing, display, or storage. + +Supports streaming, yielding text chunks as they're generated by the model. + + + +Return the output parser type for serialization. + + + + + +```python +langchain_core.output_parsers.string.StrOutputParser.get_lc_namespace() -> langchain_core.output_parsers.list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "output_parser"]` + + + + + + + +```python +langchain_core.output_parsers.string.StrOutputParser.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +`StrOutputParser` is serializable. + +**Returns:** `bool` + +`True` + + + + + + + +```python +langchain_core.output_parsers.string.StrOutputParser.parse( + text: str +) -> str +``` + + + + + + +Returns the input text with no changes. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/transform.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/transform.mdx new file mode 100644 index 0000000..cb80843 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/transform.mdx @@ -0,0 +1,242 @@ +--- +layout: overview +slug: langchain-core/langchain_core/output_parsers/transform +title: langchain_core.output_parsers.transform +--- + +Base classes for output parsers that can handle streaming input. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseCumulativeTransformOutputParser`](#langchain_core-output_parsers-transform-BaseCumulativeTransformOutputParser) | Base class for an output parser that can handle streaming input. | +| [`BaseTransformOutputParser`](#langchain_core-output_parsers-transform-BaseTransformOutputParser) | Base class for an output parser that can handle streaming input. | + +### API + + + + + +```python +class langchain_core.output_parsers.transform.BaseCumulativeTransformOutputParser() +``` + + + + + + +**Bases:** [BaseTransformOutputParser[T]](#langchain_core-output_parsers-transform-BaseTransformOutputParser) + +Base class for an output parser that can handle streaming input. + + + +In streaming mode, whether to yield diffs between the previous and current parsed +output, or just the current parsed output. + + + + + +```python +langchain_core.output_parsers.transform.BaseCumulativeTransformOutputParser._atransform( + input: collections.abc.AsyncIterator[str | langchain_core.messages.BaseMessage] +) -> collections.abc.AsyncIterator[langchain_core.output_parsers.base.T] +``` + + + + + + +async + + + + + + + +```python +langchain_core.output_parsers.transform.BaseCumulativeTransformOutputParser._diff( + prev: langchain_core.output_parsers.base.T | None, + next: langchain_core.output_parsers.base.T +) -> langchain_core.output_parsers.base.T +``` + + + + + + +Convert parsed outputs into a diff format. + +The semantics of this are up to the output parser. + +**Parameters:** + + +The previous parsed output. + + + +The current parsed output. + + +**Returns:** `T` + +The diff between the previous and current parsed output. + + + + + + + +```python +langchain_core.output_parsers.transform.BaseCumulativeTransformOutputParser._transform( + input: collections.abc.Iterator[str | langchain_core.messages.BaseMessage] +) -> collections.abc.Iterator[typing.Any] +``` + + + + + + + + + + + + + + +```python +class langchain_core.output_parsers.transform.BaseTransformOutputParser() +``` + + + + + + +**Bases:** [BaseOutputParser[T]](/langchain-core/langchain_core/output_parsers/base#langchain_core-output_parsers-base-BaseOutputParser) + +Base class for an output parser that can handle streaming input. + + + + + + +```python +langchain_core.output_parsers.transform.BaseTransformOutputParser._atransform( + input: collections.abc.AsyncIterator[str | langchain_core.messages.BaseMessage] +) -> collections.abc.AsyncIterator[langchain_core.output_parsers.base.T] +``` + + + + + + +async + + + + + + + +```python +langchain_core.output_parsers.transform.BaseTransformOutputParser._transform( + input: collections.abc.Iterator[str | langchain_core.messages.BaseMessage] +) -> collections.abc.Iterator[langchain_core.output_parsers.base.T] +``` + + + + + + + + + + + + +```python +langchain_core.output_parsers.transform.BaseTransformOutputParser.atransform( + input: collections.abc.AsyncIterator[str | langchain_core.messages.BaseMessage], + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.output_parsers.base.T] +``` + + + + + + +async + +Async transform the input into the output format. + +**Parameters:** + + +The input to transform. + + + +The configuration to use for the transformation. + + + +Additional keyword arguments. + + + + + + + + +```python +langchain_core.output_parsers.transform.BaseTransformOutputParser.transform( + input: collections.abc.Iterator[str | langchain_core.messages.BaseMessage], + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.output_parsers.base.T] +``` + + + + + + +Transform the input into the output format. + +**Parameters:** + + +The input to transform. + + + +The configuration to use for the transformation. + + + +Additional keyword arguments. + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/xml.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/xml.mdx new file mode 100644 index 0000000..9403a7c --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/output_parsers/xml.mdx @@ -0,0 +1,356 @@ +--- +layout: overview +slug: langchain-core/langchain_core/output_parsers/xml +title: langchain_core.output_parsers.xml +--- + +Output parser for XML format. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`XMLOutputParser`](#langchain_core-output_parsers-xml-XMLOutputParser) | Parse an output using xml format. | +| [`_StreamingParser`](#langchain_core-output_parsers-xml-_StreamingParser) | Streaming parser for XML. | + +### Functions + +| Name | Description | +|------|-------------| +| [`nested_element`](#langchain_core-output_parsers-xml-nested_element) | Get nested element from path. | + +### Data + +[`XML_FORMAT_INSTRUCTIONS`](#langchain_core-output_parsers-xml-XML_FORMAT_INSTRUCTIONS) + +[`_HAS_DEFUSEDXML`](#langchain_core-output_parsers-xml-_HAS_DEFUSEDXML) + +### API + + + + + +```python +class langchain_core.output_parsers.xml.XMLOutputParser() +``` + + + + + + +**Bases:** [BaseTransformOutputParser](/langchain-core/langchain_core/output_parsers/transform#langchain_core-output_parsers-transform-BaseTransformOutputParser) + +Parse an output using xml format. + +Returns a dictionary of tags. + + + + + + + + + +Parser to use for XML parsing. + +Can be either `'defusedxml'` or `'xml'`. + +- `'defusedxml'` is the default parser and is used to prevent XML vulnerabilities + present in some distributions of Python's standard library xml. `defusedxml` is + a wrapper around the standard library parser that sets up the parser with secure + defaults. +- `'xml'` is the standard library parser. + +!!! warning + + Use `xml` only if you are sure that your distribution of the standard library is + not vulnerable to XML vulnerabilities. + +Review the following resources for more information: + +* https://docs.python.org/3/library/xml.html#xml-vulnerabilities +* https://github.com/tiran/defusedxml + +The standard library relies on [`libexpat`](https://github.com/libexpat/libexpat) +for parsing XML. + + + +Tags to tell the LLM to expect in the XML output. + + Note this may not be perfect depending on the LLM implementation. + + For example, with `tags=["foo", "bar", "baz"]`: + + 1. A well-formatted XML instance: + `'<foo> + <bar> + <baz></baz> + </bar> +</foo>'` + + 2. A badly-formatted XML instance (missing closing tag for 'bar'): + `'<foo> + <bar> + </foo>'` + + 3. A badly-formatted XML instance (unexpected 'tag' element): + `'<foo> + <tag> + </tag> +</foo>'` + + + + + +```python +langchain_core.output_parsers.xml.XMLOutputParser._atransform( + input: collections.abc.AsyncIterator[str | langchain_core.messages.BaseMessage] +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.AddableDict] +``` + + + + + + +async + + + + + + + +```python +langchain_core.output_parsers.xml.XMLOutputParser._root_to_dict( + root: xml.etree.ElementTree.Element +) -> dict[str, str | langchain_core.output_parsers.list[typing.Any]] +``` + + + + + + +Converts xml tree to python dictionary. + + + + + + + +```python +langchain_core.output_parsers.xml.XMLOutputParser._transform( + input: collections.abc.Iterator[str | langchain_core.messages.BaseMessage] +) -> collections.abc.Iterator[langchain_core.runnables.utils.AddableDict] +``` + + + + + + + + + + + + +```python +langchain_core.output_parsers.xml.XMLOutputParser.get_format_instructions() -> str +``` + + + + + + +Return the format instructions for the XML output. + + + + + + + +```python +langchain_core.output_parsers.xml.XMLOutputParser.parse( + text: str +) -> dict[str, str | langchain_core.output_parsers.list[typing.Any]] +``` + + + + + + +Parse the output of an LLM call. + +**Parameters:** + + +The output of an LLM call. + + +**Returns:** `dict[str, str | list[Any]]` + +A `dict` representing the parsed XML. + +**Raises:** + +- `OutputParserException`: If the XML is not well-formed. +- `ImportError`: If defus`edxml is not installed and the `defusedxml` parser is +requested. + + + + + + + + + +```python +class langchain_core.output_parsers.xml._StreamingParser( + parser: typing.Literal['defusedxml', 'xml'] +) +``` + + + + + + +Streaming parser for XML. + +This implementation is pulled into a class to avoid implementation drift between +`transform` and `atransform` of the `XMLOutputParser`. + + + + + + + + + + + + + + + + + +```python +langchain_core.output_parsers.xml._StreamingParser.close() -> None +``` + + + + + + +Close the parser. + +This should be called after all chunks have been parsed. + + + + + + + +```python +langchain_core.output_parsers.xml._StreamingParser.parse( + chunk: str | langchain_core.messages.BaseMessage +) -> collections.abc.Iterator[langchain_core.runnables.utils.AddableDict] +``` + + + + + + +Parse a chunk of text. + +**Parameters:** + + +A chunk of text to parse. This can be a `str` or a `BaseMessage`. + + +**Raises:** + +- `xml.etree.ElementTree.ParseError`: If the XML is not well-formed. + + + + + + + + + +```python +langchain_core.output_parsers.xml.nested_element( + path: langchain_core.output_parsers.list[str], + elem: xml.etree.ElementTree.Element +) -> typing.Any +``` + + + + + + +Get nested element from path. + +**Parameters:** + + +The path to the element. + + + +The element to extract. + + +**Returns:** `Any` + +The nested element. + + + + + + + + +```python +langchain_core.output_parsers.xml.XML_FORMAT_INSTRUCTIONS = 'The output should be formatted as a XML file.\n1. Output should conform to the ... +``` + + + + + + + + + +```python +langchain_core.output_parsers.xml._HAS_DEFUSEDXML = True +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs.mdx new file mode 100644 index 0000000..8c03123 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs.mdx @@ -0,0 +1,109 @@ +--- +layout: overview +slug: langchain-core/langchain_core/outputs +title: langchain_core.outputs +--- + +Output classes. + +Used to represent the output of a language model call and the output of a chat. + +The top container for information is the `LLMResult` object. `LLMResult` is used by both +chat models and LLMs. This object contains the output of the language model and any +additional information that the model provider wants to return. + +When invoking models via the standard runnable methods (e.g. invoke, batch, etc.): + +- Chat models will return `AIMessage` objects. +- LLMs will return regular text strings. + +In addition, users can access the raw output of either LLMs or chat models via +callbacks. The `on_chat_model_end` and `on_llm_end` callbacks will return an `LLMResult` +object containing the generated outputs and any additional information returned by the +model provider. + +In general, if information is already available in the AIMessage object, it is +recommended to access it from there rather than from the `LLMResult` object. + +## Submodules + +- **[`langchain_core.outputs.chat_generation`](/langchain-core/langchain_core/outputs/chat_generation)** +- **[`langchain_core.outputs.chat_result`](/langchain-core/langchain_core/outputs/chat_result)** +- **[`langchain_core.outputs.generation`](/langchain-core/langchain_core/outputs/generation)** +- **[`langchain_core.outputs.llm_result`](/langchain-core/langchain_core/outputs/llm_result)** +- **[`langchain_core.outputs.run_info`](/langchain-core/langchain_core/outputs/run_info)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-outputs-__dir__) | - | +| [`__getattr__`](#langchain_core-outputs-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-outputs-__all__) + +[`_dynamic_imports`](#langchain_core-outputs-_dynamic_imports) + +### API + + + + + +```python +langchain_core.outputs.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.outputs.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.outputs.__all__ = ('ChatGeneration', 'ChatGenerationChunk', 'ChatResult', 'Generation', 'Generatio... +``` + + + + + + + + + +```python +langchain_core.outputs._dynamic_imports = {'ChatGeneration': 'chat_generation', 'ChatGenerationChunk': 'chat_generation', ... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/chat_generation.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/chat_generation.mdx new file mode 100644 index 0000000..2778370 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/chat_generation.mdx @@ -0,0 +1,192 @@ +--- +layout: overview +slug: langchain-core/langchain_core/outputs/chat_generation +title: langchain_core.outputs.chat_generation +--- + +Chat generation output classes. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ChatGeneration`](#langchain_core-outputs-chat_generation-ChatGeneration) | A single chat generation output. | +| [`ChatGenerationChunk`](#langchain_core-outputs-chat_generation-ChatGenerationChunk) | `ChatGeneration` chunk. | + +### Functions + +| Name | Description | +|------|-------------| +| [`merge_chat_generation_chunks`](#langchain_core-outputs-chat_generation-merge_chat_generation_chunks) | Merge a list of `ChatGenerationChunk`s into a single `ChatGenerationChunk`. | + +### API + + + + + +```python +class langchain_core.outputs.chat_generation.ChatGeneration() +``` + + + + + + +**Bases:** [Generation](/langchain-core/langchain_core/outputs/generation#langchain_core-outputs-generation-Generation) + +A single chat generation output. + +A subclass of `Generation` that represents the response from a chat model that +generates chat messages. + +The `message` attribute is a structured representation of the chat message. Most of +the time, the message will be of type `AIMessage`. + +Users working with chat models will usually access information via either +`AIMessage` (returned from runnable interfaces) or `LLMResult` (available via +callbacks). + + + +The message output by the chat model. + + + +The text contents of the output message. + +!!! warning "SHOULD NOT BE SET DIRECTLY!" + + + +Type is used exclusively for serialization purposes. + + + + + +```python +langchain_core.outputs.chat_generation.ChatGeneration.set_text() -> typing_extensions.Self +``` + + + + + + +Set the text attribute to be the contents of the message. + +**Parameters:** + + +The values of the object. + + +**Returns:** `Self` + +The values of the object with the text attribute set. + +**Raises:** + +- `ValueError`: If the message is not a string or a list. + + + + + + + + + +```python +class langchain_core.outputs.chat_generation.ChatGenerationChunk() +``` + + + + + + +**Bases:** [ChatGeneration](#langchain_core-outputs-chat_generation-ChatGeneration) + +`ChatGeneration` chunk. + +`ChatGeneration` chunks can be concatenated with other `ChatGeneration` chunks. + + + +The message chunk output by the chat model. + + + +Type is used exclusively for serialization purposes. + + + + + +```python +langchain_core.outputs.chat_generation.ChatGenerationChunk.__add__( + other: langchain_core.outputs.chat_generation.ChatGenerationChunk | list[langchain_core.outputs.chat_generation.ChatGenerationChunk] +) -> langchain_core.outputs.chat_generation.ChatGenerationChunk +``` + + + + + + +Concatenate two `ChatGenerationChunk`s. + +**Parameters:** + + +The other `ChatGenerationChunk` or list of `ChatGenerationChunk` to +concatenate. + + +**Returns:** `ChatGenerationChunk` + +A new `ChatGenerationChunk` concatenated from self and other. + +**Raises:** + +- `TypeError`: If other is not a `ChatGenerationChunk` or list of +`ChatGenerationChunk`. + + + + + + + + + +```python +langchain_core.outputs.chat_generation.merge_chat_generation_chunks( + chunks: list[langchain_core.outputs.chat_generation.ChatGenerationChunk] +) -> langchain_core.outputs.chat_generation.ChatGenerationChunk | None +``` + + + + + + +Merge a list of `ChatGenerationChunk`s into a single `ChatGenerationChunk`. + +**Parameters:** + + +A list of `ChatGenerationChunk` to merge. + + +**Returns:** `ChatGenerationChunk | None` + +A merged `ChatGenerationChunk`, or `None` if the input list is empty. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/chat_result.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/chat_result.mdx new file mode 100644 index 0000000..45bcf92 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/chat_result.mdx @@ -0,0 +1,62 @@ +--- +layout: overview +slug: langchain-core/langchain_core/outputs/chat_result +title: langchain_core.outputs.chat_result +--- + +Chat result schema. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ChatResult`](#langchain_core-outputs-chat_result-ChatResult) | Use to represent the result of a chat model call with a single prompt. | + +### API + + + + + +```python +class langchain_core.outputs.chat_result.ChatResult() +``` + + + + + + +**Bases:** `BaseModel` + +Use to represent the result of a chat model call with a single prompt. + +This container is used internally by some implementations of chat model, it will +eventually be mapped to a more general `LLMResult` object, and then projected into +an `AIMessage` object. + +LangChain users working with chat models will usually access information via +`AIMessage` (returned from runnable interfaces) or `LLMResult` (available via +callbacks). Please refer the `AIMessage` and `LLMResult` schema documentation for +more information. + + + +List of the chat generations. + +Generations is a list to allow for multiple candidate generations for a single +input prompt. + + + +For arbitrary LLM provider specific output. + +This dictionary is a free-form dictionary that can contain any information that the +provider wants to return. It is not standardized and is provider-specific. + +Users should generally avoid relying on this field and instead rely on accessing +relevant information from standardized fields present in `AIMessage`. + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/generation.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/generation.mdx new file mode 100644 index 0000000..050fd11 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/generation.mdx @@ -0,0 +1,158 @@ +--- +layout: overview +slug: langchain-core/langchain_core/outputs/generation +title: langchain_core.outputs.generation +--- + +Generation output schema. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Generation`](#langchain_core-outputs-generation-Generation) | A single text generation output. | +| [`GenerationChunk`](#langchain_core-outputs-generation-GenerationChunk) | `GenerationChunk`, which can be concatenated with other `Generation` chunks. | + +### API + + + + + +```python +class langchain_core.outputs.generation.Generation() +``` + + + + + + +**Bases:** [Serializable](/langchain-core/langchain_core/load/serializable#langchain_core-load-serializable-Serializable) + +A single text generation output. + +Generation represents the response from an "old-fashioned" LLM (string-in, +string-out) that generates regular text (not chat messages). + +This model is used internally by chat model and will eventually be mapped to a more +general `LLMResult` object, and then projected into an `AIMessage` object. + +LangChain users working with chat models will usually access information via +`AIMessage` (returned from runnable interfaces) or `LLMResult` (available via +callbacks). Please refer to `AIMessage` and `LLMResult` for more information. + + + +Raw response from the provider. + +May include things like the reason for finishing or token log probabilities. + + + +Generated text output. + + + +Type is used exclusively for serialization purposes. + +Set to `'Generation'` for this class. + + + + + +```python +langchain_core.outputs.generation.Generation.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "output"]` + + + + + + + +```python +langchain_core.outputs.generation.Generation.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + + + +```python +class langchain_core.outputs.generation.GenerationChunk() +``` + + + + + + +**Bases:** [Generation](#langchain_core-outputs-generation-Generation) + +`GenerationChunk`, which can be concatenated with other `Generation` chunks. + + + + + + +```python +langchain_core.outputs.generation.GenerationChunk.__add__( + other: langchain_core.outputs.generation.GenerationChunk +) -> langchain_core.outputs.generation.GenerationChunk +``` + + + + + + +Concatenate two `GenerationChunk` objects. + +**Parameters:** + + +Another `GenerationChunk` to concatenate with. + + +**Returns:** `GenerationChunk` + +A new `GenerationChunk` concatenated from self and other. + +**Raises:** + +- `TypeError`: If other is not a `GenerationChunk`. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/llm_result.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/llm_result.mdx new file mode 100644 index 0000000..73f85c4 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/llm_result.mdx @@ -0,0 +1,131 @@ +--- +layout: overview +slug: langchain-core/langchain_core/outputs/llm_result +title: langchain_core.outputs.llm_result +--- + +`LLMResult` class. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LLMResult`](#langchain_core-outputs-llm_result-LLMResult) | A container for results of an LLM call. | + +### API + + + + + +```python +class langchain_core.outputs.llm_result.LLMResult() +``` + + + + + + +**Bases:** `BaseModel` + +A container for results of an LLM call. + +Both chat models and LLMs generate an `LLMResult` object. This object contains the +generated outputs and any additional information that the model provider wants to +return. + + + +Generated outputs. + +The first dimension of the list represents completions for different input prompts. + +The second dimension of the list represents different candidate generations for a +given prompt. + +- When returned from **an LLM**, the type is `list[list[Generation]]`. +- When returned from a **chat model**, the type is `list[list[ChatGeneration]]`. + +`ChatGeneration` is a subclass of `Generation` that has a field for a structured +chat message. + + + +For arbitrary LLM provider specific output. + +This dictionary is a free-form dictionary that can contain any information that the +provider wants to return. It is not standardized and is provider-specific. + +Users should generally avoid relying on this field and instead rely on accessing +relevant information from standardized fields present in AIMessage. + + + +List of metadata info for model call for each input. + +See `langchain_core.outputs.run_info.RunInfo` for details. + + + +Type is used exclusively for serialization purposes. + + + + + +```python +langchain_core.outputs.llm_result.LLMResult.__eq__( + other: object +) -> bool +``` + + + + + + +Check for `LLMResult` equality by ignoring any metadata related to runs. + +**Parameters:** + + +Another `LLMResult` object to compare against. + + +**Returns:** `bool` + +`True` if the generations and `llm_output` are equal, `False` otherwise. + + + + + + + +```python +langchain_core.outputs.llm_result.LLMResult.flatten() -> list[langchain_core.outputs.llm_result.LLMResult] +``` + + + + + + +Flatten generations into a single list. + +Unpack `list[list[Generation]] -> list[LLMResult]` where each returned +`LLMResult` contains only a single `Generation`. If token usage information is +available, it is kept only for the `LLMResult` corresponding to the top-choice +`Generation`, to avoid over-counting of token usage downstream. + +**Returns:** `list[LLMResult]` + +List of `LLMResult` objects where each returned `LLMResult` contains a +single `Generation`. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/run_info.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/run_info.mdx new file mode 100644 index 0000000..c1ef755 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/outputs/run_info.mdx @@ -0,0 +1,47 @@ +--- +layout: overview +slug: langchain-core/langchain_core/outputs/run_info +title: langchain_core.outputs.run_info +--- + +`RunInfo` class. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RunInfo`](#langchain_core-outputs-run_info-RunInfo) | Class that contains metadata for a single execution of a chain or model. | + +### API + + + + + +```python +class langchain_core.outputs.run_info.RunInfo() +``` + + + + + + +**Bases:** `BaseModel` + +Class that contains metadata for a single execution of a chain or model. + +Defined for backwards compatibility with older versions of `langchain_core`. + +!!! warning "This model will likely be deprecated in the future." + +Users can acquire the `run_id` information from callbacks or via `run_id` +information present in the `astream_event` API (depending on the use case). + + + +A unique identifier for the model or chain run. + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompt_values.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompt_values.mdx new file mode 100644 index 0000000..35092b2 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompt_values.mdx @@ -0,0 +1,416 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompt_values +title: langchain_core.prompt_values +--- + +**Prompt values** for language model prompts. + +Prompt values are used to represent different pieces of prompts. They can be used to +represent text, images, or chat message pieces. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ChatPromptValue`](#langchain_core-prompt_values-ChatPromptValue) | Chat prompt value. | +| [`ChatPromptValueConcrete`](#langchain_core-prompt_values-ChatPromptValueConcrete) | Chat prompt value which explicitly lists out the message types it accepts. | +| [`ImagePromptValue`](#langchain_core-prompt_values-ImagePromptValue) | Image prompt value. | +| [`ImageURL`](#langchain_core-prompt_values-ImageURL) | Image URL for multimodal model inputs (OpenAI format). | +| [`PromptValue`](#langchain_core-prompt_values-PromptValue) | Base abstract class for inputs to any language model. | +| [`StringPromptValue`](#langchain_core-prompt_values-StringPromptValue) | String prompt value. | + +### API + + + + + +```python +class langchain_core.prompt_values.ChatPromptValue() +``` + + + + + + +**Bases:** [PromptValue](#langchain_core-prompt_values-PromptValue) + +Chat prompt value. + +A type of a prompt value that is built from messages. + + + +List of messages. + + + + + +```python +langchain_core.prompt_values.ChatPromptValue.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "prompts", "chat"]` + + + + + + + +```python +langchain_core.prompt_values.ChatPromptValue.to_messages() -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +Return prompt as a list of messages. + + + + + + + +```python +langchain_core.prompt_values.ChatPromptValue.to_string() -> str +``` + + + + + + +Return prompt as string. + + + + + + + + + +```python +class langchain_core.prompt_values.ChatPromptValueConcrete() +``` + + + + + + +**Bases:** [ChatPromptValue](#langchain_core-prompt_values-ChatPromptValue) + +Chat prompt value which explicitly lists out the message types it accepts. + +For use in external schemas. + + + +Sequence of messages. + + + + + + + + + + +```python +class langchain_core.prompt_values.ImagePromptValue() +``` + + + + + + +**Bases:** [PromptValue](#langchain_core-prompt_values-PromptValue) + +Image prompt value. + + + +Image URL. + + + + + + + + +```python +langchain_core.prompt_values.ImagePromptValue.to_messages() -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +Return prompt (image URL) as messages. + + + + + + + +```python +langchain_core.prompt_values.ImagePromptValue.to_string() -> str +``` + + + + + + +Return prompt (image URL) as string. + + + + + + + + + +```python +class langchain_core.prompt_values.ImageURL +``` + + + + + + +**Bases:** `typing.TypedDict` + +Image URL for multimodal model inputs (OpenAI format). + +Represents the inner `image_url` object in OpenAI's Chat Completion API format. This +is used by `ImagePromptTemplate` and `ChatPromptTemplate`. + + +Specifies the detail level of the image. + +Defaults to ``'auto'`` if not specified. Higher detail levels consume +more tokens but provide better image understanding. + + + +URL of the image or base64-encoded image data. + + + + + + + + +```python +class langchain_core.prompt_values.PromptValue() +``` + + + + + + +Abstract + +**Bases:** [Serializable](/langchain-core/langchain_core/load/serializable#langchain_core-load-serializable-Serializable) + +Base abstract class for inputs to any language model. + +`PromptValues` can be converted to both LLM (pure text-generation) inputs and +chat model inputs. + + + + + + +```python +langchain_core.prompt_values.PromptValue.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "prompt"]` + + + + + + + +```python +langchain_core.prompt_values.PromptValue.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.prompt_values.PromptValue.to_messages() -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +abstract + +Return prompt as a list of messages. + + + + + + + +```python +langchain_core.prompt_values.PromptValue.to_string() -> str +``` + + + + + + +abstract + +Return prompt value as string. + + + + + + + + + +```python +class langchain_core.prompt_values.StringPromptValue() +``` + + + + + + +**Bases:** [PromptValue](#langchain_core-prompt_values-PromptValue) + +String prompt value. + + + +Prompt text. + + + + + + + + +```python +langchain_core.prompt_values.StringPromptValue.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "prompts", "base"]` + + + + + + + +```python +langchain_core.prompt_values.StringPromptValue.to_messages() -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +Return prompt as messages. + + + + + + + +```python +langchain_core.prompt_values.StringPromptValue.to_string() -> str +``` + + + + + + +Return prompt as string. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts.mdx new file mode 100644 index 0000000..fa86d9b --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts.mdx @@ -0,0 +1,99 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts +title: langchain_core.prompts +--- + +A prompt is the input to the model. + +Prompt is often constructed from multiple components and prompt values. Prompt classes +and functions make constructing and working with prompts easy. + +## Submodules + +- **[`langchain_core.prompts.base`](/langchain-core/langchain_core/prompts/base)** +- **[`langchain_core.prompts.chat`](/langchain-core/langchain_core/prompts/chat)** +- **[`langchain_core.prompts.dict`](/langchain-core/langchain_core/prompts/dict)** +- **[`langchain_core.prompts.few_shot`](/langchain-core/langchain_core/prompts/few_shot)** +- **[`langchain_core.prompts.few_shot_with_templates`](/langchain-core/langchain_core/prompts/few_shot_with_templates)** +- **[`langchain_core.prompts.image`](/langchain-core/langchain_core/prompts/image)** +- **[`langchain_core.prompts.loading`](/langchain-core/langchain_core/prompts/loading)** +- **[`langchain_core.prompts.message`](/langchain-core/langchain_core/prompts/message)** +- **[`langchain_core.prompts.prompt`](/langchain-core/langchain_core/prompts/prompt)** +- **[`langchain_core.prompts.string`](/langchain-core/langchain_core/prompts/string)** +- **[`langchain_core.prompts.structured`](/langchain-core/langchain_core/prompts/structured)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-prompts-__dir__) | - | +| [`__getattr__`](#langchain_core-prompts-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-prompts-__all__) + +[`_dynamic_imports`](#langchain_core-prompts-_dynamic_imports) + +### API + + + + + +```python +langchain_core.prompts.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.prompts.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.prompts.__all__ = ('AIMessagePromptTemplate', 'BaseChatPromptTemplate', 'BasePromptTemplate', 'Cha... +``` + + + + + + + + + +```python +langchain_core.prompts._dynamic_imports = {'BasePromptTemplate': 'base', 'format_document': 'base', 'aformat_document': 'b... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/base.mdx new file mode 100644 index 0000000..01b11c6 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/base.mdx @@ -0,0 +1,669 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts/base +title: langchain_core.prompts.base +--- + +Base class for prompt templates. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BasePromptTemplate`](#langchain_core-prompts-base-BasePromptTemplate) | Base class for all prompt templates, returning a prompt. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_document_info`](#langchain_core-prompts-base-_get_document_info) | - | +| [`aformat_document`](#langchain_core-prompts-base-aformat_document) | Async format a document into a string based on a prompt template. | +| [`format_document`](#langchain_core-prompts-base-format_document) | Format a document into a string based on a prompt template. | + +### Data + +[`FormatOutputType`](#langchain_core-prompts-base-FormatOutputType) + +### API + + + + + +```python +class langchain_core.prompts.base.BasePromptTemplate() +``` + + + + + + +Abstract + +**Bases:** [RunnableSerializable[dict, PromptValue]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable), `Generic[FormatOutputType]` + +Base class for all prompt templates, returning a prompt. + + + +Return the output type of the prompt. + + + +Return the prompt type key. + + + + + + +A dictionary of the types of the variables the prompt template expects. + +If not provided, all variables are assumed to be strings. + + + +A list of the names of the variables whose values are required as inputs to the +prompt. + + + +Metadata to be used for tracing. + + + + + + +A list of the names of the variables for placeholder or `MessagePlaceholder` that +are optional. + +These variables are auto inferred from the prompt and user need not provide them. + + + +How to parse the output of calling an LLM on this formatted prompt. + + + +A dictionary of the partial variables the prompt template carries. + +Partial variables populate the template so that you don't need to pass them in every +time you call the prompt. + + + +Tags to be used for tracing. + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate._aformat_prompt_with_error_handling( + inner_input: langchain_core.prompts.base.BasePromptTemplate.dict +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + +async + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate._format_prompt_with_error_handling( + inner_input: langchain_core.prompts.base.BasePromptTemplate.dict +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate._merge_partial_and_user_variables( + kwargs: typing.Any = {} +) -> langchain_core.prompts.base.BasePromptTemplate.dict[str, typing.Any] +``` + + + + + + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate._validate_input( + inner_input: typing.Any +) -> langchain_core.prompts.base.BasePromptTemplate.dict +``` + + + + + + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.aformat( + kwargs: typing.Any = {} +) -> langchain_core.prompts.base.FormatOutputType +``` + + + + + + +async + +Async format the prompt with the inputs. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `FormatOutputType` + +A formatted string. + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.aformat_prompt( + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + +async + +Async create `PromptValue`. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `PromptValue` + +The output of the prompt. + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.ainvoke( + input: langchain_core.prompts.base.BasePromptTemplate.dict, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + +async + +Async invoke the prompt. + +**Parameters:** + + +Input to the prompt. + + + +Configuration for the prompt. + + +**Returns:** `PromptValue` + +The output of the prompt. + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.dict( + kwargs: typing.Any = {} +) -> langchain_core.prompts.base.BasePromptTemplate.dict +``` + + + + + + +Return dictionary representation of prompt. + +**Parameters:** + + +Any additional arguments to pass to the dictionary. + + +**Returns:** `dict` + +Dictionary representation of the prompt. + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.format( + kwargs: typing.Any = {} +) -> langchain_core.prompts.base.FormatOutputType +``` + + + + + + +abstract + +Format the prompt with the inputs. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `FormatOutputType` + +A formatted string. + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.format_prompt( + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + +abstract + +Create `PromptValue`. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `PromptValue` + +The output of the prompt. + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.get_input_schema( + config: langchain_core.runnables.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Get the input schema for the prompt. + +**Parameters:** + + +Configuration for the prompt. + + +**Returns:** `type[BaseModel]` + +The input schema for the prompt. + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "prompt_template"]` + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.invoke( + input: langchain_core.prompts.base.BasePromptTemplate.dict, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + +Invoke the prompt. + +**Parameters:** + + +Input to the prompt. + + + +Configuration for the prompt. + + +**Returns:** `PromptValue` + +The output of the prompt. + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.partial( + kwargs: str | collections.abc.Callable[[], str] = {} +) -> langchain_core.prompts.base.BasePromptTemplate +``` + + + + + + +Return a partial of the prompt template. + +**Parameters:** + + +Partial variables to set. + + +**Returns:** `BasePromptTemplate` + +A partial of the prompt template. + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.save( + file_path: pathlib.Path | str +) -> None +``` + + + + + + +Save the prompt. + +**Parameters:** + + +Path to directory to save prompt to. + + +**Raises:** + +- `ValueError`: If the prompt has partial variables. +- `ValueError`: If the file path is not json or yaml. +- `NotImplementedError`: If the prompt type is not implemented. + + + + + + + +```python +langchain_core.prompts.base.BasePromptTemplate.validate_variable_names() -> typing_extensions.Self +``` + + + + + + +Validate variable names do not include restricted names. + + + + + + + + + +```python +langchain_core.prompts.base._get_document_info( + doc: langchain_core.documents.Document, + prompt: langchain_core.prompts.base.BasePromptTemplate[str] +) -> langchain_core.prompts.dict +``` + + + + + + + + + + + + + +```python +langchain_core.prompts.base.aformat_document( + doc: langchain_core.documents.Document, + prompt: langchain_core.prompts.base.BasePromptTemplate[str] +) -> str +``` + + + + + + +async + +Async format a document into a string based on a prompt template. + +First, this pulls information from the document from two sources: + +1. `page_content`: This takes the information from the `document.page_content` and + assigns it to a variable named `page_content`. +2. `metadata`: This takes information from `document.metadata` and assigns it to + variables of the same name. + +Those variables are then passed into the `prompt` to produce a formatted string. + +**Parameters:** + + +`Document`, the `page_content` and `metadata` will be used to create the +final string. + + + +`BasePromptTemplate`, will be used to format the `page_content` and +`metadata` into the final string. + + +**Returns:** `str` + +String of the document formatted. + + + + + + + + +```python +langchain_core.prompts.base.format_document( + doc: langchain_core.documents.Document, + prompt: langchain_core.prompts.base.BasePromptTemplate[str] +) -> str +``` + + + + + + +Format a document into a string based on a prompt template. + +First, this pulls information from the document from two sources: + +1. `page_content`: This takes the information from the `document.page_content` and + assigns it to a variable named `page_content`. +2. `metadata`: This takes information from `document.metadata` and assigns it to + variables of the same name. + +Those variables are then passed into the `prompt` to produce a formatted string. + +**Parameters:** + + +`Document`, the `page_content` and `metadata` will be used to create the +final string. + + + +`BasePromptTemplate`, will be used to format the `page_content` and +`metadata` into the final string. + + +**Returns:** `str` + +String of the document formatted. + + + + + + + + +```python +langchain_core.prompts.base.FormatOutputType = TypeVar('FormatOutputType') +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/chat.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/chat.mdx new file mode 100644 index 0000000..cb7ae8c --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/chat.mdx @@ -0,0 +1,1931 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts/chat +title: langchain_core.prompts.chat +--- + +Chat prompt template. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AIMessagePromptTemplate`](#langchain_core-prompts-chat-AIMessagePromptTemplate) | AI message prompt template. | +| [`BaseChatPromptTemplate`](#langchain_core-prompts-chat-BaseChatPromptTemplate) | Base class for chat prompt templates. | +| [`BaseStringMessagePromptTemplate`](#langchain_core-prompts-chat-BaseStringMessagePromptTemplate) | Base class for message prompt templates that use a string prompt template. | +| [`ChatMessagePromptTemplate`](#langchain_core-prompts-chat-ChatMessagePromptTemplate) | Chat message prompt template. | +| [`ChatPromptTemplate`](#langchain_core-prompts-chat-ChatPromptTemplate) | Prompt template for chat models. | +| [`HumanMessagePromptTemplate`](#langchain_core-prompts-chat-HumanMessagePromptTemplate) | Human message prompt template. | +| [`MessagesPlaceholder`](#langchain_core-prompts-chat-MessagesPlaceholder) | Prompt template that assumes variable is already list of messages. | +| [`SystemMessagePromptTemplate`](#langchain_core-prompts-chat-SystemMessagePromptTemplate) | System message prompt template. | +| [`_ImageTemplateParam`](#langchain_core-prompts-chat-_ImageTemplateParam) | - | +| [`_StringImageMessagePromptTemplate`](#langchain_core-prompts-chat-_StringImageMessagePromptTemplate) | Human message prompt template. This is a message sent from the user. | +| [`_TextTemplateParam`](#langchain_core-prompts-chat-_TextTemplateParam) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_convert_to_message_template`](#langchain_core-prompts-chat-_convert_to_message_template) | Instantiate a message from a variety of message formats. | +| [`_create_template_from_message_type`](#langchain_core-prompts-chat-_create_template_from_message_type) | Create a message prompt template from a message type and template string. | + +### Data + +[`MessageLike`](#langchain_core-prompts-chat-MessageLike) + +[`MessageLikeRepresentation`](#langchain_core-prompts-chat-MessageLikeRepresentation) + +[`MessagePromptTemplateT`](#langchain_core-prompts-chat-MessagePromptTemplateT) + +[`_convert_to_message`](#langchain_core-prompts-chat-_convert_to_message) + +### API + + + + + +```python +class langchain_core.prompts.chat.AIMessagePromptTemplate() +``` + + + + + + +**Bases:** [_StringImageMessagePromptTemplate](#langchain_core-prompts-chat-_StringImageMessagePromptTemplate) + +AI message prompt template. + +This is a message sent from the AI. + + + + + + + + + + +```python +class langchain_core.prompts.chat.BaseChatPromptTemplate() +``` + + + + + + +Abstract + +**Bases:** [BasePromptTemplate](/langchain-core/langchain_core/prompts/base#langchain_core-prompts-base-BasePromptTemplate) + +Base class for chat prompt templates. + + + + + + + + +```python +langchain_core.prompts.chat.BaseChatPromptTemplate.aformat( + kwargs: typing.Any = {} +) -> str +``` + + + + + + +async + +Async format the chat template into a string. + +**Parameters:** + + +Keyword arguments to use for filling in template variables in all +the template messages in this chat template. + + +**Returns:** `str` + +Formatted string. + + + + + + + +```python +langchain_core.prompts.chat.BaseChatPromptTemplate.aformat_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +async + +Async format kwargs into a list of messages. + +**Returns:** `list[BaseMessage]` + +List of `BaseMessage` objects. + + + + + + + +```python +langchain_core.prompts.chat.BaseChatPromptTemplate.aformat_prompt( + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.ChatPromptValue +``` + + + + + + +async + +Async format prompt. + +Should return a `ChatPromptValue`. + +**Parameters:** + + +Keyword arguments to use for formatting. + + + + + + + + +```python +langchain_core.prompts.chat.BaseChatPromptTemplate.format( + kwargs: typing.Any = {} +) -> str +``` + + + + + + +Format the chat template into a string. + +**Parameters:** + + +Keyword arguments to use for filling in template variables in all +the template messages in this chat template. + + +**Returns:** `str` + +Formatted string. + + + + + + + +```python +langchain_core.prompts.chat.BaseChatPromptTemplate.format_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +abstract + +Format kwargs into a list of messages. + +**Returns:** `list[BaseMessage]` + +List of `BaseMessage` objects. + + + + + + + +```python +langchain_core.prompts.chat.BaseChatPromptTemplate.format_prompt( + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.ChatPromptValue +``` + + + + + + +Format prompt. + +Should return a `ChatPromptValue`. + +**Parameters:** + + +Keyword arguments to use for formatting. + + + + + + + + +```python +langchain_core.prompts.chat.BaseChatPromptTemplate.pretty_print() -> None +``` + + + + + + +Print a human-readable representation. + + + + + + + +```python +langchain_core.prompts.chat.BaseChatPromptTemplate.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Human-readable representation. + +**Parameters:** + + +Whether to format as HTML. + + +**Returns:** `str` + +Human-readable representation. + + + + + + + + + +```python +class langchain_core.prompts.chat.BaseStringMessagePromptTemplate() +``` + + + + + + +Abstract + +**Bases:** [BaseMessagePromptTemplate](/langchain-core/langchain_core/prompts/message#langchain_core-prompts-message-BaseMessagePromptTemplate) + +Base class for message prompt templates that use a string prompt template. + + + +Additional keyword arguments to pass to the prompt template. + + + +Input variables for this prompt template. + + + +String prompt template. + + + + + +```python +langchain_core.prompts.chat.BaseStringMessagePromptTemplate.aformat( + kwargs: typing.Any = {} +) -> langchain_core.messages.BaseMessage +``` + + + + + + +async + +Async format the prompt template. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `BaseMessage` + +Formatted message. + + + + + + + +```python +langchain_core.prompts.chat.BaseStringMessagePromptTemplate.aformat_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +async + +Async format messages from kwargs. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `list[BaseMessage]` + +List of `BaseMessage` objects. + + + + + + + +```python +langchain_core.prompts.chat.BaseStringMessagePromptTemplate.format( + kwargs: typing.Any = {} +) -> langchain_core.messages.BaseMessage +``` + + + + + + +abstract + +Format the prompt template. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `BaseMessage` + +Formatted message. + + + + + + + +```python +langchain_core.prompts.chat.BaseStringMessagePromptTemplate.format_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +Format messages from kwargs. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `list[BaseMessage]` + +List of `BaseMessage` objects. + + + + + + + +```python +langchain_core.prompts.chat.BaseStringMessagePromptTemplate.from_template( + template: str, + template_format: langchain_core.prompts.string.PromptTemplateFormat = 'f-string', + partial_variables: langchain_core.prompts.dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> typing_extensions.Self +``` + + + + + + +classmethod + +Create a class from a string template. + +**Parameters:** + + +a template. + + + +format of the template. + + + +A dictionary of variables that can be used to partially +fill in the template. + +For example, if the template is `"{variable1} {variable2}"`, and +`partial_variables` is `{"variable1": "foo"}`, then the final prompt +will be `"foo {variable2}"`. + + + +Keyword arguments to pass to the constructor. + + +**Returns:** `Self` + +A new instance of this class. + + + + + + + +```python +langchain_core.prompts.chat.BaseStringMessagePromptTemplate.from_template_file( + template_file: str | pathlib.Path, + kwargs: typing.Any = {} +) -> typing_extensions.Self +``` + + + + + + +classmethod + +Create a class from a template file. + +**Parameters:** + + +path to a template file. + + + +Keyword arguments to pass to the constructor. + + +**Returns:** `Self` + +A new instance of this class. + + + + + + + +```python +langchain_core.prompts.chat.BaseStringMessagePromptTemplate.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Human-readable representation. + +**Parameters:** + + +Whether to format as HTML. + + +**Returns:** `str` + +Human-readable representation. + + + + + + + + + +```python +class langchain_core.prompts.chat.ChatMessagePromptTemplate() +``` + + + + + + +**Bases:** [BaseStringMessagePromptTemplate](#langchain_core-prompts-chat-BaseStringMessagePromptTemplate) + +Chat message prompt template. + + + +Role of the message. + + + + + +```python +langchain_core.prompts.chat.ChatMessagePromptTemplate.aformat( + kwargs: typing.Any = {} +) -> langchain_core.messages.BaseMessage +``` + + + + + + +async + +Async format the prompt template. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `BaseMessage` + +Formatted message. + + + + + + + +```python +langchain_core.prompts.chat.ChatMessagePromptTemplate.format( + kwargs: typing.Any = {} +) -> langchain_core.messages.BaseMessage +``` + + + + + + +Format the prompt template. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `BaseMessage` + +Formatted message. + + + + + + + + + +```python +class langchain_core.prompts.chat.ChatPromptTemplate( + messages: collections.abc.Sequence[langchain_core.prompts.chat.MessageLikeRepresentation], + template_format: langchain_core.prompts.string.PromptTemplateFormat = 'f-string', + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseChatPromptTemplate](#langchain_core-prompts-chat-BaseChatPromptTemplate) + +Prompt template for chat models. + +Use to create flexible templated prompts for chat models. + +!!! example + + ```python + from langchain_core.prompts import ChatPromptTemplate + + template = ChatPromptTemplate( + [ + ("system", "You are a helpful AI bot. Your name is {name}."), + ("human", "Hello, how are you doing?"), + ("ai", "I'm doing well, thanks!"), + ("human", "{user_input}"), + ] + ) + + prompt_value = template.invoke( + { + "name": "Bob", + "user_input": "What is your name?", + } + ) + # Output: + # ChatPromptValue( + # messages=[ + # SystemMessage(content='You are a helpful AI bot. Your name is Bob.'), + # HumanMessage(content='Hello, how are you doing?'), + # AIMessage(content="I'm doing well, thanks!"), + # HumanMessage(content='What is your name?') + # ] + # ) + ``` + +!!! note "Messages Placeholder" + + ```python + # In addition to Human/AI/Tool/Function messages, + # you can initialize the template with a MessagesPlaceholder + # either using the class directly or with the shorthand tuple syntax: + + template = ChatPromptTemplate( + [ + ("system", "You are a helpful AI bot."), + # Means the template will receive an optional list of messages under + # the "conversation" key + ("placeholder", "{conversation}"), + # Equivalently: + # MessagesPlaceholder(variable_name="conversation", optional=True) + ] + ) + + prompt_value = template.invoke( + { + "conversation": [ + ("human", "Hi!"), + ("ai", "How can I assist you today?"), + ("human", "Can you make me an ice cream sundae?"), + ("ai", "No."), + ] + } + ) + + # Output: + # ChatPromptValue( + # messages=[ + # SystemMessage(content='You are a helpful AI bot.'), + # HumanMessage(content='Hi!'), + # AIMessage(content='How can I assist you today?'), + # HumanMessage(content='Can you make me an ice cream sundae?'), + # AIMessage(content='No.'), + # ] + # ) + ``` + +!!! note "Single-variable template" + + If your prompt has only a single input variable (i.e., one instance of + `'{variable_nams}'`), and you invoke the template with a non-dict object, the + prompt template will inject the provided argument into that variable location. + + ```python + from langchain_core.prompts import ChatPromptTemplate + + template = ChatPromptTemplate( + [ + ("system", "You are a helpful AI bot. Your name is Carl."), + ("human", "{user_input}"), + ] + ) + + prompt_value = template.invoke("Hello, there!") + # Equivalent to + # prompt_value = template.invoke({"user_input": "Hello, there!"}) + + # Output: + # ChatPromptValue( + # messages=[ + # SystemMessage(content='You are a helpful AI bot. Your name is Carl.'), + # HumanMessage(content='Hello, there!'), + # ] + # ) + ``` + + + +Name of prompt type. Used for serialization. + + + +List of messages consisting of either message prompt templates or messages. + + + +Whether or not to try validating the template. + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.__add__( + other: typing.Any +) -> langchain_core.prompts.chat.ChatPromptTemplate +``` + + + + + + +Combine two prompt templates. + +**Parameters:** + + +Another prompt template. + + +**Returns:** `ChatPromptTemplate` + +Combined prompt template. + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.__getitem__( + index: int | slice +) -> langchain_core.prompts.chat.MessageLike | langchain_core.prompts.chat.ChatPromptTemplate +``` + + + + + + +Use to index into the chat template. + +**Returns:** `MessageLike | ChatPromptTemplate` + +If index is an int, returns the message at that index. + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.__len__() -> int +``` + + + + + + +Return the length of the chat template. + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.aformat_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +async + +Async format the chat template into a list of finalized messages. + +**Parameters:** + + +Keyword arguments to use for filling in template variables +in all the template messages in this chat template. + + +**Returns:** `list[BaseMessage]` + +List of formatted messages. + +**Raises:** + +- `ValueError`: If unexpected input. + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.append( + message: langchain_core.prompts.chat.MessageLikeRepresentation +) -> None +``` + + + + + + +Append a message to the end of the chat template. + +**Parameters:** + + +representation of a message to append. + + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.extend( + messages: collections.abc.Sequence[langchain_core.prompts.chat.MessageLikeRepresentation] +) -> None +``` + + + + + + +Extend the chat template with a sequence of messages. + +**Parameters:** + + +Sequence of message representations to append. + + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.format_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +Format the chat template into a list of finalized messages. + +**Parameters:** + + +Keyword arguments to use for filling in template variables +in all the template messages in this chat template. + + +**Returns:** `list[BaseMessage]` + +List of formatted messages. + +**Raises:** + +- `ValueError`: If messages are of unexpected types. + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.from_messages( + messages: collections.abc.Sequence[langchain_core.prompts.chat.MessageLikeRepresentation], + template_format: langchain_core.prompts.string.PromptTemplateFormat = 'f-string' +) -> langchain_core.prompts.chat.ChatPromptTemplate +``` + + + + + + +classmethod + +Create a chat prompt template from a variety of message formats. + +Args: + messages: Sequence of message representations. + + A message can be represented using the following formats: + + 1. `BaseMessagePromptTemplate` + 2. `BaseMessage` + 3. 2-tuple of `(message type, template)`; e.g., + `('human', '{user_input}')` + 4. 2-tuple of `(message class, template)` + 5. A string which is shorthand for `('human', template)`; e.g., + `'{user_input}'` + template_format: Format of the template. + +**Returns:** `ChatPromptTemplate` + +A chat prompt template. + +**Examples:** + + + +```python +Instantiation from a list of message templates: + +```python +template = ChatPromptTemplate.from_messages( + [ + ("human", "Hello, how are you?"), + ("ai", "I'm doing well, thanks!"), + ("human", "That's good to hear."), + ] +) +``` + +Instantiation from mixed message formats: + +```python +template = ChatPromptTemplate.from_messages( + [ + SystemMessage(content="hello"), + ("human", "Hello, how are you?"), + ] +) +``` + + + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.from_template( + template: str, + kwargs: typing.Any = {} +) -> langchain_core.prompts.chat.ChatPromptTemplate +``` + + + + + + +classmethod + +Create a chat prompt template from a template string. + +Creates a chat template consisting of a single message assumed to be from the +human. + +**Parameters:** + + +Template string + + + +Keyword arguments to pass to the constructor. + + +**Returns:** `ChatPromptTemplate` + +A new instance of this class. + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "prompts", "chat"]` + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.partial( + kwargs: typing.Any = {} +) -> langchain_core.prompts.chat.ChatPromptTemplate +``` + + + + + + +Get a new `ChatPromptTemplate` with some input variables already filled in. + +**Parameters:** + + +Keyword arguments to use for filling in template variables. + +Ought to be a subset of the input variables. + + +**Returns:** `ChatPromptTemplate` + +A new `ChatPromptTemplate`. + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Human-readable representation. + +**Parameters:** + + +Whether to format as HTML. + + +**Returns:** `str` + +Human-readable representation. + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.save( + file_path: pathlib.Path | str +) -> None +``` + + + + + + +Save prompt to file. + +**Parameters:** + + +path to file. + + + + + + + + +```python +langchain_core.prompts.chat.ChatPromptTemplate.validate_input_variables( + values: langchain_core.prompts.dict +) -> typing.Any +``` + + + + + + +classmethod + +Validate input variables. + +If `input_variables` is not set, it will be set to the union of all input +variables in the messages. + +**Parameters:** + + +values to validate. + + +**Returns:** `Any` + +Validated values. + +**Raises:** + +- `ValueError`: If input variables do not match. + + + + + + + + + +```python +class langchain_core.prompts.chat.HumanMessagePromptTemplate() +``` + + + + + + +**Bases:** [_StringImageMessagePromptTemplate](#langchain_core-prompts-chat-_StringImageMessagePromptTemplate) + +Human message prompt template. + +This is a message sent from the user. + + + + + + + + + + +```python +class langchain_core.prompts.chat.MessagesPlaceholder( + variable_name: str, + optional: bool = False, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseMessagePromptTemplate](/langchain-core/langchain_core/prompts/message#langchain_core-prompts-message-BaseMessagePromptTemplate) + +Prompt template that assumes variable is already list of messages. + +A placeholder which can be used to pass in a list of messages. + +!!! example "Direct usage" + + ```python + from langchain_core.prompts import MessagesPlaceholder + + prompt = MessagesPlaceholder("history") + prompt.format_messages() # raises KeyError + + prompt = MessagesPlaceholder("history", optional=True) + prompt.format_messages() # returns empty list [] + + prompt.format_messages( + history=[ + ("system", "You are an AI assistant."), + ("human", "Hello!"), + ] + ) + # -> [ + # SystemMessage(content="You are an AI assistant."), + # HumanMessage(content="Hello!"), + # ] + ``` + +!!! example "Building a prompt with chat history" + + ```python + from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful assistant."), + MessagesPlaceholder("history"), + ("human", "{question}"), + ] + ) + prompt.invoke( + { + "history": [("human", "what's 5 + 2"), ("ai", "5 + 2 is 7")], + "question": "now multiply that by 4", + } + ) + # -> ChatPromptValue(messages=[ + # SystemMessage(content="You are a helpful assistant."), + # HumanMessage(content="what's 5 + 2"), + # AIMessage(content="5 + 2 is 7"), + # HumanMessage(content="now multiply that by 4"), + # ]) + ``` + +!!! example "Limiting the number of messages" + + ```python + from langchain_core.prompts import MessagesPlaceholder + + prompt = MessagesPlaceholder("history", n_messages=1) + + prompt.format_messages( + history=[ + ("system", "You are an AI assistant."), + ("human", "Hello!"), + ] + ) + # -> [ + # HumanMessage(content="Hello!"), + # ] + ``` + + + +Input variables for this prompt template. + + + +Maximum number of messages to include. + +If `None`, then will include all. + + + +Whether `format_messages` must be provided. + +If `True` `format_messages` can be called with no arguments and will return an empty +list. + +If `False` then a named argument with name `variable_name` must be passed in, even +if the value is an empty list. + + + +Name of variable to use as messages. + + + + + +```python +langchain_core.prompts.chat.MessagesPlaceholder.format_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +Format messages from kwargs. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `list[BaseMessage]` + +List of `BaseMessage` objects. + +**Raises:** + +- `ValueError`: If variable is not a list of messages. + + + + + + + +```python +langchain_core.prompts.chat.MessagesPlaceholder.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Human-readable representation. + +**Parameters:** + + +Whether to format as HTML. + + +**Returns:** `str` + +Human-readable representation. + + + + + + + + + +```python +class langchain_core.prompts.chat.SystemMessagePromptTemplate() +``` + + + + + + +**Bases:** [_StringImageMessagePromptTemplate](#langchain_core-prompts-chat-_StringImageMessagePromptTemplate) + +System message prompt template. + +This is a message that is not sent to the user. + + + + + + + + + + +```python +class langchain_core.prompts.chat._ImageTemplateParam +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class langchain_core.prompts.chat._StringImageMessagePromptTemplate() +``` + + + + + + +**Bases:** [BaseMessagePromptTemplate](/langchain-core/langchain_core/prompts/message#langchain_core-prompts-message-BaseMessagePromptTemplate) + +Human message prompt template. This is a message sent from the user. + + + + + + +Additional keyword arguments to pass to the prompt template. + + + +Input variables for this prompt template. + + + +Prompt template. + + + + + +```python +langchain_core.prompts.chat._StringImageMessagePromptTemplate.aformat( + kwargs: typing.Any = {} +) -> langchain_core.messages.BaseMessage +``` + + + + + + +async + +Async format the prompt template. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `BaseMessage` + +Formatted message. + + + + + + + +```python +langchain_core.prompts.chat._StringImageMessagePromptTemplate.aformat_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +async + +Async format messages from kwargs. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `list[BaseMessage]` + +List of `BaseMessage` objects. + + + + + + + +```python +langchain_core.prompts.chat._StringImageMessagePromptTemplate.format( + kwargs: typing.Any = {} +) -> langchain_core.messages.BaseMessage +``` + + + + + + +Format the prompt template. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `BaseMessage` + +Formatted message. + + + + + + + +```python +langchain_core.prompts.chat._StringImageMessagePromptTemplate.format_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +Format messages from kwargs. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `list[BaseMessage]` + +List of `BaseMessage` objects. + + + + + + + +```python +langchain_core.prompts.chat._StringImageMessagePromptTemplate.from_template( + template: str | list[str | langchain_core.prompts.chat._TextTemplateParam | langchain_core.prompts.chat._ImageTemplateParam | langchain_core.prompts.dict[str, typing.Any]], + template_format: langchain_core.prompts.string.PromptTemplateFormat = 'f-string', + partial_variables: langchain_core.prompts.dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> typing_extensions.Self +``` + + + + + + +classmethod + +Create a class from a string template. + +**Parameters:** + + +a template. + + + +format of the template. + +Options are: `'f-string'`, `'mustache'`, `'jinja2'`. + + + +A dictionary of variables that can be used too partially. + + + +Keyword arguments to pass to the constructor. + + +**Returns:** `Self` + +A new instance of this class. + +**Raises:** + +- `ValueError`: If the template is not a string or list of strings. + + + + + + + +```python +langchain_core.prompts.chat._StringImageMessagePromptTemplate.from_template_file( + template_file: str | pathlib.Path, + input_variables: list[str], + kwargs: typing.Any = {} +) -> typing_extensions.Self +``` + + + + + + +classmethod + +Create a class from a template file. + +**Parameters:** + + +path to a template file. + + + +list of input variables. + + + +Keyword arguments to pass to the constructor. + + +**Returns:** `Self` + +A new instance of this class. + + + + + + + +```python +langchain_core.prompts.chat._StringImageMessagePromptTemplate.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Human-readable representation. + +**Parameters:** + + +Whether to format as HTML. + + +**Returns:** `str` + +Human-readable representation. + + + + + + + + + +```python +class langchain_core.prompts.chat._TextTemplateParam +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +langchain_core.prompts.chat._convert_to_message_template( + message: langchain_core.prompts.chat.MessageLikeRepresentation, + template_format: langchain_core.prompts.string.PromptTemplateFormat = 'f-string' +) -> langchain_core.messages.BaseMessage | langchain_core.prompts.message.BaseMessagePromptTemplate | langchain_core.prompts.chat.BaseChatPromptTemplate +``` + + + + + + +Instantiate a message from a variety of message formats. + +A message can be represented using the following formats: + +1. `BaseMessagePromptTemplate` +2. `BaseMessage` +3. 2-tuple of `(message type, template)`; e.g., `('human', '{user_input}')` +4. 2-tuple of `(message class, template)` +5. A string which is shorthand for `('human', template)`; e.g., `'{user_input}'` + +**Parameters:** + + +A representation of a message in one of the supported formats. + + + +Format of the template. + + +**Returns:** `BaseMessage | BaseMessagePromptTemplate | BaseChatPromptTemplate` + +An instance of a message or a message template. + +**Raises:** + +- `ValueError`: If unexpected message type. +- `ValueError`: If 2-tuple does not have 2 elements. + + + + + + + + +```python +langchain_core.prompts.chat._create_template_from_message_type( + message_type: str, + template: str | list, + template_format: langchain_core.prompts.string.PromptTemplateFormat = 'f-string' +) -> langchain_core.prompts.message.BaseMessagePromptTemplate +``` + + + + + + +Create a message prompt template from a message type and template string. + +**Parameters:** + + +The type of the message template (e.g., `'human'`, `'ai'`, etc.) + + + +The template string. + + + +Format of the template. + + +**Returns:** `BaseMessagePromptTemplate` + +A message prompt template of the appropriate type. + +**Raises:** + +- `ValueError`: If unexpected message type. + + + + + + + + +```python +langchain_core.prompts.chat.MessageLike = BaseMessagePromptTemplate | BaseMessage | BaseChatPromptTemplate +``` + + + + + + + + + +```python +langchain_core.prompts.chat.MessageLikeRepresentation = MessageLike | tuple[str | type, str | Sequence[dict] | Sequence[object]] | str |... +``` + + + + + + + + + +```python +langchain_core.prompts.chat.MessagePromptTemplateT = TypeVar('MessagePromptTemplateT', bound='BaseStringMessagePromptTemplate') +``` + + + + + + +Type variable for message prompt templates. + + + + + + + +```python +langchain_core.prompts.chat._convert_to_message = _convert_to_message_template +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/dict.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/dict.mdx new file mode 100644 index 0000000..a790420 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/dict.mdx @@ -0,0 +1,240 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts/dict +title: langchain_core.prompts.dict +--- + +Dictionary prompt template. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DictPromptTemplate`](#langchain_core-prompts-dict-DictPromptTemplate) | Template represented by a dictionary. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_input_variables`](#langchain_core-prompts-dict-_get_input_variables) | - | +| [`_insert_input_variables`](#langchain_core-prompts-dict-_insert_input_variables) | - | + +### API + + + + + +```python +class langchain_core.prompts.dict.DictPromptTemplate() +``` + + + + + + +**Bases:** [RunnableSerializable[dict, dict]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +Template represented by a dictionary. + +Recognizes variables in f-string or mustache formatted string dict values. + +Does NOT recognize variables in dict keys. Applies recursively. + + + + + + + + + +Template input variables. + + + + + + + + + + + +```python +langchain_core.prompts.dict.DictPromptTemplate.aformat( + kwargs: typing.Any = {} +) -> langchain_core.prompts.dict[str, typing.Any] +``` + + + + + + +async + +Format the prompt with the inputs. + +**Returns:** `dict[str, Any]` + +A formatted dict. + + + + + + + +```python +langchain_core.prompts.dict.DictPromptTemplate.format( + kwargs: typing.Any = {} +) -> langchain_core.prompts.dict[str, typing.Any] +``` + + + + + + +Format the prompt with the inputs. + +**Returns:** `dict[str, Any]` + +A formatted dict. + + + + + + + +```python +langchain_core.prompts.dict.DictPromptTemplate.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain_core", "prompts", "dict"]` + + + + + + + +```python +langchain_core.prompts.dict.DictPromptTemplate.invoke( + input: langchain_core.prompts.dict, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.prompts.dict +``` + + + + + + + + + + + + +```python +langchain_core.prompts.dict.DictPromptTemplate.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.prompts.dict.DictPromptTemplate.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Human-readable representation. + +**Parameters:** + + +Whether to format as HTML. + + +**Returns:** `str` + +Human-readable representation. + + + + + + + + + +```python +langchain_core.prompts.dict._get_input_variables( + template: langchain_core.prompts.dict, + template_format: typing.Literal['f-string', 'mustache'] +) -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.prompts.dict._insert_input_variables( + template: langchain_core.prompts.dict[str, typing.Any], + inputs: langchain_core.prompts.dict[str, typing.Any], + template_format: typing.Literal['f-string', 'mustache'] +) -> langchain_core.prompts.dict[str, typing.Any] +``` + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/few_shot.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/few_shot.mdx new file mode 100644 index 0000000..f2741a5 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/few_shot.mdx @@ -0,0 +1,671 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts/few_shot +title: langchain_core.prompts.few_shot +--- + +Prompt template that contains few shot examples. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FewShotChatMessagePromptTemplate`](#langchain_core-prompts-few_shot-FewShotChatMessagePromptTemplate) | Chat prompt template that supports few-shot examples. | +| [`FewShotPromptTemplate`](#langchain_core-prompts-few_shot-FewShotPromptTemplate) | Prompt template that contains few shot examples. | +| [`_FewShotPromptTemplateMixin`](#langchain_core-prompts-few_shot-_FewShotPromptTemplateMixin) | Prompt template that contains few shot examples. | + +### API + + + + + +```python +class langchain_core.prompts.few_shot.FewShotChatMessagePromptTemplate() +``` + + + + + + +**Bases:** [BaseChatPromptTemplate](/langchain-core/langchain_core/prompts/chat#langchain_core-prompts-chat-BaseChatPromptTemplate), [_FewShotPromptTemplateMixin](#langchain_core-prompts-few_shot-_FewShotPromptTemplateMixin) + +Chat prompt template that supports few-shot examples. + +The high level structure of produced by this prompt template is a list of messages +consisting of prefix message(s), example message(s), and suffix message(s). + +This structure enables creating a conversation with intermediate examples like: + + + +```python +System: You are a helpful AI Assistant + +Human: What is 2+2? + +AI: 4 + +Human: What is 2+3? + +AI: 5 + +Human: What is 4+4? +``` + + + +This prompt template can be used to generate a fixed list of examples or else to +dynamically select examples based on the input. + +**Examples:** + + + +```python +Prompt template with a fixed list of examples (matching the sample +conversation above): + +```python +from langchain_core.prompts import ( + FewShotChatMessagePromptTemplate, + ChatPromptTemplate, +) + +examples = [ + {"input": "2+2", "output": "4"}, + {"input": "2+3", "output": "5"}, +] + +example_prompt = ChatPromptTemplate.from_messages( + [ + ("human", "What is {input}?"), + ("ai", "{output}"), + ] +) + +few_shot_prompt = FewShotChatMessagePromptTemplate( + examples=examples, + # This is a prompt template used to format each individual example. + example_prompt=example_prompt, +) + +final_prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful AI Assistant"), + few_shot_prompt, + ("human", "{input}"), + ] +) +final_prompt.format(input="What is 4+4?") +``` + +Prompt template with dynamically selected examples: + +```python +from langchain_core.prompts import SemanticSimilarityExampleSelector +from langchain_core.embeddings import OpenAIEmbeddings +from langchain_core.vectorstores import Chroma + +examples = [ + {"input": "2+2", "output": "4"}, + {"input": "2+3", "output": "5"}, + {"input": "2+4", "output": "6"}, + # ... +] + +to_vectorize = [" ".join(example.values()) for example in examples] +embeddings = OpenAIEmbeddings() +vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=examples) +example_selector = SemanticSimilarityExampleSelector(vectorstore=vectorstore) + +from langchain_core import SystemMessage +from langchain_core.prompts import HumanMessagePromptTemplate +from langchain_core.prompts.few_shot import FewShotChatMessagePromptTemplate + +few_shot_prompt = FewShotChatMessagePromptTemplate( + # Which variable(s) will be passed to the example selector. + input_variables=["input"], + example_selector=example_selector, + # Define how each example will be formatted. + # In this case, each example will become 2 messages: + # 1 human, and 1 AI + example_prompt=( + HumanMessagePromptTemplate.from_template("{input}") + + AIMessagePromptTemplate.from_template("{output}") + ), +) +# Define the overall prompt. +final_prompt = ( + SystemMessagePromptTemplate.from_template("You are a helpful AI Assistant") + + few_shot_prompt + + HumanMessagePromptTemplate.from_template("{input}") +) +# Show the prompt +print(final_prompt.format_messages(input="What's 3+3?")) # noqa: T201 + +# Use within an LLM +from langchain_core.chat_models import ChatAnthropic + +chain = final_prompt | ChatAnthropic(model="claude-3-haiku-20240307") +chain.invoke({"input": "What's 3+3?"}) +``` + + + + + +The class to format each example. + + + +A list of the names of the variables the prompt template will use to pass to +the `example_selector`, if provided. + + + + + + + + +```python +langchain_core.prompts.few_shot.FewShotChatMessagePromptTemplate.aformat( + kwargs: typing.Any = {} +) -> str +``` + + + + + + +async + +Async format the prompt with inputs generating a string. + +Use this method to generate a string representation of a prompt consisting of +chat messages. + +Useful for feeding into a string-based completion language model or debugging. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `str` + +A string representation of the prompt + + + + + + + +```python +langchain_core.prompts.few_shot.FewShotChatMessagePromptTemplate.aformat_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +async + +Async format kwargs into a list of messages. + +**Parameters:** + + +Keyword arguments to use for filling in templates in messages. + + +**Returns:** `list[BaseMessage]` + +A list of formatted messages with all template variables filled in. + + + + + + + +```python +langchain_core.prompts.few_shot.FewShotChatMessagePromptTemplate.format( + kwargs: typing.Any = {} +) -> str +``` + + + + + + +Format the prompt with inputs generating a string. + +Use this method to generate a string representation of a prompt consisting of +chat messages. + +Useful for feeding into a string-based completion language model or debugging. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `str` + +A string representation of the prompt + + + + + + + +```python +langchain_core.prompts.few_shot.FewShotChatMessagePromptTemplate.format_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +Format kwargs into a list of messages. + +**Parameters:** + + +Keyword arguments to use for filling in templates in messages. + + +**Returns:** `list[BaseMessage]` + +A list of formatted messages with all template variables filled in. + + + + + + + +```python +langchain_core.prompts.few_shot.FewShotChatMessagePromptTemplate.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `False` as this class is not serializable. + + + + + + + +```python +langchain_core.prompts.few_shot.FewShotChatMessagePromptTemplate.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Return a pretty representation of the prompt template. + +**Parameters:** + + +Whether or not to return an HTML formatted string. + + +**Returns:** `str` + +A pretty representation of the prompt template. + + + + + + + + + +```python +class langchain_core.prompts.few_shot.FewShotPromptTemplate( + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [_FewShotPromptTemplateMixin](#langchain_core-prompts-few_shot-_FewShotPromptTemplateMixin), [StringPromptTemplate](/langchain-core/langchain_core/prompts/string#langchain_core-prompts-string-StringPromptTemplate) + +Prompt template that contains few shot examples. + + + +Return the prompt type key. + + + +`PromptTemplate` used to format an individual example. + + + +String separator used to join the prefix, the examples, and suffix. + + + + + + +A prompt template string to put before the examples. + + + +A prompt template string to put after the examples. + + + +The format of the prompt template. + +Options are: `'f-string'`, `'jinja2'`. + + + +Whether or not to try validating the template. + + + + + +```python +langchain_core.prompts.few_shot.FewShotPromptTemplate.aformat( + kwargs: typing.Any = {} +) -> str +``` + + + + + + +async + +Async format the prompt with inputs generating a string. + +Use this method to generate a string representation of a prompt. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `str` + +A string representation of the prompt. + + + + + + + +```python +langchain_core.prompts.few_shot.FewShotPromptTemplate.format( + kwargs: typing.Any = {} +) -> str +``` + + + + + + +Format the prompt with inputs generating a string. + +Use this method to generate a string representation of a prompt. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `str` + +A string representation of the prompt. + + + + + + + +```python +langchain_core.prompts.few_shot.FewShotPromptTemplate.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `False` as this class is not serializable. + + + + + + + +```python +langchain_core.prompts.few_shot.FewShotPromptTemplate.save( + file_path: pathlib.Path | str +) -> None +``` + + + + + + +Save the prompt template to a file. + +**Parameters:** + + +The path to save the prompt template to. + + +**Raises:** + +- `ValueError`: If `example_selector` is provided. + + + + + + + +```python +langchain_core.prompts.few_shot.FewShotPromptTemplate.template_is_valid() -> typing_extensions.Self +``` + + + + + + +Check that prefix, suffix, and input variables are consistent. + + + + + + + + + +```python +class langchain_core.prompts.few_shot._FewShotPromptTemplateMixin() +``` + + + + + + +**Bases:** `BaseModel` + +Prompt template that contains few shot examples. + + + +`ExampleSelector` to choose the examples to format into the prompt. + +Either this or `examples` should be provided. + + + +Examples to format into the prompt. + +Either this or `example_selector` should be provided. + + + + + + + + +```python +langchain_core.prompts.few_shot._FewShotPromptTemplateMixin._aget_examples( + kwargs: typing.Any = {} +) -> list[langchain_core.prompts.dict] +``` + + + + + + +async + +Async get the examples to use for formatting the prompt. + +**Parameters:** + + +Keyword arguments to be passed to the example selector. + + +**Returns:** `list[dict]` + +List of examples. + +**Raises:** + +- `ValueError`: If neither `examples` nor `example_selector` are provided. + + + + + + + +```python +langchain_core.prompts.few_shot._FewShotPromptTemplateMixin._get_examples( + kwargs: typing.Any = {} +) -> list[langchain_core.prompts.dict] +``` + + + + + + +Get the examples to use for formatting the prompt. + +**Parameters:** + + +Keyword arguments to be passed to the example selector. + + +**Returns:** `list[dict]` + +List of examples. + +**Raises:** + +- `ValueError`: If neither `examples` nor `example_selector` are provided. + + + + + + + +```python +langchain_core.prompts.few_shot._FewShotPromptTemplateMixin.check_examples_and_selector( + values: langchain_core.prompts.dict +) -> typing.Any +``` + + + + + + +classmethod + +Check that one and only one of `examples`/`example_selector` are provided. + +**Parameters:** + + +The values to check. + + +**Returns:** `Any` + +The values if they are valid. + +**Raises:** + +- `ValueError`: If neither or both `examples` and `example_selector` are +provided. +- `ValueError`: If both `examples` and `example_selector` are provided. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/few_shot_with_templates.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/few_shot_with_templates.mdx new file mode 100644 index 0000000..3946733 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/few_shot_with_templates.mdx @@ -0,0 +1,267 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts/few_shot_with_templates +title: langchain_core.prompts.few_shot_with_templates +--- + +Prompt template that contains few shot examples. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FewShotPromptWithTemplates`](#langchain_core-prompts-few_shot_with_templates-FewShotPromptWithTemplates) | Prompt template that contains few shot examples. | + +### API + + + + + +```python +class langchain_core.prompts.few_shot_with_templates.FewShotPromptWithTemplates() +``` + + + + + + +**Bases:** [StringPromptTemplate](/langchain-core/langchain_core/prompts/string#langchain_core-prompts-string-StringPromptTemplate) + +Prompt template that contains few shot examples. + + + +Return the prompt type key. + + + +`PromptTemplate` used to format an individual example. + + + +`ExampleSelector` to choose the examples to format into the prompt. + +Either this or `examples` should be provided. + + + +String separator used to join the prefix, the examples, and suffix. + + + +Examples to format into the prompt. + +Either this or `example_selector` should be provided. + + + + + + +A `PromptTemplate` to put before the examples. + + + +A `PromptTemplate` to put after the examples. + + + +The format of the prompt template. + +Options are: `'f-string'`, `'jinja2'`, `'mustache'`. + + + +Whether or not to try validating the template. + + + + + +```python +langchain_core.prompts.few_shot_with_templates.FewShotPromptWithTemplates._aget_examples( + kwargs: typing.Any = {} +) -> list[langchain_core.prompts.dict] +``` + + + + + + +async + + + + + + + +```python +langchain_core.prompts.few_shot_with_templates.FewShotPromptWithTemplates._get_examples( + kwargs: typing.Any = {} +) -> list[langchain_core.prompts.dict] +``` + + + + + + + + + + + + +```python +langchain_core.prompts.few_shot_with_templates.FewShotPromptWithTemplates.aformat( + kwargs: typing.Any = {} +) -> str +``` + + + + + + +async + +Async format the prompt with the inputs. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `str` + +A formatted string. + + + + + + + +```python +langchain_core.prompts.few_shot_with_templates.FewShotPromptWithTemplates.check_examples_and_selector( + values: langchain_core.prompts.dict +) -> typing.Any +``` + + + + + + +classmethod + +Check that one and only one of examples/example_selector are provided. + + + + + + + +```python +langchain_core.prompts.few_shot_with_templates.FewShotPromptWithTemplates.format( + kwargs: typing.Any = {} +) -> str +``` + + + + + + +Format the prompt with the inputs. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `str` + +A formatted string. + + + + + + + +```python +langchain_core.prompts.few_shot_with_templates.FewShotPromptWithTemplates.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "prompts", "few_shot_with_templates"]` + + + + + + + +```python +langchain_core.prompts.few_shot_with_templates.FewShotPromptWithTemplates.save( + file_path: pathlib.Path | str +) -> None +``` + + + + + + +Save the prompt to a file. + +**Parameters:** + + +The path to save the prompt to. + + +**Raises:** + +- `ValueError`: If `example_selector` is provided. + + + + + + + +```python +langchain_core.prompts.few_shot_with_templates.FewShotPromptWithTemplates.template_is_valid() -> typing_extensions.Self +``` + + + + + + +Check that prefix, suffix, and input variables are consistent. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/image.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/image.mdx new file mode 100644 index 0000000..247e2b4 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/image.mdx @@ -0,0 +1,230 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts/image +title: langchain_core.prompts.image +--- + +Image prompt template for a multimodal model. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ImagePromptTemplate`](#langchain_core-prompts-image-ImagePromptTemplate) | Image prompt template for a multimodal model. | + +### API + + + + + +```python +class langchain_core.prompts.image.ImagePromptTemplate( + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BasePromptTemplate[ImageURL]](/langchain-core/langchain_core/prompts/base#langchain_core-prompts-base-BasePromptTemplate) + +Image prompt template for a multimodal model. + + + +Return the prompt type key. + + + +Template for the prompt. + + + +The format of the prompt template. + +Options are: `'f-string'`, `'mustache'`, `'jinja2'`. + + + + + +```python +langchain_core.prompts.image.ImagePromptTemplate.aformat( + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.ImageURL +``` + + + + + + +async + +Async format the prompt with the inputs. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `ImageURL` + +A formatted string. + + + + + + + +```python +langchain_core.prompts.image.ImagePromptTemplate.aformat_prompt( + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + +async + +Async format the prompt with the inputs. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `PromptValue` + +A formatted string. + + + + + + + +```python +langchain_core.prompts.image.ImagePromptTemplate.format( + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.ImageURL +``` + + + + + + +Format the prompt with the inputs. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `ImageURL` + +A formatted string. + +**Raises:** + +- `ValueError`: If the url is not provided. +- `ValueError`: If the url is not a string. +- `ValueError`: If `'path'` is provided in the template or kwargs. + + + + + + + +```python +langchain_core.prompts.image.ImagePromptTemplate.format_prompt( + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + +Format the prompt with the inputs. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `PromptValue` + +A formatted string. + + + + + + + +```python +langchain_core.prompts.image.ImagePromptTemplate.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "prompts", "image"]` + + + + + + + +```python +langchain_core.prompts.image.ImagePromptTemplate.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Return a pretty representation of the prompt. + +**Parameters:** + + +Whether to return an html formatted string. + + +**Returns:** `str` + +A pretty representation of the prompt. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/loading.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/loading.mdx new file mode 100644 index 0000000..4f5b347 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/loading.mdx @@ -0,0 +1,284 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts/loading +title: langchain_core.prompts.loading +--- + +Load prompts. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_load_chat_prompt`](#langchain_core-prompts-loading-_load_chat_prompt) | Load chat prompt from config. | +| [`_load_examples`](#langchain_core-prompts-loading-_load_examples) | Load examples if necessary. | +| [`_load_few_shot_prompt`](#langchain_core-prompts-loading-_load_few_shot_prompt) | Load the "few shot" prompt from the config. | +| [`_load_output_parser`](#langchain_core-prompts-loading-_load_output_parser) | Load output parser. | +| [`_load_prompt`](#langchain_core-prompts-loading-_load_prompt) | Load the prompt template from config. | +| [`_load_prompt_from_file`](#langchain_core-prompts-loading-_load_prompt_from_file) | Load prompt from file. | +| [`_load_template`](#langchain_core-prompts-loading-_load_template) | Load template from the path if applicable. | +| [`load_prompt`](#langchain_core-prompts-loading-load_prompt) | Unified method for loading a prompt from LangChainHub or local filesystem. | +| [`load_prompt_from_config`](#langchain_core-prompts-loading-load_prompt_from_config) | Load prompt from config dict. | + +### Data + +[`URL_BASE`](#langchain_core-prompts-loading-URL_BASE) + +[`logger`](#langchain_core-prompts-loading-logger) + +[`type_to_loader_dict`](#langchain_core-prompts-loading-type_to_loader_dict) + +### API + + + + + +```python +langchain_core.prompts.loading._load_chat_prompt( + config: langchain_core.prompts.dict +) -> langchain_core.prompts.chat.ChatPromptTemplate +``` + + + + + + +Load chat prompt from config. + + + + + + + + +```python +langchain_core.prompts.loading._load_examples( + config: langchain_core.prompts.dict +) -> langchain_core.prompts.dict +``` + + + + + + +Load examples if necessary. + + + + + + + + +```python +langchain_core.prompts.loading._load_few_shot_prompt( + config: langchain_core.prompts.dict +) -> langchain_core.prompts.few_shot.FewShotPromptTemplate +``` + + + + + + +Load the "few shot" prompt from the config. + + + + + + + + +```python +langchain_core.prompts.loading._load_output_parser( + config: langchain_core.prompts.dict +) -> langchain_core.prompts.dict +``` + + + + + + +Load output parser. + + + + + + + + +```python +langchain_core.prompts.loading._load_prompt( + config: langchain_core.prompts.dict +) -> langchain_core.prompts.prompt.PromptTemplate +``` + + + + + + +Load the prompt template from config. + + + + + + + + +```python +langchain_core.prompts.loading._load_prompt_from_file( + file: str | pathlib.Path, + encoding: str | None = None +) -> langchain_core.prompts.base.BasePromptTemplate +``` + + + + + + +Load prompt from file. + + + + + + + + +```python +langchain_core.prompts.loading._load_template( + var_name: str, + config: langchain_core.prompts.dict +) -> langchain_core.prompts.dict +``` + + + + + + +Load template from the path if applicable. + + + + + + + + +```python +langchain_core.prompts.loading.load_prompt( + path: str | pathlib.Path, + encoding: str | None = None +) -> langchain_core.prompts.base.BasePromptTemplate +``` + + + + + + +Unified method for loading a prompt from LangChainHub or local filesystem. + +**Parameters:** + + +Path to the prompt file. + + + +Encoding of the file. + + +**Returns:** `BasePromptTemplate` + +A `PromptTemplate` object. + +**Raises:** + +- `RuntimeError`: If the path is a LangChainHub path. + + + + + + + + +```python +langchain_core.prompts.loading.load_prompt_from_config( + config: langchain_core.prompts.dict +) -> langchain_core.prompts.base.BasePromptTemplate +``` + + + + + + +Load prompt from config dict. + +**Parameters:** + + +Dict containing the prompt configuration. + + +**Returns:** `BasePromptTemplate` + +A `PromptTemplate` object. + +**Raises:** + +- `ValueError`: If the prompt type is not supported. + + + + + + + + +```python +langchain_core.prompts.loading.URL_BASE = 'https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/' +``` + + + + + + + + + +```python +langchain_core.prompts.loading.logger = logging.getLogger(__name__) +``` + + + + + + + + + +```python +langchain_core.prompts.loading.type_to_loader_dict: dict[str, Callable[[dict], BasePromptTemplate]] = {'prompt': _load_prompt, 'few_shot': _load_few_shot_prompt, 'chat': _load_chat_p... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/message.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/message.mdx new file mode 100644 index 0000000..7dc2084 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/message.mdx @@ -0,0 +1,223 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts/message +title: langchain_core.prompts.message +--- + +Message prompt templates. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseMessagePromptTemplate`](#langchain_core-prompts-message-BaseMessagePromptTemplate) | Base class for message prompt templates. | + +### API + + + + + +```python +class langchain_core.prompts.message.BaseMessagePromptTemplate() +``` + + + + + + +Abstract + +**Bases:** [Serializable](/langchain-core/langchain_core/load/serializable#langchain_core-load-serializable-Serializable) + +Base class for message prompt templates. + + + +Input variables for this prompt template. + + + + + +```python +langchain_core.prompts.message.BaseMessagePromptTemplate.__add__( + other: typing.Any +) -> langchain_core.prompts.chat.ChatPromptTemplate +``` + + + + + + +Combine two prompt templates. + +**Parameters:** + + +Another prompt template. + + +**Returns:** `ChatPromptTemplate` + +Combined prompt template. + + + + + + + +```python +langchain_core.prompts.message.BaseMessagePromptTemplate.aformat_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +async + +Async format messages from kwargs. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `list[BaseMessage]` + +List of `BaseMessage` objects. + + + + + + + +```python +langchain_core.prompts.message.BaseMessagePromptTemplate.format_messages( + kwargs: typing.Any = {} +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +abstract + +Format messages from kwargs. + +Should return a list of `BaseMessage` objects. + +**Parameters:** + + +Keyword arguments to use for formatting. + + +**Returns:** `list[BaseMessage]` + +List of `BaseMessage` objects. + + + + + + + +```python +langchain_core.prompts.message.BaseMessagePromptTemplate.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "prompts", "chat"]` + + + + + + + +```python +langchain_core.prompts.message.BaseMessagePromptTemplate.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.prompts.message.BaseMessagePromptTemplate.pretty_print() -> None +``` + + + + + + +Print a human-readable representation. + + + + + + + +```python +langchain_core.prompts.message.BaseMessagePromptTemplate.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Human-readable representation. + +**Parameters:** + + +Whether to format as HTML. + + +**Returns:** `str` + +Human-readable representation. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/prompt.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/prompt.mdx new file mode 100644 index 0000000..34091d0 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/prompt.mdx @@ -0,0 +1,374 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts/prompt +title: langchain_core.prompts.prompt +--- + +Prompt schema definition. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`PromptTemplate`](#langchain_core-prompts-prompt-PromptTemplate) | Prompt template for a language model. | + +### API + + + + + +```python +class langchain_core.prompts.prompt.PromptTemplate() +``` + + + + + + +**Bases:** [StringPromptTemplate](/langchain-core/langchain_core/prompts/string#langchain_core-prompts-string-StringPromptTemplate) + +Prompt template for a language model. + +A prompt template consists of a string template. It accepts a set of parameters +from the user that can be used to generate a prompt for a language model. + +The template can be formatted using either f-strings (default), jinja2, or mustache +syntax. + +!!! warning "Security" + + Prefer using `template_format='f-string'` instead of `template_format='jinja2'`, + or make sure to NEVER accept jinja2 templates from untrusted sources as they may + lead to arbitrary Python code execution. + + As of LangChain 0.0.329, Jinja2 templates will be rendered using Jinja2's + SandboxedEnvironment by default. This sand-boxing should be treated as a + best-effort approach rather than a guarantee of security, as it is an opt-out + rather than opt-in approach. + + Despite the sandboxing, we recommend to never use jinja2 templates from + untrusted sources. + + + +Return the prompt type key. + + + + + + +The prompt template. + + + +The format of the prompt template. + +Options are: `'f-string'`, `'mustache'`, `'jinja2'`. + + + +Whether or not to try validating the template. + + + + + +```python +langchain_core.prompts.prompt.PromptTemplate.__add__( + other: typing.Any +) -> langchain_core.prompts.prompt.PromptTemplate +``` + + + + + + +Override the `+` operator to allow for combining prompt templates. + +**Returns:** `PromptTemplate` + +A new `PromptTemplate` that is the combination of the two. + +**Raises:** + +- `ValueError`: If the template formats are not f-string or if there are +conflicting partial variables. +- `NotImplementedError`: If the other object is not a `PromptTemplate` or str. + + + + + + + +```python +langchain_core.prompts.prompt.PromptTemplate.format( + kwargs: typing.Any = {} +) -> str +``` + + + + + + +Format the prompt with the inputs. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `str` + +A formatted string. + + + + + + + +```python +langchain_core.prompts.prompt.PromptTemplate.from_examples( + examples: list[str], + suffix: str, + input_variables: list[str], + example_separator: str = '\n\n', + prefix: str = '', + kwargs: typing.Any = {} +) -> langchain_core.prompts.prompt.PromptTemplate +``` + + + + + + +classmethod + +Take examples in list format with prefix and suffix to create a prompt. + +Intended to be used as a way to dynamically create a prompt from examples. + +**Parameters:** + + +List of examples to use in the prompt. + + + +String to go after the list of examples. + +Should generally set up the user's input. + + + +A list of variable names the final prompt template will +expect. + + + +The separator to use in between examples. + + + +String that should go before any examples. + +Generally includes examples. + + +**Returns:** `PromptTemplate` + +The final prompt generated. + + + + + + + +```python +langchain_core.prompts.prompt.PromptTemplate.from_file( + template_file: str | pathlib.Path, + encoding: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.prompts.prompt.PromptTemplate +``` + + + + + + +classmethod + +Load a prompt from a file. + +**Parameters:** + + +The path to the file containing the prompt template. + + + +The encoding system for opening the template file. + +If not provided, will use the OS default. + + +**Returns:** `PromptTemplate` + +The prompt loaded from the file. + + + + + + + +```python +langchain_core.prompts.prompt.PromptTemplate.from_template( + template: str, + template_format: langchain_core.prompts.string.PromptTemplateFormat = 'f-string', + partial_variables: langchain_core.prompts.dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.prompts.prompt.PromptTemplate +``` + + + + + + +classmethod + +Load a prompt template from a template. + +!!! warning "Security" + + Prefer using `template_format='f-string'` instead of + `template_format='jinja2'`, or make sure to NEVER accept jinja2 templates + from untrusted sources as they may lead to arbitrary Python code execution. + + As of LangChain 0.0.329, Jinja2 templates will be rendered using Jinja2's + SandboxedEnvironment by default. This sand-boxing should be treated as a + best-effort approach rather than a guarantee of security, as it is an + opt-out rather than opt-in approach. + + Despite the sandboxing, we recommend to never use jinja2 templates from + untrusted sources. + +**Parameters:** + + +The template to load. + + + +The format of the template. + +Use `jinja2` for jinja2, `mustache` for mustache, and `f-string` for +f-strings. + + + +A dictionary of variables that can be used to partially +fill in the template. + +For example, if the template is `'{variable1} {variable2}'`, and +`partial_variables` is `{"variable1": "foo"}`, then the final prompt +will be `'foo {variable2}'`. + + + +Any other arguments to pass to the prompt template. + + +**Returns:** `PromptTemplate` + +The prompt template loaded from the template. + + + + + + + +```python +langchain_core.prompts.prompt.PromptTemplate.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Get the input schema for the prompt. + +**Parameters:** + + +The runnable configuration. + + +**Returns:** `type[BaseModel]` + +The input schema for the prompt. + + + + + + + +```python +langchain_core.prompts.prompt.PromptTemplate.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "prompts", "prompt"]` + + + + + + + +```python +langchain_core.prompts.prompt.PromptTemplate.pre_init_validation( + values: langchain_core.prompts.dict +) -> typing.Any +``` + + + + + + +classmethod + +Check that template and input variables are consistent. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/string.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/string.mdx new file mode 100644 index 0000000..53805e5 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/string.mdx @@ -0,0 +1,597 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts/string +title: langchain_core.prompts.string +--- + +`BasePrompt` schema definition. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`StringPromptTemplate`](#langchain_core-prompts-string-StringPromptTemplate) | String prompt that exposes the format method, returning a prompt. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_create_model_recursive`](#langchain_core-prompts-string-_create_model_recursive) | - | +| [`_get_jinja2_variables_from_template`](#langchain_core-prompts-string-_get_jinja2_variables_from_template) | - | +| [`check_valid_template`](#langchain_core-prompts-string-check_valid_template) | Check that template string is valid. | +| [`get_template_variables`](#langchain_core-prompts-string-get_template_variables) | Get the variables from the template. | +| [`is_subsequence`](#langchain_core-prompts-string-is_subsequence) | Return `True` if child is subsequence of parent. | +| [`jinja2_formatter`](#langchain_core-prompts-string-jinja2_formatter) | Format a template using jinja2. | +| [`mustache_formatter`](#langchain_core-prompts-string-mustache_formatter) | Format a template using mustache. | +| [`mustache_schema`](#langchain_core-prompts-string-mustache_schema) | Get the variables from a mustache template. | +| [`mustache_template_vars`](#langchain_core-prompts-string-mustache_template_vars) | Get the top-level variables from a mustache template. | +| [`validate_jinja2`](#langchain_core-prompts-string-validate_jinja2) | Validate that the input variables are valid for the template. | + +### Data + +[`DEFAULT_FORMATTER_MAPPING`](#langchain_core-prompts-string-DEFAULT_FORMATTER_MAPPING) + +[`DEFAULT_VALIDATOR_MAPPING`](#langchain_core-prompts-string-DEFAULT_VALIDATOR_MAPPING) + +[`Defs`](#langchain_core-prompts-string-Defs) + +[`PromptTemplateFormat`](#langchain_core-prompts-string-PromptTemplateFormat) + +[`_HAS_JINJA2`](#langchain_core-prompts-string-_HAS_JINJA2) + +### API + + + + + +```python +class langchain_core.prompts.string.StringPromptTemplate() +``` + + + + + + +Abstract + +**Bases:** [BasePromptTemplate](/langchain-core/langchain_core/prompts/base#langchain_core-prompts-base-BasePromptTemplate) + +String prompt that exposes the format method, returning a prompt. + + + + + + +```python +langchain_core.prompts.string.StringPromptTemplate.aformat_prompt( + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + +async + +Async format the prompt with the inputs. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `PromptValue` + +A formatted string. + + + + + + + +```python +langchain_core.prompts.string.StringPromptTemplate.format( + kwargs: typing.Any = {} +) -> str +``` + + + + + + +abstract + + + + + + + +```python +langchain_core.prompts.string.StringPromptTemplate.format_prompt( + kwargs: typing.Any = {} +) -> langchain_core.prompt_values.PromptValue +``` + + + + + + +Format the prompt with the inputs. + +**Parameters:** + + +Any arguments to be passed to the prompt template. + + +**Returns:** `PromptValue` + +A formatted string. + + + + + + + +```python +langchain_core.prompts.string.StringPromptTemplate.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "prompts", "base"]` + + + + + + + +```python +langchain_core.prompts.string.StringPromptTemplate.pretty_print() -> None +``` + + + + + + +Print a pretty representation of the prompt. + + + + + + + +```python +langchain_core.prompts.string.StringPromptTemplate.pretty_repr( + html: bool = False +) -> str +``` + + + + + + +Get a pretty representation of the prompt. + +**Parameters:** + + +Whether to return an HTML-formatted string. + + +**Returns:** `str` + +A pretty representation of the prompt. + + + + + + + + + +```python +langchain_core.prompts.string._create_model_recursive( + name: str, + defs: langchain_core.prompts.string.Defs +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + + +```python +langchain_core.prompts.string._get_jinja2_variables_from_template( + template: str +) -> set[str] +``` + + + + + + + + + + + + + +```python +langchain_core.prompts.string.check_valid_template( + template: str, + template_format: str, + input_variables: list[str] +) -> None +``` + + + + + + +Check that template string is valid. + +**Parameters:** + + +The template string. + + + +The template format. + +Should be one of `'f-string'` or `'jinja2'`. + + + +The input variables. + + +**Raises:** + +- `ValueError`: If the template format is not supported. +- `ValueError`: If the prompt schema is invalid. + + + + + + + + +```python +langchain_core.prompts.string.get_template_variables( + template: str, + template_format: str +) -> list[str] +``` + + + + + + +Get the variables from the template. + +**Parameters:** + + +The template string. + + + +The template format. + +Should be one of `'f-string'`, `'mustache'` or `'jinja2'`. + + +**Returns:** `list[str]` + +The variables from the template. + +**Raises:** + +- `ValueError`: If the template format is not supported. + + + + + + + + +```python +langchain_core.prompts.string.is_subsequence( + child: collections.abc.Sequence, + parent: collections.abc.Sequence +) -> bool +``` + + + + + + +Return `True` if child is subsequence of parent. + + + + + + + + +```python +langchain_core.prompts.string.jinja2_formatter( + template: str, + kwargs: typing.Any = {} +) -> str +``` + + + + + + +Format a template using jinja2. + +!!! warning "Security" + + As of LangChain 0.0.329, this method uses Jinja2's `SandboxedEnvironment` by + default. However, this sandboxing should be treated as a best-effort approach + rather than a guarantee of security. + + Do not accept jinja2 templates from untrusted sources as they may lead + to arbitrary Python code execution. + + [More information.](https://jinja.palletsprojects.com/en/3.1.x/sandbox/) + +**Parameters:** + + +The template string. + + + +The variables to format the template with. + + +**Returns:** `str` + +The formatted string. + +**Raises:** + +- `ImportError`: If jinja2 is not installed. + + + + + + + + +```python +langchain_core.prompts.string.mustache_formatter( + template: str, + kwargs: typing.Any = {} +) -> str +``` + + + + + + +Format a template using mustache. + +**Parameters:** + + +The template string. + + + +The variables to format the template with. + + +**Returns:** `str` + +The formatted string. + + + + + + + + +```python +langchain_core.prompts.string.mustache_schema( + template: str +) -> type[pydantic.BaseModel] +``` + + + + + + +Get the variables from a mustache template. + +**Parameters:** + + +The template string. + + +**Returns:** `type[BaseModel]` + +The variables from the template as a Pydantic model. + + + + + + + + +```python +langchain_core.prompts.string.mustache_template_vars( + template: str +) -> set[str] +``` + + + + + + +Get the top-level variables from a mustache template. + +For nested variables like `{{person.name}}`, only the top-level key (`person`) is +returned. + +**Parameters:** + + +The template string. + + +**Returns:** `set[str]` + +The top-level variables from the template. + + + + + + + + +```python +langchain_core.prompts.string.validate_jinja2( + template: str, + input_variables: list[str] +) -> None +``` + + + + + + +Validate that the input variables are valid for the template. + +Issues a warning if missing or extra variables are found. + +**Parameters:** + + +The template string. + + + +The input variables. + + + + + + + + + +```python +langchain_core.prompts.string.DEFAULT_FORMATTER_MAPPING: dict[str, Callable[..., str]] = {'f-string': formatter.format, 'mustache': mustache_formatter, 'jinja2': jinja2_... +``` + + + + + + + + + +```python +langchain_core.prompts.string.DEFAULT_VALIDATOR_MAPPING: dict[str, Callable] = {'f-string': formatter.validate_input_variables, 'jinja2': validate_jinja2} +``` + + + + + + + + + +```python +langchain_core.prompts.string.Defs = dict[str, 'Defs'] +``` + + + + + + + + + +```python +langchain_core.prompts.string.PromptTemplateFormat = Literal['f-string', 'mustache', 'jinja2'] +``` + + + + + + + + + +```python +langchain_core.prompts.string._HAS_JINJA2 = True +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/structured.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/structured.mdx new file mode 100644 index 0000000..1184130 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/prompts/structured.mdx @@ -0,0 +1,211 @@ +--- +layout: overview +slug: langchain-core/langchain_core/prompts/structured +title: langchain_core.prompts.structured +--- + +Structured prompt template for a language model. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`StructuredPrompt`](#langchain_core-prompts-structured-StructuredPrompt) | Structured prompt template for a language model. | + +### API + + + + + +```python +class langchain_core.prompts.structured.StructuredPrompt( + messages: collections.abc.Sequence[langchain_core.prompts.chat.MessageLikeRepresentation], + schema_: langchain_core.prompts.dict | type[pydantic.BaseModel] | None = None, + structured_output_kwargs: langchain_core.prompts.dict[str, typing.Any] | None = None, + template_format: langchain_core.prompts.string.PromptTemplateFormat = 'f-string', + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [ChatPromptTemplate](/langchain-core/langchain_core/prompts/chat#langchain_core-prompts-chat-ChatPromptTemplate) + +Structured prompt template for a language model. + + + +Schema for the structured prompt. + + + + + + + + +```python +langchain_core.prompts.structured.StructuredPrompt.__or__( + other: langchain_core.runnables.base.Runnable[typing.Any, langchain_core.runnables.base.Other] | collections.abc.Callable[[Iterator[Any]], collections.abc.Iterator[langchain_core.runnables.base.Other]] | collections.abc.Callable[[AsyncIterator[Any]], collections.abc.AsyncIterator[langchain_core.runnables.base.Other]] | collections.abc.Callable[[Any], langchain_core.runnables.base.Other] | collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[typing.Any, langchain_core.runnables.base.Other] | collections.abc.Callable[[Any], langchain_core.runnables.base.Other] | typing.Any] +) -> langchain_core.runnables.base.RunnableSerializable[langchain_core.prompts.dict, langchain_core.runnables.base.Other] +``` + + + + + + + + + + + + +```python +langchain_core.prompts.structured.StructuredPrompt.from_messages_and_schema( + messages: collections.abc.Sequence[langchain_core.prompts.chat.MessageLikeRepresentation], + schema: langchain_core.prompts.dict | type, + kwargs: typing.Any = {} +) -> langchain_core.prompts.chat.ChatPromptTemplate +``` + + + + + + +classmethod + +Create a chat prompt template from a variety of message formats. + +**Parameters:** + + +Sequence of message representations. + +A message can be represented using the following formats: + +1. `BaseMessagePromptTemplate` +2. `BaseMessage` +3. 2-tuple of `(message type, template)`; e.g., + `("human", "{user_input}")` +4. 2-tuple of `(message class, template)` +5. A string which is shorthand for `("human", template)`; e.g., + `"{user_input}"` + + + +A dictionary representation of function call, or a Pydantic model. + + + +Any additional kwargs to pass through to +`ChatModel.with_structured_output(schema, **kwargs)`. + + +**Returns:** `ChatPromptTemplate` + +A structured prompt template + +**Examples:** + + + +```python +Instantiation from a list of message templates: + +```python +from langchain_core.prompts import StructuredPrompt + + +class OutputSchema(BaseModel): + name: str + value: int + + +template = StructuredPrompt( + [ + ("human", "Hello, how are you?"), + ("ai", "I'm doing well, thanks!"), + ("human", "That's good to hear."), + ], + OutputSchema, +) +``` + + + + + + + + + +```python +langchain_core.prompts.structured.StructuredPrompt.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +For example, if the class is `langchain.llms.openai.OpenAI`, then the namespace +is `["langchain", "llms", "openai"]` + +**Returns:** `list[str]` + +The namespace of the LangChain object. + + + + + + + +```python +langchain_core.prompts.structured.StructuredPrompt.pipe( + others: langchain_core.runnables.base.Runnable[typing.Any, langchain_core.runnables.base.Other] | collections.abc.Callable[[Iterator[Any]], collections.abc.Iterator[langchain_core.runnables.base.Other]] | collections.abc.Callable[[AsyncIterator[Any]], collections.abc.AsyncIterator[langchain_core.runnables.base.Other]] | collections.abc.Callable[[Any], langchain_core.runnables.base.Other] | collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[typing.Any, langchain_core.runnables.base.Other] | collections.abc.Callable[[Any], langchain_core.runnables.base.Other] | typing.Any] = (), + name: str | None = None +) -> langchain_core.runnables.base.RunnableSerializable[langchain_core.prompts.dict, langchain_core.runnables.base.Other] +``` + + + + + + +Pipe the structured prompt to a language model. + +**Parameters:** + + +The language model to pipe the structured prompt to. + + + +The name of the pipeline. + + +**Returns:** `RunnableSerializable[dict, Other]` + +A `RunnableSequence` object. + +**Raises:** + +- `NotImplementedError`: If the first element of `others` is not a language +model. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/rate_limiters.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/rate_limiters.mdx new file mode 100644 index 0000000..535732a --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/rate_limiters.mdx @@ -0,0 +1,300 @@ +--- +layout: overview +slug: langchain-core/langchain_core/rate_limiters +title: langchain_core.rate_limiters +--- + +Interface for a rate limiter and an in-memory rate limiter. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseRateLimiter`](#langchain_core-rate_limiters-BaseRateLimiter) | Base class for rate limiters. | +| [`InMemoryRateLimiter`](#langchain_core-rate_limiters-InMemoryRateLimiter) | An in memory rate limiter based on a token bucket algorithm. | + +### Data + +[`__all__`](#langchain_core-rate_limiters-__all__) + +### API + + + + + +```python +class langchain_core.rate_limiters.BaseRateLimiter() +``` + + + + + + +Abstract + +**Bases:** `ABC` + +Base class for rate limiters. + +Usage of the base limiter is through the acquire and aacquire methods depending +on whether running in a sync or async context. + +Implementations are free to add a timeout parameter to their initialize method +to allow users to specify a timeout for acquiring the necessary tokens when +using a blocking call. + +Current limitations: + +- Rate limiting information is not surfaced in tracing or callbacks. This means + that the total time it takes to invoke a chat model will encompass both + the time spent waiting for tokens and the time spent making the request. + + + + + + +```python +langchain_core.rate_limiters.BaseRateLimiter.aacquire( + blocking: bool = True +) -> bool +``` + + + + + + +async abstract + +Attempt to acquire the necessary tokens for the rate limiter. + +This method blocks until the required tokens are available if `blocking` +is set to `True`. + +If `blocking` is set to `False`, the method will immediately return the result +of the attempt to acquire the tokens. + +**Parameters:** + + +If `True`, the method will block until the tokens are available. +If `False`, the method will return immediately with the result of +the attempt. + + +**Returns:** `bool` + +`True` if the tokens were successfully acquired, `False` otherwise. + + + + + + + +```python +langchain_core.rate_limiters.BaseRateLimiter.acquire( + blocking: bool = True +) -> bool +``` + + + + + + +abstract + +Attempt to acquire the necessary tokens for the rate limiter. + +This method blocks until the required tokens are available if `blocking` +is set to `True`. + +If `blocking` is set to `False`, the method will immediately return the result +of the attempt to acquire the tokens. + +**Parameters:** + + +If `True`, the method will block until the tokens are available. +If `False`, the method will return immediately with the result of +the attempt. + + +**Returns:** `bool` + +`True` if the tokens were successfully acquired, `False` otherwise. + + + + + + + + + +```python +class langchain_core.rate_limiters.InMemoryRateLimiter( + requests_per_second: float = 1, + check_every_n_seconds: float = 0.1, + max_bucket_size: float = 1 +) +``` + + + + + + +**Bases:** [BaseRateLimiter](#langchain_core-rate_limiters-BaseRateLimiter) + +An in memory rate limiter based on a token bucket algorithm. + +This is an in memory rate limiter, so it cannot rate limit across +different processes. + +The rate limiter only allows time-based rate limiting and does not +take into account any information about the input or the output, so it +cannot be used to rate limit based on the size of the request. + +It is thread safe and can be used in either a sync or async context. + +The in memory rate limiter is based on a token bucket. The bucket is filled +with tokens at a given rate. Each request consumes a token. If there are +not enough tokens in the bucket, the request is blocked until there are +enough tokens. + +These tokens have nothing to do with LLM tokens. They are just +a way to keep track of how many requests can be made at a given time. + +Current limitations: + +- The rate limiter is not designed to work across different processes. It is + an in-memory rate limiter, but it is thread safe. +- The rate limiter only supports time-based rate limiting. It does not take + into account the size of the request or any other factors. + + + + + + + + + + + + + + +```python +langchain_core.rate_limiters.InMemoryRateLimiter._consume() -> bool +``` + + + + + + +Try to consume a token. + +**Returns:** `bool` + +True means that the tokens were consumed, and the caller can proceed to + + + + + + + +```python +langchain_core.rate_limiters.InMemoryRateLimiter.aacquire( + blocking: bool = True +) -> bool +``` + + + + + + +async + +Attempt to acquire a token from the rate limiter. Async version. + +This method blocks until the required tokens are available if `blocking` +is set to `True`. + +If `blocking` is set to `False`, the method will immediately return the result +of the attempt to acquire the tokens. + +**Parameters:** + + +If `True`, the method will block until the tokens are available. +If `False`, the method will return immediately with the result of +the attempt. + + +**Returns:** `bool` + +`True` if the tokens were successfully acquired, `False` otherwise. + + + + + + + +```python +langchain_core.rate_limiters.InMemoryRateLimiter.acquire( + blocking: bool = True +) -> bool +``` + + + + + + +Attempt to acquire a token from the rate limiter. + +This method blocks until the required tokens are available if `blocking` +is set to `True`. + +If `blocking` is set to `False`, the method will immediately return the result +of the attempt to acquire the tokens. + +**Parameters:** + + +If `True`, the method will block until the tokens are available. +If `False`, the method will return immediately with the result of +the attempt. + + +**Returns:** `bool` + +`True` if the tokens were successfully acquired, `False` otherwise. + + + + + + + + + +```python +langchain_core.rate_limiters.__all__ = ['BaseRateLimiter', 'InMemoryRateLimiter'] +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/retrievers.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/retrievers.mdx new file mode 100644 index 0000000..cc706db --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/retrievers.mdx @@ -0,0 +1,433 @@ +--- +layout: overview +slug: langchain-core/langchain_core/retrievers +title: langchain_core.retrievers +--- + +**Retriever** class returns `Document` objects given a text **query**. + +It is more general than a vector store. A retriever does not need to be able to +store documents, only to return (or retrieve) it. Vector stores can be used as +the backbone of a retriever, but there are other types of retrievers as well. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseRetriever`](#langchain_core-retrievers-BaseRetriever) | Abstract base class for a document retrieval system. | +| [`LangSmithRetrieverParams`](#langchain_core-retrievers-LangSmithRetrieverParams) | LangSmith parameters for tracing. | + +### Data + +[`RetrieverInput`](#langchain_core-retrievers-RetrieverInput) + +[`RetrieverLike`](#langchain_core-retrievers-RetrieverLike) + +[`RetrieverOutput`](#langchain_core-retrievers-RetrieverOutput) + +[`RetrieverOutputLike`](#langchain_core-retrievers-RetrieverOutputLike) + +### API + + + + + +```python +class langchain_core.retrievers.BaseRetriever() +``` + + + + + + +Abstract + +**Bases:** [RunnableSerializable[RetrieverInput, RetrieverOutput]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +Abstract base class for a document retrieval system. + +A retrieval system is defined as something that can take string queries and return +the most 'relevant' documents from some source. + +Usage: + +A retriever follows the standard `Runnable` interface, and should be used via the +standard `Runnable` methods of `invoke`, `ainvoke`, `batch`, `abatch`. + +Implementation: + +When implementing a custom retriever, the class should implement the +`_get_relevant_documents` method to define the logic for retrieving documents. + +Optionally, an async native implementations can be provided by overriding the +`_aget_relevant_documents` method. + +!!! example "Retriever that returns the first 5 documents from a list of documents" + + ```python + from langchain_core.documents import Document + from langchain_core.retrievers import BaseRetriever + + class SimpleRetriever(BaseRetriever): + docs: list[Document] + k: int = 5 + + def _get_relevant_documents(self, query: str) -> list[Document]: + """Return the first k documents from the list of documents""" + return self.docs[:self.k] + + async def _aget_relevant_documents(self, query: str) -> list[Document]: + """(Optional) async native implementation.""" + return self.docs[:self.k] + ``` + +!!! example "Simple retriever based on a scikit-learn vectorizer" + + ```python + from sklearn.metrics.pairwise import cosine_similarity + + + class TFIDFRetriever(BaseRetriever, BaseModel): + vectorizer: Any + docs: list[Document] + tfidf_array: Any + k: int = 4 + + class Config: + arbitrary_types_allowed = True + + def _get_relevant_documents(self, query: str) -> list[Document]: + # Ip -- (n_docs,x), Op -- (n_docs,n_Feats) + query_vec = self.vectorizer.transform([query]) + # Op -- (n_docs,1) -- Cosine Sim with each doc + results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,)) + return [self.docs[i] for i in results.argsort()[-self.k :][::-1]] + ``` + + + + + + + + + +Optional metadata associated with the retriever. + +This metadata will be associated with each call to this retriever, +and passed as arguments to the handlers defined in `callbacks`. + +You can use these to eg identify a specific instance of a retriever with its +use case. + + + + + + +Optional list of tags associated with the retriever. + +These tags will be associated with each call to this retriever, +and passed as arguments to the handlers defined in `callbacks`. + +You can use these to eg identify a specific instance of a retriever with its +use case. + + + + + +```python +langchain_core.retrievers.BaseRetriever.__init_subclass__( + kwargs: typing.Any = {} +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.retrievers.BaseRetriever._aget_relevant_documents( + query: str, + run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForRetrieverRun +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + +Asynchronously get documents relevant to a query. + +**Parameters:** + + +String to find relevant documents for + + + +The callback handler to use + + +**Returns:** `list[Document]` + +List of relevant documents + + + + + + + +```python +langchain_core.retrievers.BaseRetriever._get_ls_params( + _kwargs: typing.Any = {} +) -> langchain_core.retrievers.LangSmithRetrieverParams +``` + + + + + + +Get standard params for tracing. + + + + + + + +```python +langchain_core.retrievers.BaseRetriever._get_relevant_documents( + query: str, + run_manager: langchain_core.callbacks.manager.CallbackManagerForRetrieverRun +) -> list[langchain_core.documents.Document] +``` + + + + + + +abstract + +Get documents relevant to a query. + +**Parameters:** + + +String to find relevant documents for. + + + +The callback handler to use. + + +**Returns:** `list[Document]` + +List of relevant documents. + + + + + + + +```python +langchain_core.retrievers.BaseRetriever.ainvoke( + input: str, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + +Asynchronously invoke the retriever to get relevant documents. + +Main entry point for asynchronous retriever invocations. + +Examples: + + +```python +await retriever.ainvoke("query") +``` + + + +**Parameters:** + + +The query string. + + + +Configuration for the retriever. + + + +Additional arguments to pass to the retriever. + + +**Returns:** `list[Document]` + +List of relevant documents. + + + + + + + +```python +langchain_core.retrievers.BaseRetriever.invoke( + input: str, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +Invoke the retriever to get relevant documents. + +Main entry point for synchronous retriever invocations. + +Examples: + + +```python +retriever.invoke("query") +``` + + + +**Parameters:** + + +The query string. + + + +Configuration for the retriever. + + + +Additional arguments to pass to the retriever. + + +**Returns:** `list[Document]` + +List of relevant documents. + + + + + + + + + +```python +class langchain_core.retrievers.LangSmithRetrieverParams +``` + + + + + + +**Bases:** `typing.TypedDict` + +LangSmith parameters for tracing. + + +Embedding model. + + + +Embedding provider. + + + +Retriever name. + + + +Vector store provider. + + + + + + + + +```python +langchain_core.retrievers.RetrieverInput = str +``` + + + + + + + + + +```python +langchain_core.retrievers.RetrieverLike = Runnable[RetrieverInput, RetrieverOutput] +``` + + + + + + + + + +```python +langchain_core.retrievers.RetrieverOutput = list[Document] +``` + + + + + + + + + +```python +langchain_core.retrievers.RetrieverOutputLike = Runnable[Any, RetrieverOutput] +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables.mdx new file mode 100644 index 0000000..32dc3b1 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables.mdx @@ -0,0 +1,117 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables +title: langchain_core.runnables +--- + +LangChain **Runnable** and the **LangChain Expression Language (LCEL)**. + +The LangChain Expression Language (LCEL) offers a declarative method to build +production-grade programs that harness the power of LLMs. + +Programs created using LCEL and LangChain `Runnable` objects inherently suppor +synchronous asynchronous, batch, and streaming operations. + +Support for **async** allows servers hosting LCEL based programs to scale bette for +higher concurrent loads. + +**Batch** operations allow for processing multiple inputs in parallel. + +**Streaming** of intermediate outputs, as they're being generated, allows for creating +more responsive UX. + +This module contains schema and implementation of LangChain `Runnable` object +primitives. + +## Submodules + +- **[`langchain_core.runnables.base`](/langchain-core/langchain_core/runnables/base)** +- **[`langchain_core.runnables.branch`](/langchain-core/langchain_core/runnables/branch)** +- **[`langchain_core.runnables.config`](/langchain-core/langchain_core/runnables/config)** +- **[`langchain_core.runnables.configurable`](/langchain-core/langchain_core/runnables/configurable)** +- **[`langchain_core.runnables.fallbacks`](/langchain-core/langchain_core/runnables/fallbacks)** +- **[`langchain_core.runnables.graph`](/langchain-core/langchain_core/runnables/graph)** +- **[`langchain_core.runnables.graph_ascii`](/langchain-core/langchain_core/runnables/graph_ascii)** +- **[`langchain_core.runnables.graph_mermaid`](/langchain-core/langchain_core/runnables/graph_mermaid)** +- **[`langchain_core.runnables.graph_png`](/langchain-core/langchain_core/runnables/graph_png)** +- **[`langchain_core.runnables.history`](/langchain-core/langchain_core/runnables/history)** +- **[`langchain_core.runnables.passthrough`](/langchain-core/langchain_core/runnables/passthrough)** +- **[`langchain_core.runnables.retry`](/langchain-core/langchain_core/runnables/retry)** +- **[`langchain_core.runnables.router`](/langchain-core/langchain_core/runnables/router)** +- **[`langchain_core.runnables.schema`](/langchain-core/langchain_core/runnables/schema)** +- **[`langchain_core.runnables.utils`](/langchain-core/langchain_core/runnables/utils)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-runnables-__dir__) | - | +| [`__getattr__`](#langchain_core-runnables-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-runnables-__all__) + +[`_dynamic_imports`](#langchain_core-runnables-_dynamic_imports) + +### API + + + + + +```python +langchain_core.runnables.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.runnables.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.runnables.__all__ = ('AddableDict', 'ConfigurableField', 'ConfigurableFieldMultiOption', 'Configurab... +``` + + + + + + + + + +```python +langchain_core.runnables._dynamic_imports = {'chain': 'base', 'Runnable': 'base', 'RunnableBinding': 'base', 'RunnableGenera... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/base.mdx new file mode 100644 index 0000000..ce322bf --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/base.mdx @@ -0,0 +1,5475 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/base +title: langchain_core.runnables.base +--- + +Base classes and utilities for `Runnable` objects. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Runnable`](#langchain_core-runnables-base-Runnable) | A unit of work that can be invoked, batched, streamed, transformed and composed. | +| [`RunnableBinding`](#langchain_core-runnables-base-RunnableBinding) | Wrap a `Runnable` with additional functionality. | +| [`RunnableBindingBase`](#langchain_core-runnables-base-RunnableBindingBase) | `Runnable` that delegates calls to another `Runnable` with a set of `**kwargs`. | +| [`RunnableEach`](#langchain_core-runnables-base-RunnableEach) | RunnableEach class. | +| [`RunnableEachBase`](#langchain_core-runnables-base-RunnableEachBase) | RunnableEachBase class. | +| [`RunnableGenerator`](#langchain_core-runnables-base-RunnableGenerator) | `Runnable` that runs a generator function. | +| [`RunnableLambda`](#langchain_core-runnables-base-RunnableLambda) | `RunnableLambda` converts a python callable into a `Runnable`. | +| [`RunnableParallel`](#langchain_core-runnables-base-RunnableParallel) | Runnable that runs a mapping of `Runnable`s in parallel. | +| [`RunnableSequence`](#langchain_core-runnables-base-RunnableSequence) | Sequence of `Runnable` objects, where the output of one is the input of the next. | +| [`RunnableSerializable`](#langchain_core-runnables-base-RunnableSerializable) | Runnable that can be serialized to JSON. | +| [`_RunnableCallableAsync`](#langchain_core-runnables-base-_RunnableCallableAsync) | - | +| [`_RunnableCallableAsyncIterator`](#langchain_core-runnables-base-_RunnableCallableAsyncIterator) | - | +| [`_RunnableCallableIterator`](#langchain_core-runnables-base-_RunnableCallableIterator) | - | +| [`_RunnableCallableSync`](#langchain_core-runnables-base-_RunnableCallableSync) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_seq_input_schema`](#langchain_core-runnables-base-_seq_input_schema) | - | +| [`_seq_output_schema`](#langchain_core-runnables-base-_seq_output_schema) | - | +| [`chain`](#langchain_core-runnables-base-chain) | Decorate a function to make it a `Runnable`. | +| [`coerce_to_runnable`](#langchain_core-runnables-base-coerce_to_runnable) | Coerce a `Runnable`-like object into a `Runnable`. | + +### Data + +[`Other`](#langchain_core-runnables-base-Other) + +[`RunnableLike`](#langchain_core-runnables-base-RunnableLike) + +[`RunnableMap`](#langchain_core-runnables-base-RunnableMap) + +[`_RUNNABLE_GENERIC_NUM_ARGS`](#langchain_core-runnables-base-_RUNNABLE_GENERIC_NUM_ARGS) + +[`_RUNNABLE_SEQUENCE_MIN_STEPS`](#langchain_core-runnables-base-_RUNNABLE_SEQUENCE_MIN_STEPS) + +### API + + + + + +```python +class langchain_core.runnables.base.Runnable() +``` + + + + + + +Abstract + +**Bases:** `Generic[Input, Output]` + +A unit of work that can be invoked, batched, streamed, transformed and composed. + +Key Methods +=========== + +- `invoke`/`ainvoke`: Transforms a single input into an output. +- `batch`/`abatch`: Efficiently transforms multiple inputs into outputs. +- `stream`/`astream`: Streams output from a single input as it's produced. +- `astream_log`: Streams output and selected intermediate results from an + input. + +Built-in optimizations: + +- **Batch**: By default, batch runs invoke() in parallel using a thread pool + executor. Override to optimize batching. + +- **Async**: Methods with `'a'` prefix are asynchronous. By default, they execute + the sync counterpart using asyncio's thread pool. + Override for native async. + +All methods accept an optional config argument, which can be used to configure +execution, add tags and metadata for tracing and debugging etc. + +Runnables expose schematic information about their input, output and config via +the `input_schema` property, the `output_schema` property and `config_schema` +method. + +Composition +=========== + +Runnable objects can be composed together to create chains in a declarative way. + +Any chain constructed this way will automatically have sync, async, batch, and +streaming support. + +The main composition primitives are `RunnableSequence` and `RunnableParallel`. + +**`RunnableSequence`** invokes a series of runnables sequentially, with +one Runnable's output serving as the next's input. Construct using +the `|` operator or by passing a list of runnables to `RunnableSequence`. + +**`RunnableParallel`** invokes runnables concurrently, providing the same input +to each. Construct it using a dict literal within a sequence or by passing a +dict to `RunnableParallel`. + + +For example, + + + +```python +from langchain_core.runnables import RunnableLambda + +# A RunnableSequence constructed using the `|` operator +sequence = RunnableLambda(lambda x: x + 1) | RunnableLambda(lambda x: x * 2) +sequence.invoke(1) # 4 +sequence.batch([1, 2, 3]) # [4, 6, 8] + + +# A sequence that contains a RunnableParallel constructed using a dict literal +sequence = RunnableLambda(lambda x: x + 1) | { + "mul_2": RunnableLambda(lambda x: x * 2), + "mul_5": RunnableLambda(lambda x: x * 5), +} +sequence.invoke(1) # {'mul_2': 4, 'mul_5': 10} +``` + + + +Standard Methods +================ + +All `Runnable`s expose additional methods that can be used to modify their +behavior (e.g., add a retry policy, add lifecycle listeners, make them +configurable, etc.). + +These methods will work on any `Runnable`, including `Runnable` chains +constructed by composing other `Runnable`s. +See the individual methods for details. + +For example, + + + +```python +from langchain_core.runnables import RunnableLambda + +import random + +def add_one(x: int) -> int: + return x + 1 + + +def buggy_double(y: int) -> int: + """Buggy code that will fail 70% of the time""" + if random.random() > 0.3: + print('This code failed, and will probably be retried!') # noqa: T201 + raise ValueError('Triggered buggy code') + return y * 2 + +sequence = ( + RunnableLambda(add_one) | + RunnableLambda(buggy_double).with_retry( # Retry on failure + stop_after_attempt=10, + wait_exponential_jitter=False + ) +) + +print(sequence.input_schema.model_json_schema()) # Show inferred input schema +print(sequence.output_schema.model_json_schema()) # Show inferred output schema +print(sequence.invoke(2)) # invoke the sequence (note the retry above!!) +``` + + + +Debugging and tracing +===================== + +As the chains get longer, it can be useful to be able to see intermediate results +to debug and trace the chain. + +You can set the global debug flag to True to enable debug output for all chains: + + + +```python +from langchain_core.globals import set_debug + +set_debug(True) +``` + + + +Alternatively, you can pass existing or custom callbacks to any given chain: + + + +```python +from langchain_core.tracers import ConsoleCallbackHandler + +chain.invoke(..., config={"callbacks": [ConsoleCallbackHandler()]}) +``` + + + +For a UI (and much more) checkout [LangSmith](https://docs.langchain.com/langsmith/home). + + + +Input type. + +The type of input this `Runnable` accepts specified as a type annotation. + + + +Output Type. + +The type of output this `Runnable` produces specified as a type annotation. + + + +List configurable fields for this `Runnable`. + + + +The type of input this `Runnable` accepts specified as a Pydantic model. + + + +The name of the `Runnable`. Used for debugging and tracing. + + + +Output schema. + +The type of output this `Runnable` produces specified as a Pydantic model. + + + + + +```python +langchain_core.runnables.base.Runnable.__or__( + other: langchain_core.runnables.base.Runnable[typing.Any, langchain_core.runnables.base.Other] | collections.abc.Callable[[Iterator[Any]], collections.abc.Iterator[langchain_core.runnables.base.Other]] | collections.abc.Callable[[AsyncIterator[Any]], collections.abc.AsyncIterator[langchain_core.runnables.base.Other]] | collections.abc.Callable[[Any], langchain_core.runnables.base.Other] | collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[typing.Any, langchain_core.runnables.base.Other] | collections.abc.Callable[[Any], langchain_core.runnables.base.Other] | typing.Any] +) -> langchain_core.runnables.base.RunnableSerializable[langchain_core.runnables.utils.Input, langchain_core.runnables.base.Other] +``` + + + + + + +Runnable "or" operator. + +Compose this `Runnable` with another object to create a +`RunnableSequence`. + +**Parameters:** + + +Another `Runnable` or a `Runnable`-like object. + + +**Returns:** `RunnableSerializable[Input, Other]` + +A new `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.__ror__( + other: langchain_core.runnables.base.Runnable[langchain_core.runnables.base.Other, typing.Any] | collections.abc.Callable[[Iterator[Other]], collections.abc.Iterator[typing.Any]] | collections.abc.Callable[[AsyncIterator[Other]], collections.abc.AsyncIterator[typing.Any]] | collections.abc.Callable[[Other], typing.Any] | collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[langchain_core.runnables.base.Other, typing.Any] | collections.abc.Callable[[Other], typing.Any] | typing.Any] +) -> langchain_core.runnables.base.RunnableSerializable[langchain_core.runnables.base.Other, langchain_core.runnables.utils.Output] +``` + + + + + + +Runnable "reverse-or" operator. + +Compose this `Runnable` with another object to create a +`RunnableSequence`. + +**Parameters:** + + +Another `Runnable` or a `Runnable`-like object. + + +**Returns:** `RunnableSerializable[Other, Output]` + +A new `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable._abatch_with_config( + func: collections.abc.Callable[[list[Input]], collections.abc.Awaitable[list[Exception | langchain_core.runnables.utils.Output]]] | collections.abc.Callable[[list[Input], list[AsyncCallbackManagerForChainRun]], collections.abc.Awaitable[list[Exception | langchain_core.runnables.utils.Output]]] | collections.abc.Callable[[list[Input], list[AsyncCallbackManagerForChainRun], list[RunnableConfig]], collections.abc.Awaitable[list[Exception | langchain_core.runnables.utils.Output]]], + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + run_type: str | None = None, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +async + +Transform a list of inputs to a list of outputs, with callbacks. + +Helper method to transform an `Input` value to an `Output` value, +with callbacks. + +Use this method to implement `invoke` in subclasses. + + + + + + + +```python +langchain_core.runnables.base.Runnable._acall_with_config( + func: collections.abc.Callable[[Input], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input, AsyncCallbackManagerForChainRun], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input, AsyncCallbackManagerForChainRun, RunnableConfig], collections.abc.Awaitable[langchain_core.runnables.utils.Output]], + input_: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None, + run_type: str | None = None, + serialized: dict[str, typing.Any] | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + +Async call with config. + +Helper method to transform an `Input` value to an `Output` value, +with callbacks. + +Use this method to implement `ainvoke` in subclasses. + + + + + + + +```python +langchain_core.runnables.base.Runnable._atransform_stream_with_config( + inputs: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + transformer: collections.abc.Callable[[AsyncIterator[Input]], collections.abc.AsyncIterator[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[AsyncIterator[Input], AsyncCallbackManagerForChainRun], collections.abc.AsyncIterator[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[AsyncIterator[Input], AsyncCallbackManagerForChainRun, RunnableConfig], collections.abc.AsyncIterator[langchain_core.runnables.utils.Output]], + config: langchain_core.runnables.config.RunnableConfig | None, + run_type: str | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + +Transform a stream with config. + +Helper method to transform an Async `Iterator` of `Input` values into an +Async `Iterator` of `Output` values, with callbacks. + +Use this to implement `astream` or `atransform` in `Runnable` subclasses. + + + + + + + +```python +langchain_core.runnables.base.Runnable._batch_with_config( + func: collections.abc.Callable[[list[Input]], list[Exception | langchain_core.runnables.utils.Output]] | collections.abc.Callable[[list[Input], list[CallbackManagerForChainRun]], list[Exception | langchain_core.runnables.utils.Output]] | collections.abc.Callable[[list[Input], list[CallbackManagerForChainRun], list[RunnableConfig]], list[Exception | langchain_core.runnables.utils.Output]], + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + run_type: str | None = None, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +Transform a list of inputs to a list of outputs, with callbacks. + +Helper method to transform an `Input` value to an `Output` value, +with callbacks. Use this method to implement `invoke` in subclasses. + + + + + + + +```python +langchain_core.runnables.base.Runnable._call_with_config( + func: collections.abc.Callable[[Input], langchain_core.runnables.utils.Output] | collections.abc.Callable[[Input, CallbackManagerForChainRun], langchain_core.runnables.utils.Output] | collections.abc.Callable[[Input, CallbackManagerForChainRun, RunnableConfig], langchain_core.runnables.utils.Output], + input_: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None, + run_type: str | None = None, + serialized: dict[str, typing.Any] | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +Call with config. + +Helper method to transform an `Input` value to an `Output` value, +with callbacks. + +Use this method to implement `invoke` in subclasses. + + + + + + + +```python +langchain_core.runnables.base.Runnable._transform_stream_with_config( + inputs: collections.abc.Iterator[langchain_core.runnables.utils.Input], + transformer: collections.abc.Callable[[Iterator[Input]], collections.abc.Iterator[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Iterator[Input], CallbackManagerForChainRun], collections.abc.Iterator[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Iterator[Input], CallbackManagerForChainRun, RunnableConfig], collections.abc.Iterator[langchain_core.runnables.utils.Output]], + config: langchain_core.runnables.config.RunnableConfig | None, + run_type: str | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + +Transform a stream with config. + +Helper method to transform an `Iterator` of `Input` values into an +`Iterator` of `Output` values, with callbacks. + +Use this to implement `stream` or `transform` in `Runnable` subclasses. + + + + + + + +```python +langchain_core.runnables.base.Runnable.abatch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +async + +Default implementation runs `ainvoke` in parallel using `asyncio.gather`. + +The default implementation of `batch` works well for IO bound runnables. + +Subclasses must override this method if they can batch more efficiently; +e.g., if the underlying `Runnable` uses an API which supports a batch mode. + +**Parameters:** + + +A list of inputs to the `Runnable`. + + + +A config to use when invoking the `Runnable`. + +The config supports standard keys like `'tags'`, `'metadata'` for +tracing purposes, `'max_concurrency'` for controlling how much work to +do in parallel, and other keys. + +Please refer to `RunnableConfig` for more details. + + + +Whether to return exceptions instead of raising them. + + + +Additional keyword arguments to pass to the `Runnable`. + + +**Returns:** `list[Output]` + +A list of outputs from the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.abatch_as_completed( + inputs: collections.abc.Sequence[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | collections.abc.Sequence[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[tuple[int, langchain_core.runnables.utils.Output | Exception]] +``` + + + + + + +async + +Run `ainvoke` in parallel on a list of inputs. + +Yields results as they complete. + +**Parameters:** + + +A list of inputs to the `Runnable`. + + + +A config to use when invoking the `Runnable`. + +The config supports standard keys like `'tags'`, `'metadata'` for +tracing purposes, `'max_concurrency'` for controlling how much work to +do in parallel, and other keys. + +Please refer to `RunnableConfig` for more details. + + + +Whether to return exceptions instead of raising them. + + + +Additional keyword arguments to pass to the `Runnable`. + + + + + + + + +```python +langchain_core.runnables.base.Runnable.ainvoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + +Transform a single input into an output. + +**Parameters:** + + +The input to the `Runnable`. + + + +A config to use when invoking the `Runnable`. + +The config supports standard keys like `'tags'`, `'metadata'` for +tracing purposes, `'max_concurrency'` for controlling how much work to +do in parallel, and other keys. + +Please refer to `RunnableConfig` for more details. + + +**Returns:** `Output` + +The output of the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.as_tool( + args_schema: type[pydantic.BaseModel] | None = None, + name: str | None = None, + description: str | None = None, + arg_types: dict[str, type] | None = None +) -> langchain_core.tools.BaseTool +``` + + + + + + +Create a `BaseTool` from a `Runnable`. + +`as_tool` will instantiate a `BaseTool` with a name, description, and +`args_schema` from a `Runnable`. Where possible, schemas are inferred +from `runnable.get_input_schema`. + +Alternatively (e.g., if the `Runnable` takes a dict as input and the specific +`dict` keys are not typed), the schema can be specified directly with +`args_schema`. + +You can also pass `arg_types` to just specify the required arguments and their +types. + +!!! example "`TypedDict` input" + + ```python + from typing_extensions import TypedDict + from langchain_core.runnables import RunnableLambda + + + class Args(TypedDict): + a: int + b: list[int] + + + def f(x: Args) -> str: + return str(x["a"] * max(x["b"])) + + + runnable = RunnableLambda(f) + as_tool = runnable.as_tool() + as_tool.invoke({"a": 3, "b": [1, 2]}) + ``` + +!!! example "`dict` input, specifying schema via `args_schema`" + + ```python + from typing import Any + from pydantic import BaseModel, Field + from langchain_core.runnables import RunnableLambda + + def f(x: dict[str, Any]) -> str: + return str(x["a"] * max(x["b"])) + + class FSchema(BaseModel): + """Apply a function to an integer and list of integers.""" + + a: int = Field(..., description="Integer") + b: list[int] = Field(..., description="List of ints") + + runnable = RunnableLambda(f) + as_tool = runnable.as_tool(FSchema) + as_tool.invoke({"a": 3, "b": [1, 2]}) + ``` + +!!! example "`dict` input, specifying schema via `arg_types`" + + ```python + from typing import Any + from langchain_core.runnables import RunnableLambda + + + def f(x: dict[str, Any]) -> str: + return str(x["a"] * max(x["b"])) + + + runnable = RunnableLambda(f) + as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]}) + as_tool.invoke({"a": 3, "b": [1, 2]}) + ``` + +!!! example "`str` input" + + ```python + from langchain_core.runnables import RunnableLambda + + + def f(x: str) -> str: + return x + "a" + + + def g(x: str) -> str: + return x + "z" + + + runnable = RunnableLambda(f) | g + as_tool = runnable.as_tool() + as_tool.invoke("b") + ``` + +**Parameters:** + + +The schema for the tool. + + + +The name of the tool. + + + +The description of the tool. + + + +A dictionary of argument names to types. + + +**Returns:** `BaseTool` + +A `BaseTool` instance. + + + + + + + +```python +langchain_core.runnables.base.Runnable.assign( + kwargs: langchain_core.runnables.base.Runnable[dict[str, typing.Any], typing.Any] | collections.abc.Callable[[dict[str, Any]], typing.Any] | collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[dict[str, typing.Any], typing.Any] | collections.abc.Callable[[dict[str, Any]], typing.Any]] = {} +) -> langchain_core.runnables.base.RunnableSerializable[typing.Any, typing.Any] +``` + + + + + + +Assigns new fields to the `dict` output of this `Runnable`. + + + +```python +from langchain_core.language_models.fake import FakeStreamingListLLM +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import SystemMessagePromptTemplate +from langchain_core.runnables import Runnable +from operator import itemgetter + +prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" +) +model = FakeStreamingListLLM(responses=["foo-lish"]) + +chain: Runnable = prompt | model | {"str": StrOutputParser()} + +chain_with_assign = chain.assign(hello=itemgetter("str") | model) + +print(chain_with_assign.input_schema.model_json_schema()) +# {'title': 'PromptInput', 'type': 'object', 'properties': +{'question': {'title': 'Question', 'type': 'string'}}} +print(chain_with_assign.output_schema.model_json_schema()) +# {'title': 'RunnableSequenceOutput', 'type': 'object', 'properties': +{'str': {'title': 'Str', +'type': 'string'}, 'hello': {'title': 'Hello', 'type': 'string'}}} +``` + + + +**Parameters:** + + +A mapping of keys to `Runnable` or `Runnable`-like objects +that will be invoked with the entire output dict of this `Runnable`. + + +**Returns:** `RunnableSerializable[Any, Any]` + +A new `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.astream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + +Default implementation of `astream`, which calls `ainvoke`. + +Subclasses must override this method if they support streaming output. + +**Parameters:** + + +The input to the `Runnable`. + + + +The config to use for the `Runnable`. + + + +Additional keyword arguments to pass to the `Runnable`. + + + + + + + + +```python +langchain_core.runnables.base.Runnable.astream_events( + input: typing.Any, + config: langchain_core.runnables.config.RunnableConfig | None = None, + version: typing.Literal['v1', 'v2'] = 'v2', + include_names: collections.abc.Sequence[str] | None = None, + include_types: collections.abc.Sequence[str] | None = None, + include_tags: collections.abc.Sequence[str] | None = None, + exclude_names: collections.abc.Sequence[str] | None = None, + exclude_types: collections.abc.Sequence[str] | None = None, + exclude_tags: collections.abc.Sequence[str] | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.schema.StreamEvent] +``` + + + + + + +async + +Generate a stream of events. + +Use to create an iterator over `StreamEvent` that provide real-time information +about the progress of the `Runnable`, including `StreamEvent` from intermediate +results. + +A `StreamEvent` is a dictionary with the following schema: + +- `event`: Event names are of the format: + `on_[runnable_type]_(start|stream|end)`. +- `name`: The name of the `Runnable` that generated the event. +- `run_id`: Randomly generated ID associated with the given execution of the + `Runnable` that emitted the event. A child `Runnable` that gets invoked as + part of the execution of a parent `Runnable` is assigned its own unique ID. +- `parent_ids`: The IDs of the parent runnables that generated the event. The + root `Runnable` will have an empty list. The order of the parent IDs is from + the root to the immediate parent. Only available for v2 version of the API. + The v1 version of the API will return an empty list. +- `tags`: The tags of the `Runnable` that generated the event. +- `metadata`: The metadata of the `Runnable` that generated the event. +- `data`: The data associated with the event. The contents of this field + depend on the type of event. See the table below for more details. + +Below is a table that illustrates some events that might be emitted by various +chains. Metadata fields have been omitted from the table for brevity. +Chain definitions have been included after the table. + +!!! note + This reference table is for the v2 version of the schema. + +| event | name | chunk | input | output | +| ---------------------- | -------------------- | ----------------------------------- | ------------------------------------------------- | --------------------------------------------------- | +| `on_chat_model_start` | `'[model name]'` | | `{"messages": [[SystemMessage, HumanMessage]]}` | | +| `on_chat_model_stream` | `'[model name]'` | `AIMessageChunk(content="hello")` | | | +| `on_chat_model_end` | `'[model name]'` | | `{"messages": [[SystemMessage, HumanMessage]]}` | `AIMessageChunk(content="hello world")` | +| `on_llm_start` | `'[model name]'` | | `{'input': 'hello'}` | | +| `on_llm_stream` | `'[model name]'` | `'Hello' ` | | | +| `on_llm_end` | `'[model name]'` | | `'Hello human!'` | | +| `on_chain_start` | `'format_docs'` | | | | +| `on_chain_stream` | `'format_docs'` | `'hello world!, goodbye world!'` | | | +| `on_chain_end` | `'format_docs'` | | `[Document(...)]` | `'hello world!, goodbye world!'` | +| `on_tool_start` | `'some_tool'` | | `{"x": 1, "y": "2"}` | | +| `on_tool_end` | `'some_tool'` | | | `{"x": 1, "y": "2"}` | +| `on_retriever_start` | `'[retriever name]'` | | `{"query": "hello"}` | | +| `on_retriever_end` | `'[retriever name]'` | | `{"query": "hello"}` | `[Document(...), ..]` | +| `on_prompt_start` | `'[template_name]'` | | `{"question": "hello"}` | | +| `on_prompt_end` | `'[template_name]'` | | `{"question": "hello"}` | `ChatPromptValue(messages: [SystemMessage, ...])` | + +In addition to the standard events, users can also dispatch custom events (see example below). + +Custom events will be only be surfaced with in the v2 version of the API! + +A custom event has following format: + +| Attribute | Type | Description | +| ----------- | ------ | --------------------------------------------------------------------------------------------------------- | +| `name` | `str` | A user defined name for the event. | +| `data` | `Any` | The data associated with the event. This can be anything, though we suggest making it JSON serializable. | + +Here are declarations associated with the standard events shown above: + +`format_docs`: + + + +```python +def format_docs(docs: list[Document]) -> str: + '''Format the docs.''' + return ", ".join([doc.page_content for doc in docs]) + + +format_docs = RunnableLambda(format_docs) +``` + + + +`some_tool`: + + + +```python +@tool +def some_tool(x: int, y: str) -> dict: + '''Some_tool.''' + return {"x": x, "y": y} +``` + + + +`prompt`: + + + +```python +template = ChatPromptTemplate.from_messages( + [ + ("system", "You are Cat Agent 007"), + ("human", "{question}"), + ] +).with_config({"run_name": "my_template", "tags": ["my_template"]}) +``` + + + +!!! example + + ```python + from langchain_core.runnables import RunnableLambda + + + async def reverse(s: str) -> str: + return s[::-1] + + + chain = RunnableLambda(func=reverse) + + events = [ + event async for event in chain.astream_events("hello", version="v2") + ] + + # Will produce the following events + # (run_id, and parent_ids has been omitted for brevity): + [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "reverse", + "tags": [], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "reverse", + "tags": [], + }, + { + "data": {"output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "reverse", + "tags": [], + }, + ] + ``` + +```python title="Dispatch custom event" +from langchain_core.callbacks.manager import ( + adispatch_custom_event, +) +from langchain_core.runnables import RunnableLambda, RunnableConfig +import asyncio + + +async def slow_thing(some_input: str, config: RunnableConfig) -> str: + """Do something that takes a long time.""" + await asyncio.sleep(1) # Placeholder for some slow operation + await adispatch_custom_event( + "progress_event", + {"message": "Finished step 1 of 3"}, + config=config # Must be included for python < 3.10 + ) + await asyncio.sleep(1) # Placeholder for some slow operation + await adispatch_custom_event( + "progress_event", + {"message": "Finished step 2 of 3"}, + config=config # Must be included for python < 3.10 + ) + await asyncio.sleep(1) # Placeholder for some slow operation + return "Done" + +slow_thing = RunnableLambda(slow_thing) + +async for event in slow_thing.astream_events("some_input", version="v2"): + print(event) +``` + +**Parameters:** + + +The input to the `Runnable`. + + + +The config to use for the `Runnable`. + + + +The version of the schema to use, either `'v2'` or `'v1'`. + +Users should use `'v2'`. + +`'v1'` is for backwards compatibility and will be deprecated +in `0.4.0`. + +No default will be assigned until the API is stabilized. +custom events will only be surfaced in `'v2'`. + + + +Only include events from `Runnable` objects with matching names. + + + +Only include events from `Runnable` objects with matching types. + + + +Only include events from `Runnable` objects with matching tags. + + + +Exclude events from `Runnable` objects with matching names. + + + +Exclude events from `Runnable` objects with matching types. + + + +Exclude events from `Runnable` objects with matching tags. + + + +Additional keyword arguments to pass to the `Runnable`. + +These will be passed to `astream_log` as this implementation +of `astream_events` is built on top of `astream_log`. + + +**Raises:** + +- `NotImplementedError`: If the version is not `'v1'` or `'v2'`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.astream_log( + input: typing.Any, + config: langchain_core.runnables.config.RunnableConfig | None = None, + diff: bool = True, + with_streamed_output_list: bool = True, + include_names: collections.abc.Sequence[str] | None = None, + include_types: collections.abc.Sequence[str] | None = None, + include_tags: collections.abc.Sequence[str] | None = None, + exclude_names: collections.abc.Sequence[str] | None = None, + exclude_types: collections.abc.Sequence[str] | None = None, + exclude_tags: collections.abc.Sequence[str] | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.tracers.log_stream.RunLogPatch] | collections.abc.AsyncIterator[langchain_core.tracers.log_stream.RunLog] +``` + + + + + + +async + +Stream all output from a `Runnable`, as reported to the callback system. + +This includes all inner runs of LLMs, Retrievers, Tools, etc. + +Output is streamed as Log objects, which include a list of +Jsonpatch ops that describe how the state of the run has changed in each +step, and the final state of the run. + +The Jsonpatch ops can be applied in order to construct state. + +**Parameters:** + + +The input to the `Runnable`. + + + +The config to use for the `Runnable`. + + + +Whether to yield diffs between each step or the current state. + + + +Whether to yield the `streamed_output` list. + + + +Only include logs with these names. + + + +Only include logs with these types. + + + +Only include logs with these tags. + + + +Exclude logs with these names. + + + +Exclude logs with these types. + + + +Exclude logs with these tags. + + + +Additional keyword arguments to pass to the `Runnable`. + + + + + + + + +```python +langchain_core.runnables.base.Runnable.atransform( + input: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + +Transform inputs to outputs. + +Default implementation of atransform, which buffers input and calls `astream`. + +Subclasses must override this method if they can start producing output while +input is still being generated. + +**Parameters:** + + +An async iterator of inputs to the `Runnable`. + + + +The config to use for the `Runnable`. + + + +Additional keyword arguments to pass to the `Runnable`. + + + + + + + + +```python +langchain_core.runnables.base.Runnable.batch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +Default implementation runs invoke in parallel using a thread pool executor. + +The default implementation of batch works well for IO bound runnables. + +Subclasses must override this method if they can batch more efficiently; +e.g., if the underlying `Runnable` uses an API which supports a batch mode. + +**Parameters:** + + +A list of inputs to the `Runnable`. + + + +A config to use when invoking the `Runnable`. The config supports +standard keys like `'tags'`, `'metadata'` for +tracing purposes, `'max_concurrency'` for controlling how much work +to do in parallel, and other keys. + +Please refer to `RunnableConfig` for more details. + + + +Whether to return exceptions instead of raising them. + + + +Additional keyword arguments to pass to the `Runnable`. + + +**Returns:** `list[Output]` + +A list of outputs from the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.batch_as_completed( + inputs: collections.abc.Sequence[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | collections.abc.Sequence[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[tuple[int, langchain_core.runnables.utils.Output | Exception]] +``` + + + + + + +Run `invoke` in parallel on a list of inputs. + +Yields results as they complete. + +**Parameters:** + + +A list of inputs to the `Runnable`. + + + +A config to use when invoking the `Runnable`. + +The config supports standard keys like `'tags'`, `'metadata'` for +tracing purposes, `'max_concurrency'` for controlling how much work to +do in parallel, and other keys. + +Please refer to `RunnableConfig` for more details. + + + +Whether to return exceptions instead of raising them. + + + +Additional keyword arguments to pass to the `Runnable`. + + + + + + + + +```python +langchain_core.runnables.base.Runnable.bind( + kwargs: typing.Any = {} +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Bind arguments to a `Runnable`, returning a new `Runnable`. + +Useful when a `Runnable` in a chain requires an argument that is not +in the output of the previous `Runnable` or included in the user input. + +**Parameters:** + + +The arguments to bind to the `Runnable`. + + +**Returns:** `Runnable[Input, Output]` + +A new `Runnable` with the arguments bound. + + + + + + + +```python +langchain_core.runnables.base.Runnable.config_schema( + include: collections.abc.Sequence[str] | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +The type of config this `Runnable` accepts specified as a Pydantic model. + +To mark a field as configurable, see the `configurable_fields` +and `configurable_alternatives` methods. + +**Parameters:** + + +A list of fields to include in the config schema. + + +**Returns:** `type[BaseModel]` + +A Pydantic model that can be used to validate config. + + + + + + + +```python +langchain_core.runnables.base.Runnable.get_config_jsonschema( + include: collections.abc.Sequence[str] | None = None +) -> dict[str, typing.Any] +``` + + + + + + +Get a JSON schema that represents the config of the `Runnable`. + +!!! version-added "Added in `langchain-core` 0.3.0" + +**Parameters:** + + +A list of fields to include in the config schema. + + +**Returns:** `dict[str, Any]` + +A JSON schema that represents the config of the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.get_graph( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> langchain_core.runnables.graph.Graph +``` + + + + + + +Return a graph representation of this `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.get_input_jsonschema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> dict[str, typing.Any] +``` + + + + + + +Get a JSON schema that represents the input to the `Runnable`. + +!!! version-added "Added in `langchain-core` 0.3.0" + +**Parameters:** + + +A config to use when generating the schema. + + +**Returns:** `dict[str, Any]` + +A JSON schema that represents the input to the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Get a Pydantic model that can be used to validate input to the `Runnable`. + +`Runnable` objects that leverage the `configurable_fields` and +`configurable_alternatives` methods will have a dynamic input schema that +depends on which configuration the `Runnable` is invoked with. + +This method allows to get an input schema for a specific configuration. + +**Parameters:** + + +A config to use when generating the schema. + + +**Returns:** `type[BaseModel]` + +A Pydantic model that can be used to validate input. + + + + + + + +```python +langchain_core.runnables.base.Runnable.get_name( + suffix: str | None = None, + name: str | None = None +) -> str +``` + + + + + + +Get the name of the `Runnable`. + +**Parameters:** + + +An optional suffix to append to the name. + + + +An optional name to use instead of the `Runnable`'s name. + + +**Returns:** `str` + +The name of the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.get_output_jsonschema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> dict[str, typing.Any] +``` + + + + + + +Get a JSON schema that represents the output of the `Runnable`. + +!!! version-added "Added in `langchain-core` 0.3.0" + +**Parameters:** + + +A config to use when generating the schema. + + +**Returns:** `dict[str, Any]` + +A JSON schema that represents the output of the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.get_output_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Get a Pydantic model that can be used to validate output to the `Runnable`. + +`Runnable` objects that leverage the `configurable_fields` and +`configurable_alternatives` methods will have a dynamic output schema that +depends on which configuration the `Runnable` is invoked with. + +This method allows to get an output schema for a specific configuration. + +**Parameters:** + + +A config to use when generating the schema. + + +**Returns:** `type[BaseModel]` + +A Pydantic model that can be used to validate output. + + + + + + + +```python +langchain_core.runnables.base.Runnable.get_prompts( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> list[langchain_core.prompts.base.BasePromptTemplate] +``` + + + + + + +Return a list of prompts used by this `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.invoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +abstract + +Transform a single input into an output. + +**Parameters:** + + +The input to the `Runnable`. + + + +A config to use when invoking the `Runnable`. + +The config supports standard keys like `'tags'`, `'metadata'` for +tracing purposes, `'max_concurrency'` for controlling how much work to +do in parallel, and other keys. + +Please refer to `RunnableConfig` for more details. + + +**Returns:** `Output` + +The output of the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.map() -> langchain_core.runnables.base.Runnable[list[langchain_core.runnables.utils.Input], list[langchain_core.runnables.utils.Output]] +``` + + + + + + +Return a new `Runnable` that maps a list of inputs to a list of outputs. + +Calls `invoke` with each input. + +**Returns:** `Runnable[list[Input], list[Output]]` + +A new `Runnable` that maps a list of inputs to a list of outputs. + + + + + + + +```python +langchain_core.runnables.base.Runnable.pick( + keys: str | list[str] +) -> langchain_core.runnables.base.RunnableSerializable[typing.Any, typing.Any] +``` + + + + + + +Pick keys from the output `dict` of this `Runnable`. + +!!! example "Pick a single key" + + ```python + import json + + from langchain_core.runnables import RunnableLambda, RunnableMap + + as_str = RunnableLambda(str) + as_json = RunnableLambda(json.loads) + chain = RunnableMap(str=as_str, json=as_json) + + chain.invoke("[1, 2, 3]") + # -> {"str": "[1, 2, 3]", "json": [1, 2, 3]} + + json_only_chain = chain.pick("json") + json_only_chain.invoke("[1, 2, 3]") + # -> [1, 2, 3] + ``` + +!!! example "Pick a list of keys" + + ```python + from typing import Any + + import json + + from langchain_core.runnables import RunnableLambda, RunnableMap + + as_str = RunnableLambda(str) + as_json = RunnableLambda(json.loads) + + + def as_bytes(x: Any) -> bytes: + return bytes(x, "utf-8") + + + chain = RunnableMap( + str=as_str, json=as_json, bytes=RunnableLambda(as_bytes) + ) + + chain.invoke("[1, 2, 3]") + # -> {"str": "[1, 2, 3]", "json": [1, 2, 3], "bytes": b"[1, 2, 3]"} + + json_and_bytes_chain = chain.pick(["json", "bytes"]) + json_and_bytes_chain.invoke("[1, 2, 3]") + # -> {"json": [1, 2, 3], "bytes": b"[1, 2, 3]"} + ``` + +**Parameters:** + + +A key or list of keys to pick from the output dict. + + +**Returns:** `RunnableSerializable[Any, Any]` + +a new `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.pipe( + others: langchain_core.runnables.base.Runnable[typing.Any, langchain_core.runnables.base.Other] | collections.abc.Callable[[Any], langchain_core.runnables.base.Other] = (), + name: str | None = None +) -> langchain_core.runnables.base.RunnableSerializable[langchain_core.runnables.utils.Input, langchain_core.runnables.base.Other] +``` + + + + + + +Pipe `Runnable` objects. + +Compose this `Runnable` with `Runnable`-like objects to make a +`RunnableSequence`. + +Equivalent to `RunnableSequence(self, *others)` or `self | others[0] | ...` + +**Parameters:** + + +Other `Runnable` or `Runnable`-like objects to compose + + + +An optional name for the resulting `RunnableSequence`. + + +**Returns:** `RunnableSerializable[Input, Other]` + +A new `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.Runnable.stream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + +Default implementation of `stream`, which calls `invoke`. + +Subclasses must override this method if they support streaming output. + +**Parameters:** + + +The input to the `Runnable`. + + + +The config to use for the `Runnable`. + + + +Additional keyword arguments to pass to the `Runnable`. + + + + + + + + +```python +langchain_core.runnables.base.Runnable.transform( + input: collections.abc.Iterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + +Transform inputs to outputs. + +Default implementation of transform, which buffers input and calls `astream`. + +Subclasses must override this method if they can start producing output while +input is still being generated. + +**Parameters:** + + +An iterator of inputs to the `Runnable`. + + + +The config to use for the `Runnable`. + + + +Additional keyword arguments to pass to the `Runnable`. + + + + + + + + +```python +langchain_core.runnables.base.Runnable.with_alisteners( + on_start: langchain_core.tracers.root_listeners.AsyncListener | None = None, + on_end: langchain_core.tracers.root_listeners.AsyncListener | None = None, + on_error: langchain_core.tracers.root_listeners.AsyncListener | None = None +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Bind async lifecycle listeners to a `Runnable`. + +Returns a new `Runnable`. + +The Run object contains information about the run, including its `id`, +`type`, `input`, `output`, `error`, `start_time`, `end_time`, and +any tags or metadata added to the run. + +**Parameters:** + + +Called asynchronously before the `Runnable` starts running, +with the `Run` object. + + + +Called asynchronously after the `Runnable` finishes running, +with the `Run` object. + + + +Called asynchronously if the `Runnable` throws an error, +with the `Run` object. + + +**Returns:** `Runnable[Input, Output]` + +A new `Runnable` with the listeners bound. + + + + + + + +```python +langchain_core.runnables.base.Runnable.with_config( + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Bind config to a `Runnable`, returning a new `Runnable`. + +**Parameters:** + + +The config to bind to the `Runnable`. + + + +Additional keyword arguments to pass to the `Runnable`. + + +**Returns:** `Runnable[Input, Output]` + +A new `Runnable` with the config bound. + + + + + + + +```python +langchain_core.runnables.base.Runnable.with_fallbacks( + fallbacks: collections.abc.Sequence[langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output]], + exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,), + exception_key: str | None = None +) -> langchain_core.runnables.fallbacks.RunnableWithFallbacks[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Add fallbacks to a `Runnable`, returning a new `Runnable`. + +The new `Runnable` will try the original `Runnable`, and then each fallback +in order, upon failures. + +**Parameters:** + + +A sequence of runnables to try if the original `Runnable` +fails. + + + +A tuple of exception types to handle. + + + +If `string` is specified then handled exceptions will be +passed to fallbacks as part of the input under the specified key. + +If `None`, exceptions will not be passed to fallbacks. + +If used, the base `Runnable` and its fallbacks must accept a +dictionary as input. + + + +A sequence of runnables to try if the original `Runnable` +fails. + + + +A tuple of exception types to handle. + + + +If `string` is specified then handled exceptions will be +passed to fallbacks as part of the input under the specified key. + +If `None`, exceptions will not be passed to fallbacks. + +If used, the base `Runnable` and its fallbacks must accept a +dictionary as input. + + +**Returns:** `RunnableWithFallbacksT[Input, Output]` + +A new `Runnable` that will try the original `Runnable`, and then each +Fallback in order, upon failures. + + + + + + + +```python +langchain_core.runnables.base.Runnable.with_listeners( + on_start: collections.abc.Callable[[Run], None] | collections.abc.Callable[[Run, RunnableConfig], None] | None = None, + on_end: collections.abc.Callable[[Run], None] | collections.abc.Callable[[Run, RunnableConfig], None] | None = None, + on_error: collections.abc.Callable[[Run], None] | collections.abc.Callable[[Run, RunnableConfig], None] | None = None +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Bind lifecycle listeners to a `Runnable`, returning a new `Runnable`. + +The Run object contains information about the run, including its `id`, +`type`, `input`, `output`, `error`, `start_time`, `end_time`, and +any tags or metadata added to the run. + +**Parameters:** + + +Called before the `Runnable` starts running, with the `Run` +object. + + + +Called after the `Runnable` finishes running, with the `Run` +object. + + + +Called if the `Runnable` throws an error, with the `Run` +object. + + +**Returns:** `Runnable[Input, Output]` + +A new `Runnable` with the listeners bound. + + + + + + + +```python +langchain_core.runnables.base.Runnable.with_retry( + retry_if_exception_type: tuple[type[BaseException], ...] = (Exception,), + wait_exponential_jitter: bool = True, + exponential_jitter_params: langchain_core.runnables.retry.ExponentialJitterParams | None = None, + stop_after_attempt: int = 3 +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Create a new `Runnable` that retries the original `Runnable` on exceptions. + +**Parameters:** + + +A tuple of exception types to retry on. + + + +Whether to add jitter to the wait +time between retries. + + + +The maximum number of attempts to make before +giving up. + + + +Parameters for +`tenacity.wait_exponential_jitter`. Namely: `initial`, `max`, +`exp_base`, and `jitter` (all `float` values). + + +**Returns:** `Runnable[Input, Output]` + +A new `Runnable` that retries the original `Runnable` on exceptions. + + + + + + + +```python +langchain_core.runnables.base.Runnable.with_types( + input_type: type[langchain_core.runnables.utils.Input] | None = None, + output_type: type[langchain_core.runnables.utils.Output] | None = None +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Bind input and output types to a `Runnable`, returning a new `Runnable`. + +**Parameters:** + + +The input type to bind to the `Runnable`. + + + +The output type to bind to the `Runnable`. + + +**Returns:** `Runnable[Input, Output]` + +A new `Runnable` with the types bound. + + + + + + + + + +```python +class langchain_core.runnables.base.RunnableBinding() +``` + + + + + + +**Bases:** [RunnableBindingBase[Input, Output]](#langchain_core-runnables-base-RunnableBindingBase) + +Wrap a `Runnable` with additional functionality. + +A `RunnableBinding` can be thought of as a "runnable decorator" that +preserves the essential features of `Runnable`; i.e., batching, streaming, +and async support, while adding additional functionality. + +Any class that inherits from `Runnable` can be bound to a `RunnableBinding`. +Runnables expose a standard set of methods for creating `RunnableBindings` +or sub-classes of `RunnableBindings` (e.g., `RunnableRetry`, +`RunnableWithFallbacks`) that add additional functionality. + +These methods include: + +- `bind`: Bind kwargs to pass to the underlying `Runnable` when running it. +- `with_config`: Bind config to pass to the underlying `Runnable` when running + it. +- `with_listeners`: Bind lifecycle listeners to the underlying `Runnable`. +- `with_types`: Override the input and output types of the underlying + `Runnable`. +- `with_retry`: Bind a retry policy to the underlying `Runnable`. +- `with_fallbacks`: Bind a fallback policy to the underlying `Runnable`. + +Example: +`bind`: Bind kwargs to pass to the underlying `Runnable` when running it. + + ```python + # Create a Runnable binding that invokes the chat model with the + # additional kwarg `stop=['-']` when running it. + from langchain_openai import ChatOpenAI + + model = ChatOpenAI() + model.invoke('Say "Parrot-MAGIC"', stop=["-"]) # Should return `Parrot` + # Using it the easy way via `bind` method which returns a new + # RunnableBinding + runnable_binding = model.bind(stop=["-"]) + runnable_binding.invoke('Say "Parrot-MAGIC"') # Should return `Parrot` + ``` + Can also be done by instantiating a `RunnableBinding` directly (not + recommended): + + ```python + from langchain_core.runnables import RunnableBinding + + runnable_binding = RunnableBinding( + bound=model, + kwargs={"stop": ["-"]}, # <-- Note the additional kwargs + ) + runnable_binding.invoke('Say "Parrot-MAGIC"') # Should return `Parrot` + ``` + + + + + + +```python +langchain_core.runnables.base.RunnableBinding.__getattr__( + name: str +) -> typing.Any +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBinding.bind( + kwargs: typing.Any = {} +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Bind additional kwargs to a `Runnable`, returning a new `Runnable`. + +**Parameters:** + + +The kwargs to bind to the `Runnable`. + + +**Returns:** `Runnable[Input, Output]` + +A new `Runnable` with the same type and config as the original, + + + + + + + +```python +langchain_core.runnables.base.RunnableBinding.with_config( + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBinding.with_listeners( + on_start: collections.abc.Callable[[Run], None] | collections.abc.Callable[[Run, RunnableConfig], None] | None = None, + on_end: collections.abc.Callable[[Run], None] | collections.abc.Callable[[Run, RunnableConfig], None] | None = None, + on_error: collections.abc.Callable[[Run], None] | collections.abc.Callable[[Run, RunnableConfig], None] | None = None +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Bind lifecycle listeners to a `Runnable`, returning a new `Runnable`. + +The `Run` object contains information about the run, including its `id`, +`type`, `input`, `output`, `error`, `start_time`, `end_time`, and +any tags or metadata added to the run. + +**Parameters:** + + +Called before the `Runnable` starts running, with the `Run` +object. + + + +Called after the `Runnable` finishes running, with the `Run` +object. + + + +Called if the `Runnable` throws an error, with the `Run` +object. + + +**Returns:** `Runnable[Input, Output]` + +A new `Runnable` with the listeners bound. + + + + + + + +```python +langchain_core.runnables.base.RunnableBinding.with_retry( + kwargs: typing.Any = {} +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBinding.with_types( + input_type: type[langchain_core.runnables.utils.Input] | pydantic.BaseModel | None = None, + output_type: type[langchain_core.runnables.utils.Output] | pydantic.BaseModel | None = None +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.base.RunnableBindingBase( + bound: langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output], + kwargs: collections.abc.Mapping[str, typing.Any] | None = None, + config: langchain_core.runnables.config.RunnableConfig | None = None, + config_factories: list[collections.abc.Callable[[RunnableConfig], langchain_core.runnables.config.RunnableConfig]] | None = None, + custom_input_type: type[langchain_core.runnables.utils.Input] | pydantic.BaseModel | None = None, + custom_output_type: type[langchain_core.runnables.utils.Output] | pydantic.BaseModel | None = None, + other_kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [RunnableSerializable[Input, Output]](#langchain_core-runnables-base-RunnableSerializable) + +`Runnable` that delegates calls to another `Runnable` with a set of `**kwargs`. + +Use only if creating a new `RunnableBinding` subclass with different `__init__` +args. + +See documentation for `RunnableBinding` for more details. + + + + + + + + + +The underlying `Runnable` that this `Runnable` delegates to. + + + +The config to bind to the underlying `Runnable`. + + + +The config factories to bind to the underlying `Runnable`. + + + + + + +Override the input type of the underlying `Runnable` with a custom type. + +The type can be a Pydantic model, or a type annotation (e.g., `list[str]`). + + + +Override the output type of the underlying `Runnable` with a custom type. + +The type can be a Pydantic model, or a type annotation (e.g., `list[str]`). + + + +kwargs to pass to the underlying `Runnable` when running. + +For example, when the `Runnable` binding is invoked the underlying +`Runnable` will be invoked with the same input but with these additional +kwargs. + + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase._merge_configs( + configs: langchain_core.runnables.config.RunnableConfig | None = () +) -> langchain_core.runnables.config.RunnableConfig +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.abatch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.abatch_as_completed( + inputs: collections.abc.Sequence[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | collections.abc.Sequence[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[tuple[int, langchain_core.runnables.utils.Output | Exception]] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.ainvoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.astream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.astream_events( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.schema.StreamEvent] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.atransform( + input: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.batch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.batch_as_completed( + inputs: collections.abc.Sequence[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | collections.abc.Sequence[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[tuple[int, langchain_core.runnables.utils.Output | Exception]] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.get_graph( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> langchain_core.runnables.graph.Graph +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "runnable"]` + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.get_name( + suffix: str | None = None, + name: str | None = None +) -> str +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.get_output_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.invoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.stream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableBindingBase.transform( + input: collections.abc.Iterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.base.RunnableEach() +``` + + + + + + +**Bases:** [RunnableEachBase[Input, Output]](#langchain_core-runnables-base-RunnableEachBase) + +RunnableEach class. + +`Runnable` that calls another `Runnable` for each element of the input sequence. + +It allows you to call multiple inputs with the bounded `Runnable`. + +`RunnableEach` makes it easy to run multiple inputs for the `Runnable`. +In the below example, we associate and run three inputs +with a `Runnable`: + + ```python + from langchain_core.runnables.base import RunnableEach + from langchain_openai import ChatOpenAI + from langchain_core.prompts import ChatPromptTemplate + from langchain_core.output_parsers import StrOutputParser + prompt = ChatPromptTemplate.from_template("Tell me a short joke about + {topic}") + model = ChatOpenAI() + output_parser = StrOutputParser() + runnable = prompt | model | output_parser + runnable_each = RunnableEach(bound=runnable) + output = runnable_each.invoke([{'topic':'Computer Science'}, + {'topic':'Art'}, + {'topic':'Biology'}]) + print(output) # noqa: T201 + + ``` + + + + + + +```python +langchain_core.runnables.base.RunnableEach.bind( + kwargs: typing.Any = {} +) -> langchain_core.runnables.base.RunnableEach[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableEach.get_name( + suffix: str | None = None, + name: str | None = None +) -> str +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableEach.with_alisteners( + on_start: langchain_core.tracers.root_listeners.AsyncListener | None = None, + on_end: langchain_core.tracers.root_listeners.AsyncListener | None = None, + on_error: langchain_core.tracers.root_listeners.AsyncListener | None = None +) -> langchain_core.runnables.base.RunnableEach[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Bind async lifecycle listeners to a `Runnable`. + +Returns a new `Runnable`. + +The `Run` object contains information about the run, including its `id`, +`type`, `input`, `output`, `error`, `start_time`, `end_time`, and +any tags or metadata added to the run. + +**Parameters:** + + +Called asynchronously before the `Runnable` starts running, +with the `Run` object. + + + +Called asynchronously after the `Runnable` finishes running, +with the `Run` object. + + + +Called asynchronously if the `Runnable` throws an error, +with the `Run` object. + + +**Returns:** `RunnableEach[Input, Output]` + +A new `Runnable` with the listeners bound. + + + + + + + +```python +langchain_core.runnables.base.RunnableEach.with_config( + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.base.RunnableEach[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableEach.with_listeners( + on_start: collections.abc.Callable[[Run], None] | collections.abc.Callable[[Run, RunnableConfig], None] | None = None, + on_end: collections.abc.Callable[[Run], None] | collections.abc.Callable[[Run, RunnableConfig], None] | None = None, + on_error: collections.abc.Callable[[Run], None] | collections.abc.Callable[[Run, RunnableConfig], None] | None = None +) -> langchain_core.runnables.base.RunnableEach[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Bind lifecycle listeners to a `Runnable`, returning a new `Runnable`. + +The `Run` object contains information about the run, including its `id`, +`type`, `input`, `output`, `error`, `start_time`, `end_time`, and +any tags or metadata added to the run. + +**Parameters:** + + +Called before the `Runnable` starts running, with the `Run` +object. + + + +Called after the `Runnable` finishes running, with the `Run` +object. + + + +Called if the `Runnable` throws an error, with the `Run` +object. + + +**Returns:** `RunnableEach[Input, Output]` + +A new `Runnable` with the listeners bound. + + + + + + + + + +```python +class langchain_core.runnables.base.RunnableEachBase() +``` + + + + + + +**Bases:** [RunnableSerializable[list[Input], list[Output]]](#langchain_core-runnables-base-RunnableSerializable) + +RunnableEachBase class. + +`Runnable` that calls another `Runnable` for each element of the input sequence. + +Use only if creating a new `RunnableEach` subclass with different `__init__` +args. + +See documentation for `RunnableEach` for more details. + + + + + + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableEachBase._ainvoke( + inputs: list[langchain_core.runnables.utils.Input], + run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableEachBase._invoke( + inputs: list[langchain_core.runnables.utils.Input], + run_manager: langchain_core.callbacks.manager.CallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableEachBase.ainvoke( + input: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableEachBase.astream_events( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.schema.StreamEvent] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableEachBase.get_graph( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> langchain_core.runnables.graph.Graph +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableEachBase.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableEachBase.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "runnable"]` + + + + + + + +```python +langchain_core.runnables.base.RunnableEachBase.get_output_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableEachBase.invoke( + input: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableEachBase.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + + + +```python +class langchain_core.runnables.base.RunnableGenerator( + transform: collections.abc.Callable[[Iterator[Input]], collections.abc.Iterator[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[AsyncIterator[Input]], collections.abc.AsyncIterator[langchain_core.runnables.utils.Output]], + atransform: collections.abc.Callable[[AsyncIterator[Input]], collections.abc.AsyncIterator[langchain_core.runnables.utils.Output]] | None = None, + name: str | None = None +) +``` + + + + + + +**Bases:** [Runnable[Input, Output]](#langchain_core-runnables-base-Runnable) + +`Runnable` that runs a generator function. + +`RunnableGenerator`s can be instantiated directly or by using a generator within +a sequence. + +`RunnableGenerator`s can be used to implement custom behavior, such as custom +output parsers, while preserving streaming capabilities. Given a generator function +with a signature `Iterator[A] -> Iterator[B]`, wrapping it in a +`RunnableGenerator` allows it to emit output chunks as soon as they are streamed +in from the previous step. + +!!! note + If a generator function has a `signature A -> Iterator[B]`, such that it + requires its input from the previous step to be completed before emitting chunks + (e.g., most LLMs need the entire prompt available to start generating), it can + instead be wrapped in a `RunnableLambda`. + +Here is an example to show the basic mechanics of a `RunnableGenerator`: + + ```python + from typing import Any, AsyncIterator, Iterator + + from langchain_core.runnables import RunnableGenerator + + + def gen(input: Iterator[Any]) -> Iterator[str]: + for token in ["Have", " a", " nice", " day"]: + yield token + + + runnable = RunnableGenerator(gen) + runnable.invoke(None) # "Have a nice day" + list(runnable.stream(None)) # ["Have", " a", " nice", " day"] + runnable.batch([None, None]) # ["Have a nice day", "Have a nice day"] + + + # Async version: + async def agen(input: AsyncIterator[Any]) -> AsyncIterator[str]: + for token in ["Have", " a", " nice", " day"]: + yield token + + + runnable = RunnableGenerator(agen) + await runnable.ainvoke(None) # "Have a nice day" + [p async for p in runnable.astream(None)] # ["Have", " a", " nice", " day"] + ``` + +`RunnableGenerator` makes it easy to implement custom behavior within a streaming +context. Below we show an example: + + ```python + from langchain_core.prompts import ChatPromptTemplate + from langchain_core.runnables import RunnableGenerator, RunnableLambda + from langchain_openai import ChatOpenAI + from langchain_core.output_parsers import StrOutputParser + + + model = ChatOpenAI() + chant_chain = ( + ChatPromptTemplate.from_template("Give me a 3 word chant about {topic}") + | model + | StrOutputParser() + ) + + + def character_generator(input: Iterator[str]) -> Iterator[str]: + for token in input: + if "," in token or "." in token: + yield "👏" + token + else: + yield token + + + runnable = chant_chain | character_generator + assert type(runnable.last) is RunnableGenerator + "".join(runnable.stream({"topic": "waste"})) # Reduce👏, Reuse👏, Recycle👏. + + + # Note that RunnableLambda can be used to delay streaming of one step in a + # sequence until the previous step is finished: + def reverse_generator(input: str) -> Iterator[str]: + # Yield characters of input in reverse order. + for character in input[::-1]: + yield character + + + runnable = chant_chain | RunnableLambda(reverse_generator) + "".join(runnable.stream({"topic": "waste"})) # ".elcycer ,esuer ,ecudeR" + ``` + + + + + + + + +```python +langchain_core.runnables.base.RunnableGenerator.__eq__( + other: object +) -> bool +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableGenerator.__repr__() -> str +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableGenerator.ainvoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableGenerator.astream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableGenerator.atransform( + input: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableGenerator.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableGenerator.get_output_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableGenerator.invoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableGenerator.stream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableGenerator.transform( + input: collections.abc.Iterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.base.RunnableLambda( + func: collections.abc.Callable[[Input], collections.abc.Iterator[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input], langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input], langchain_core.runnables.utils.Output] | collections.abc.Callable[[Input, RunnableConfig], langchain_core.runnables.utils.Output] | collections.abc.Callable[[Input, CallbackManagerForChainRun], langchain_core.runnables.utils.Output] | collections.abc.Callable[[Input, CallbackManagerForChainRun, RunnableConfig], langchain_core.runnables.utils.Output] | collections.abc.Callable[[Input], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input], collections.abc.AsyncIterator[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input, RunnableConfig], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input, AsyncCallbackManagerForChainRun], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input, AsyncCallbackManagerForChainRun, RunnableConfig], collections.abc.Awaitable[langchain_core.runnables.utils.Output]], + afunc: collections.abc.Callable[[Input], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input], collections.abc.AsyncIterator[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input, RunnableConfig], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input, AsyncCallbackManagerForChainRun], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input, AsyncCallbackManagerForChainRun, RunnableConfig], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | None = None, + name: str | None = None +) +``` + + + + + + +**Bases:** [Runnable[Input, Output]](#langchain_core-runnables-base-Runnable) + +`RunnableLambda` converts a python callable into a `Runnable`. + +Wrapping a callable in a `RunnableLambda` makes the callable usable +within either a sync or async context. + +`RunnableLambda` can be composed as any other `Runnable` and provides +seamless integration with LangChain tracing. + +`RunnableLambda` is best suited for code that does not need to support +streaming. If you need to support streaming (i.e., be able to operate +on chunks of inputs and yield chunks of outputs), use `RunnableGenerator` +instead. + +Note that if a `RunnableLambda` returns an instance of `Runnable`, that +instance is invoked (or streamed) during execution. + +**Examples:** + + + +```python +# This is a RunnableLambda +from langchain_core.runnables import RunnableLambda + + +def add_one(x: int) -> int: + return x + 1 + + +runnable = RunnableLambda(add_one) + +runnable.invoke(1) # returns 2 +runnable.batch([1, 2, 3]) # returns [2, 3, 4] + +# Async is supported by default by delegating to the sync implementation +await runnable.ainvoke(1) # returns 2 +await runnable.abatch([1, 2, 3]) # returns [2, 3, 4] + + +# Alternatively, can provide both synd and sync implementations +async def add_one_async(x: int) -> int: + return x + 1 + + +runnable = RunnableLambda(add_one, afunc=add_one_async) +runnable.invoke(1) # Uses add_one +await runnable.ainvoke(1) # Uses add_one_async +``` + + + + + +The type of the input to this `Runnable`. + + + +The type of the output of this `Runnable` as a type annotation. + + + + + + + + + +The dependencies of this `Runnable`. + + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda.__eq__( + other: object +) -> bool +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda.__repr__() -> str +``` + + + + + + +Return a string representation of this `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda._ainvoke( + value: langchain_core.runnables.utils.Input, + run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda._atransform( + chunks: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda._invoke( + input_: langchain_core.runnables.utils.Input, + run_manager: langchain_core.callbacks.manager.CallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda._transform( + chunks: collections.abc.Iterator[langchain_core.runnables.utils.Input], + run_manager: langchain_core.callbacks.manager.CallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda.ainvoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + +Invoke this `Runnable` asynchronously. + +**Parameters:** + + +The input to this `Runnable`. + + + +The config to use. + + + +Additional keyword arguments. + + +**Returns:** `Output` + +The output of this `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda.astream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda.atransform( + input: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda.get_graph( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> langchain_core.runnables.graph.Graph +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +The Pydantic schema for the input to this `Runnable`. + +**Parameters:** + + +The config to use. + + +**Returns:** `type[BaseModel]` + +The input schema for this `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda.get_output_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda.invoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +Invoke this `Runnable` synchronously. + +**Parameters:** + + +The input to this `Runnable`. + + + +The config to use. + + + +Additional keyword arguments. + + +**Returns:** `Output` + +The output of this `Runnable`. + +**Raises:** + +- `TypeError`: If the `Runnable` is a coroutine function. + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda.stream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableLambda.transform( + input: collections.abc.Iterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.base.RunnableParallel( + steps__: collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, typing.Any] | collections.abc.Callable[[Input], typing.Any] | collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, typing.Any] | collections.abc.Callable[[Input], typing.Any]]] | None = None, + kwargs: langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, typing.Any] | collections.abc.Callable[[Input], typing.Any] | collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, typing.Any] | collections.abc.Callable[[Input], typing.Any]] = {} +) +``` + + + + + + +**Bases:** [RunnableSerializable[Input, dict[str, Any]]](#langchain_core-runnables-base-RunnableSerializable) + +Runnable that runs a mapping of `Runnable`s in parallel. + +Returns a mapping of their outputs. + +`RunnableParallel` is one of the two main composition primitives, +alongside `RunnableSequence`. It invokes `Runnable`s concurrently, providing the +same input to each. + +A `RunnableParallel` can be instantiated directly or by using a dict literal +within a sequence. + +Here is a simple example that uses functions to illustrate the use of +`RunnableParallel`: + + ```python + from langchain_core.runnables import RunnableLambda + + + def add_one(x: int) -> int: + return x + 1 + + + def mul_two(x: int) -> int: + return x * 2 + + + def mul_three(x: int) -> int: + return x * 3 + + + runnable_1 = RunnableLambda(add_one) + runnable_2 = RunnableLambda(mul_two) + runnable_3 = RunnableLambda(mul_three) + + sequence = runnable_1 | { # this dict is coerced to a RunnableParallel + "mul_two": runnable_2, + "mul_three": runnable_3, + } + # Or equivalently: + # sequence = runnable_1 | RunnableParallel( + # {"mul_two": runnable_2, "mul_three": runnable_3} + # ) + # Also equivalently: + # sequence = runnable_1 | RunnableParallel( + # mul_two=runnable_2, + # mul_three=runnable_3, + # ) + + sequence.invoke(1) + await sequence.ainvoke(1) + + sequence.batch([1, 2, 3]) + await sequence.abatch([1, 2, 3]) + ``` + +`RunnableParallel` makes it easy to run `Runnable`s in parallel. In the below +example, we simultaneously stream output from two different `Runnable` objects: + + ```python + from langchain_core.prompts import ChatPromptTemplate + from langchain_core.runnables import RunnableParallel + from langchain_openai import ChatOpenAI + + model = ChatOpenAI() + joke_chain = ( + ChatPromptTemplate.from_template("tell me a joke about {topic}") | model + ) + poem_chain = ( + ChatPromptTemplate.from_template("write a 2-line poem about {topic}") + | model + ) + + runnable = RunnableParallel(joke=joke_chain, poem=poem_chain) + + # Display stream + output = {key: "" for key, _ in runnable.output_schema()} + for chunk in runnable.stream({"topic": "bear"}): + for key in chunk: + output[key] = output[key] + chunk[key].content + print(output) # noqa: T201 + ``` + + + +The type of the input to the `Runnable`. + + + +Get the config specs of the `Runnable`. + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.__repr__() -> str +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel._atransform( + inputs: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.AddableDict] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel._transform( + inputs: collections.abc.Iterator[langchain_core.runnables.utils.Input], + run_manager: langchain_core.callbacks.manager.CallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig +) -> collections.abc.Iterator[langchain_core.runnables.utils.AddableDict] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.ainvoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> dict[str, typing.Any] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.astream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[dict[str, typing.Any]] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.atransform( + input: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[dict[str, typing.Any]] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.get_graph( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> langchain_core.runnables.graph.Graph +``` + + + + + + +Get the graph representation of the `Runnable`. + +**Parameters:** + + +The config to use. + + +**Returns:** `Graph` + +The graph representation of the `Runnable`. + +**Raises:** + +- `ValueError`: If a `Runnable` has no first or last node. + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Get the input schema of the `Runnable`. + +**Parameters:** + + +The config to use. + + +**Returns:** `type[BaseModel]` + +The input schema of the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "runnable"]` + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.get_name( + suffix: str | None = None, + name: str | None = None +) -> str +``` + + + + + + +Get the name of the `Runnable`. + +**Parameters:** + + +The suffix to use. + + + +The name to use. + + +**Returns:** `str` + +The name of the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.get_output_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Get the output schema of the `Runnable`. + +**Parameters:** + + +The config to use. + + +**Returns:** `type[BaseModel]` + +The output schema of the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.invoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> dict[str, typing.Any] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.stream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[dict[str, typing.Any]] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableParallel.transform( + input: collections.abc.Iterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[dict[str, typing.Any]] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.base.RunnableSequence( + steps: langchain_core.runnables.base.RunnableLike = (), + name: str | None = None, + first: langchain_core.runnables.base.Runnable[typing.Any, typing.Any] | None = None, + middle: list[langchain_core.runnables.base.Runnable[typing.Any, typing.Any]] | None = None, + last: langchain_core.runnables.base.Runnable[typing.Any, typing.Any] | None = None +) +``` + + + + + + +**Bases:** [RunnableSerializable[Input, Output]](#langchain_core-runnables-base-RunnableSerializable) + +Sequence of `Runnable` objects, where the output of one is the input of the next. + +**`RunnableSequence`** is the most important composition operator in LangChain +as it is used in virtually every chain. + +A `RunnableSequence` can be instantiated directly or more commonly by using the +`|` operator where either the left or right operands (or both) must be a +`Runnable`. + +Any `RunnableSequence` automatically supports sync, async, batch. + +The default implementations of `batch` and `abatch` utilize threadpools and +asyncio gather and will be faster than naive invocation of `invoke` or `ainvoke` +for IO bound `Runnable`s. + +Batching is implemented by invoking the batch method on each component of the +`RunnableSequence` in order. + +A `RunnableSequence` preserves the streaming properties of its components, so if +all components of the sequence implement a `transform` method -- which +is the method that implements the logic to map a streaming input to a streaming +output -- then the sequence will be able to stream input to output! + +If any component of the sequence does not implement transform then the +streaming will only begin after this component is run. If there are +multiple blocking components, streaming begins after the last one. + +!!! note + `RunnableLambdas` do not support `transform` by default! So if you need to + use a `RunnableLambdas` be careful about where you place them in a + `RunnableSequence` (if you need to use the `stream`/`astream` methods). + + If you need arbitrary logic and need streaming, you can subclass + Runnable, and implement `transform` for whatever logic you need. + +Here is a simple example that uses simple functions to illustrate the use of +`RunnableSequence`: + + ```python + from langchain_core.runnables import RunnableLambda + + + def add_one(x: int) -> int: + return x + 1 + + + def mul_two(x: int) -> int: + return x * 2 + + + runnable_1 = RunnableLambda(add_one) + runnable_2 = RunnableLambda(mul_two) + sequence = runnable_1 | runnable_2 + # Or equivalently: + # sequence = RunnableSequence(first=runnable_1, last=runnable_2) + sequence.invoke(1) + await sequence.ainvoke(1) + + sequence.batch([1, 2, 3]) + await sequence.abatch([1, 2, 3]) + ``` + +Here's an example that uses streams JSON output generated by an LLM: + + ```python + from langchain_core.output_parsers.json import SimpleJsonOutputParser + from langchain_openai import ChatOpenAI + + prompt = PromptTemplate.from_template( + "In JSON format, give me a list of {topic} and their " + "corresponding names in French, Spanish and in a " + "Cat Language." + ) + + model = ChatOpenAI() + chain = prompt | model | SimpleJsonOutputParser() + + async for chunk in chain.astream({"topic": "colors"}): + print("-") # noqa: T201 + print(chunk, sep="", flush=True) # noqa: T201 + ``` + + + +The type of the input to the `Runnable`. + + + +The type of the output of the `Runnable`. + + + +Get the config specs of the `Runnable`. + + + +The first `Runnable` in the sequence. + + + +The last `Runnable` in the sequence. + + + +The middle `Runnable` in the sequence. + + + + + + +All the `Runnable`s that make up the sequence in order. + + + + + +```python +langchain_core.runnables.base.RunnableSequence.__or__( + other: langchain_core.runnables.base.Runnable[typing.Any, langchain_core.runnables.base.Other] | collections.abc.Callable[[Iterator[Any]], collections.abc.Iterator[langchain_core.runnables.base.Other]] | collections.abc.Callable[[AsyncIterator[Any]], collections.abc.AsyncIterator[langchain_core.runnables.base.Other]] | collections.abc.Callable[[Any], langchain_core.runnables.base.Other] | collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[typing.Any, langchain_core.runnables.base.Other] | collections.abc.Callable[[Any], langchain_core.runnables.base.Other] | typing.Any] +) -> langchain_core.runnables.base.RunnableSerializable[langchain_core.runnables.utils.Input, langchain_core.runnables.base.Other] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.__repr__() -> str +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.__ror__( + other: langchain_core.runnables.base.Runnable[langchain_core.runnables.base.Other, typing.Any] | collections.abc.Callable[[Iterator[Other]], collections.abc.Iterator[typing.Any]] | collections.abc.Callable[[AsyncIterator[Other]], collections.abc.AsyncIterator[typing.Any]] | collections.abc.Callable[[Other], typing.Any] | collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[langchain_core.runnables.base.Other, typing.Any] | collections.abc.Callable[[Other], typing.Any] | typing.Any] +) -> langchain_core.runnables.base.RunnableSerializable[langchain_core.runnables.base.Other, langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence._atransform( + inputs: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence._transform( + inputs: collections.abc.Iterator[langchain_core.runnables.utils.Input], + run_manager: langchain_core.callbacks.manager.CallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.abatch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.ainvoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.astream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.atransform( + input: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.batch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.get_graph( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> langchain_core.runnables.graph.Graph +``` + + + + + + +Get the graph representation of the `Runnable`. + +**Parameters:** + + +The config to use. + + +**Returns:** `Graph` + +The graph representation of the `Runnable`. + +**Raises:** + +- `ValueError`: If a `Runnable` has no first or last node. + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Get the input schema of the `Runnable`. + +**Parameters:** + + +The config to use. + + +**Returns:** `type[BaseModel]` + +The input schema of the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "runnable"]` + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.get_output_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Get the output schema of the `Runnable`. + +**Parameters:** + + +The config to use. + + +**Returns:** `type[BaseModel]` + +The output schema of the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.invoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.stream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.base.RunnableSequence.transform( + input: collections.abc.Iterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.base.RunnableSerializable() +``` + + + + + + +**Bases:** [Serializable](/langchain-core/langchain_core/load/serializable#langchain_core-load-serializable-Serializable), [Runnable[Input, Output]](#langchain_core-runnables-base-Runnable) + +Runnable that can be serialized to JSON. + + + + + + +The name of the `Runnable`. + +Used for debugging and tracing. + + + + + +```python +langchain_core.runnables.base.RunnableSerializable.configurable_alternatives( + which: langchain_core.runnables.utils.ConfigurableField, + default_key: str = 'default', + prefix_keys: bool = False, + kwargs: langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] | collections.abc.Callable[[], langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output]] = {} +) -> langchain_core.runnables.base.RunnableSerializable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Configure alternatives for `Runnable` objects that can be set at runtime. + +!!! example + + ```python + from langchain_anthropic import ChatAnthropic + from langchain_core.runnables.utils import ConfigurableField + from langchain_openai import ChatOpenAI + + model = ChatAnthropic( + model_name="claude-sonnet-4-5-20250929" + ).configurable_alternatives( + ConfigurableField(id="llm"), + default_key="anthropic", + openai=ChatOpenAI(), + ) + + # uses the default model ChatAnthropic + print(model.invoke("which organization created you?").content) + + # uses ChatOpenAI + print( + model.with_config(configurable={"llm": "openai"}) + .invoke("which organization created you?") + .content + ) + ``` + +**Parameters:** + + +The `ConfigurableField` instance that will be used to select the +alternative. + + + +The default key to use if no alternative is selected. + + + +Whether to prefix the keys with the `ConfigurableField` id. + + + +A dictionary of keys to `Runnable` instances or callables that +return `Runnable` instances. + + +**Returns:** `RunnableSerializable[Input, Output]` + +A new `Runnable` with the alternatives configured. + + + + + + + +```python +langchain_core.runnables.base.RunnableSerializable.configurable_fields( + kwargs: langchain_core.runnables.utils.AnyConfigurableField = {} +) -> langchain_core.runnables.base.RunnableSerializable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Configure particular `Runnable` fields at runtime. + +!!! example + + ```python + from langchain_core.runnables import ConfigurableField + from langchain_openai import ChatOpenAI + + model = ChatOpenAI(max_tokens=20).configurable_fields( + max_tokens=ConfigurableField( + id="output_token_number", + name="Max tokens in the output", + description="The maximum number of tokens in the output", + ) + ) + + # max_tokens = 20 + print( + "max_tokens_20: ", model.invoke("tell me something about chess").content + ) + + # max_tokens = 200 + print( + "max_tokens_200: ", + model.with_config(configurable={"output_token_number": 200}) + .invoke("tell me something about chess") + .content, + ) + ``` + +**Parameters:** + + +A dictionary of `ConfigurableField` instances to configure. + + +**Returns:** `RunnableSerializable[Input, Output]` + +A new `Runnable` with the fields configured. + +**Raises:** + +- `ValueError`: If a configuration key is not found in the `Runnable`. + + + + + + + +```python +langchain_core.runnables.base.RunnableSerializable.to_json() -> langchain_core.load.serializable.SerializedConstructor | langchain_core.load.serializable.SerializedNotImplemented +``` + + + + + + +Serialize the `Runnable` to JSON. + +**Returns:** `SerializedConstructor | SerializedNotImplemented` + +A JSON-serializable representation of the `Runnable`. + + + + + + + + + +```python +class langchain_core.runnables.base._RunnableCallableAsync() +``` + + + + + + +Protocol + +**Bases:** `Protocol[Input, Output]` + + + + + +```python +langchain_core.runnables.base._RunnableCallableAsync.__call__( + _in: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig +) -> collections.abc.Awaitable[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.base._RunnableCallableAsyncIterator() +``` + + + + + + +Protocol + +**Bases:** `Protocol[Input, Output]` + + + + + +```python +langchain_core.runnables.base._RunnableCallableAsyncIterator.__call__( + _in: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.base._RunnableCallableIterator() +``` + + + + + + +Protocol + +**Bases:** `Protocol[Input, Output]` + + + + + +```python +langchain_core.runnables.base._RunnableCallableIterator.__call__( + _in: collections.abc.Iterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.base._RunnableCallableSync() +``` + + + + + + +Protocol + +**Bases:** `Protocol[Input, Output]` + + + + + +```python +langchain_core.runnables.base._RunnableCallableSync.__call__( + _in: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig +) -> langchain_core.runnables.utils.Output +``` + + + + + + + + + + + + + + +```python +langchain_core.runnables.base._seq_input_schema( + steps: list[langchain_core.runnables.base.Runnable[typing.Any, typing.Any]], + config: langchain_core.runnables.config.RunnableConfig | None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + + +```python +langchain_core.runnables.base._seq_output_schema( + steps: list[langchain_core.runnables.base.Runnable[typing.Any, typing.Any]], + config: langchain_core.runnables.config.RunnableConfig | None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + + +```python +langchain_core.runnables.base.chain( + func: collections.abc.Callable[[Input], langchain_core.runnables.utils.Output] | collections.abc.Callable[[Input], collections.abc.Iterator[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input], collections.abc.Coroutine[typing.Any, typing.Any, langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input], collections.abc.AsyncIterator[langchain_core.runnables.utils.Output]] +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Decorate a function to make it a `Runnable`. + +Sets the name of the `Runnable` to the name of the function. +Any runnables called by the function will be traced as dependencies. + +**Parameters:** + + +A `Callable`. + + +**Returns:** `Runnable[Input, Output]` + +A `Runnable`. + + + + + + + + +```python +langchain_core.runnables.base.coerce_to_runnable( + thing: langchain_core.runnables.base.RunnableLike +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + +Coerce a `Runnable`-like object into a `Runnable`. + +**Parameters:** + + +A `Runnable`-like object. + + +**Returns:** `Runnable[Input, Output]` + +A `Runnable`. + +**Raises:** + +- `TypeError`: If the object is not `Runnable`-like. + + + + + + + + +```python +langchain_core.runnables.base.Other = TypeVar('Other') +``` + + + + + + + + + +```python +langchain_core.runnables.base.RunnableLike = Runnable[Input, Output] | Callable[[Input], Output] | Callable[[Input], Awaitabl... +``` + + + + + + + + + +```python +langchain_core.runnables.base.RunnableMap = RunnableParallel +``` + + + + + + + + + +```python +langchain_core.runnables.base._RUNNABLE_GENERIC_NUM_ARGS = 2 +``` + + + + + + + + + +```python +langchain_core.runnables.base._RUNNABLE_SEQUENCE_MIN_STEPS = 2 +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/branch.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/branch.mdx new file mode 100644 index 0000000..8e953fb --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/branch.mdx @@ -0,0 +1,287 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/branch +title: langchain_core.runnables.branch +--- + +Runnable that selects which branch to run based on a condition. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RunnableBranch`](#langchain_core-runnables-branch-RunnableBranch) | `Runnable` that selects which branch to run based on a condition. | + +### Data + +[`_MIN_BRANCHES`](#langchain_core-runnables-branch-_MIN_BRANCHES) + +### API + + + + + +```python +class langchain_core.runnables.branch.RunnableBranch( + branches: tuple[langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, bool] | collections.abc.Callable[[Input], bool] | collections.abc.Callable[[Input], collections.abc.Awaitable[bool]], langchain_core.runnables.base.RunnableLike] | langchain_core.runnables.base.RunnableLike = () +) +``` + + + + + + +**Bases:** [RunnableSerializable[Input, Output]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +`Runnable` that selects which branch to run based on a condition. + +The `Runnable` is initialized with a list of `(condition, Runnable)` pairs and +a default branch. + +When operating on an input, the first condition that evaluates to True is +selected, and the corresponding `Runnable` is run on the input. + +If no condition evaluates to `True`, the default branch is run on the input. + +**Examples:** + + + +```python +from langchain_core.runnables import RunnableBranch + +branch = RunnableBranch( + (lambda x: isinstance(x, str), lambda x: x.upper()), + (lambda x: isinstance(x, int), lambda x: x + 1), + (lambda x: isinstance(x, float), lambda x: x * 2), + lambda x: "goodbye", +) + +branch.invoke("hello") # "HELLO" +branch.invoke(None) # "goodbye" +``` + + + + + +A list of `(condition, Runnable)` pairs. + + + + + + +A `Runnable` to run if no condition is met. + + + + + + + + +```python +langchain_core.runnables.branch.RunnableBranch.ainvoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.branch.RunnableBranch.astream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + +First evaluates the condition, then delegate to `True` or `False` branch. + +**Parameters:** + + +The input to the `Runnable`. + + + +The configuration for the `Runnable`. + + + +Additional keyword arguments to pass to the `Runnable`. + + + + + + + + +```python +langchain_core.runnables.branch.RunnableBranch.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.branch.RunnableBranch.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "runnable"]` + + + + + + + +```python +langchain_core.runnables.branch.RunnableBranch.invoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +First evaluates the condition, then delegate to `True` or `False` branch. + +**Parameters:** + + +The input to the `Runnable`. + + + +The configuration for the `Runnable`. + + + +Additional keyword arguments to pass to the `Runnable`. + + +**Returns:** `Output` + +The output of the branch that was run. + + + + + + + +```python +langchain_core.runnables.branch.RunnableBranch.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.runnables.branch.RunnableBranch.stream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + +First evaluates the condition, then delegate to `True` or `False` branch. + +**Parameters:** + + +The input to the `Runnable`. + + + +The configuration for the `Runnable`. + + + +Additional keyword arguments to pass to the `Runnable`. + + + + + + + + + + +```python +langchain_core.runnables.branch._MIN_BRANCHES = 2 +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/config.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/config.mdx new file mode 100644 index 0000000..e0637c4 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/config.mdx @@ -0,0 +1,777 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/config +title: langchain_core.runnables.config +--- + +Configuration utilities for `Runnable` objects. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ContextThreadPoolExecutor`](#langchain_core-runnables-config-ContextThreadPoolExecutor) | ThreadPoolExecutor that copies the context to the child thread. | +| [`EmptyDict`](#langchain_core-runnables-config-EmptyDict) | Empty dict type. | +| [`RunnableConfig`](#langchain_core-runnables-config-RunnableConfig) | Configuration for a `Runnable`. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_set_config_context`](#langchain_core-runnables-config-_set_config_context) | Set the child Runnable config + tracing context. | +| [`acall_func_with_variable_args`](#langchain_core-runnables-config-acall_func_with_variable_args) | Async call function that may optionally accept a run_manager and/or config. | +| [`call_func_with_variable_args`](#langchain_core-runnables-config-call_func_with_variable_args) | Call function that may optionally accept a run_manager and/or config. | +| [`ensure_config`](#langchain_core-runnables-config-ensure_config) | Ensure that a config is a dict with all keys present. | +| [`get_async_callback_manager_for_config`](#langchain_core-runnables-config-get_async_callback_manager_for_config) | Get an async callback manager for a config. | +| [`get_callback_manager_for_config`](#langchain_core-runnables-config-get_callback_manager_for_config) | Get a callback manager for a config. | +| [`get_config_list`](#langchain_core-runnables-config-get_config_list) | Get a list of configs from a single config or a list of configs. | +| [`get_executor_for_config`](#langchain_core-runnables-config-get_executor_for_config) | Get an executor for a config. | +| [`merge_configs`](#langchain_core-runnables-config-merge_configs) | Merge multiple configs into one. | +| [`patch_config`](#langchain_core-runnables-config-patch_config) | Patch a config with new values. | +| [`run_in_executor`](#langchain_core-runnables-config-run_in_executor) | Run a function in an executor. | +| [`set_config_context`](#langchain_core-runnables-config-set_config_context) | Set the child Runnable config + tracing context. | + +### Data + +[`CONFIG_KEYS`](#langchain_core-runnables-config-CONFIG_KEYS) + +[`COPIABLE_KEYS`](#langchain_core-runnables-config-COPIABLE_KEYS) + +[`DEFAULT_RECURSION_LIMIT`](#langchain_core-runnables-config-DEFAULT_RECURSION_LIMIT) + +[`P`](#langchain_core-runnables-config-P) + +[`T`](#langchain_core-runnables-config-T) + +[`var_child_runnable_config`](#langchain_core-runnables-config-var_child_runnable_config) + +### API + + + + + +```python +class langchain_core.runnables.config.ContextThreadPoolExecutor() +``` + + + + + + +**Bases:** `ThreadPoolExecutor` + +ThreadPoolExecutor that copies the context to the child thread. + + + + + + +```python +langchain_core.runnables.config.ContextThreadPoolExecutor.map( + fn: collections.abc.Callable[..., langchain_core.runnables.config.T], + iterables: collections.abc.Iterable[typing.Any] = (), + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.runnables.config.T] +``` + + + + + + +Map a function to multiple iterables. + +**Parameters:** + + +The function to map. + + + +The iterables to map over. + + + +The timeout for the map. + + + +The chunksize for the map. + + +**Returns:** `Iterator[T]` + +The iterator for the mapped function. + + + + + + + +```python +langchain_core.runnables.config.ContextThreadPoolExecutor.submit( + func: collections.abc.Callable[langchain_core.runnables.config.P, langchain_core.runnables.config.T], + args: langchain_core.runnables.config.P.args = (), + kwargs: langchain_core.runnables.config.P.kwargs = {} +) -> concurrent.futures.Future[langchain_core.runnables.config.T] +``` + + + + + + +Submit a function to the executor. + +**Parameters:** + + +The function to submit. + + + +The positional arguments to the function. + + + +The keyword arguments to the function. + + +**Returns:** `Future[T]` + +The future for the function. + + + + + + + + + +```python +class langchain_core.runnables.config.EmptyDict +``` + + + + + + +**Bases:** `typing.TypedDict` + +Empty dict type. + + + + + + + +```python +class langchain_core.runnables.config.RunnableConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for a `Runnable`. + +!!! note Custom values + + The `TypedDict` has `total=False` set intentionally to: + + - Allow partial configs to be created and merged together via `merge_configs` + - Support config propagation from parent to child runnables via + `var_child_runnable_config` (a `ContextVar` that automatically passes + config down the call stack without explicit parameter passing), where + configs are merged rather than replaced + + !!! example + + ```python + # Parent sets tags + chain.invoke(input, config={"tags": ["parent"]}) + # Child automatically inherits and can add: + # ensure_config({"tags": ["child"]}) -> {"tags": ["parent", "child"]} + ``` + + +Callbacks for this call and any sub-calls (e.g. a Chain calling an LLM). + +Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. + + + +Runtime values for attributes previously made configurable on this `Runnable`, +or sub-`Runnable` objects, through `configurable_fields` or +`configurable_alternatives`. + +Check `output_schema` for a description of the attributes that have been made +configurable. + + + +Maximum number of parallel calls to make. + +If not provided, defaults to `ThreadPoolExecutor`'s default. + + + +Metadata for this call and any sub-calls (e.g. a Chain calling an LLM). + +Keys should be strings, values should be JSON-serializable. + + + +Maximum number of times a call can recurse. + +If not provided, defaults to `25`. + + + +Unique identifier for the tracer run for this call. + +If not provided, a new UUID will be generated. + + + +Name for the tracer run for this call. + +Defaults to the name of the class. + + + +Tags for this call and any sub-calls (e.g. a Chain calling an LLM). + +You can use these to filter calls. + + + + + + + + +```python +langchain_core.runnables.config._set_config_context( + config: langchain_core.runnables.config.RunnableConfig +) -> tuple[contextvars.Token[langchain_core.runnables.config.RunnableConfig | None], dict[str, typing.Any] | None] +``` + + + + + + +Set the child Runnable config + tracing context. + +**Parameters:** + + +The config to set. + + +**Returns:** `tuple[Token[RunnableConfig | None], dict[str, Any] | None]` + +The token to reset the config and the previous tracing context. + + + + + + + + +```python +langchain_core.runnables.config.acall_func_with_variable_args( + func: collections.abc.Callable[[Input], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input, RunnableConfig], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input, AsyncCallbackManagerForChainRun], collections.abc.Awaitable[langchain_core.runnables.utils.Output]] | collections.abc.Callable[[Input, AsyncCallbackManagerForChainRun, RunnableConfig], collections.abc.Awaitable[langchain_core.runnables.utils.Output]], + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig, + run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Awaitable[langchain_core.runnables.utils.Output] +``` + + + + + + +Async call function that may optionally accept a run_manager and/or config. + +**Parameters:** + + +The function to call. + + + +The input to the function. + + + +The config to pass to the function. + + + +The run manager to pass to the function. + + + +The keyword arguments to pass to the function. + + +**Returns:** `Awaitable[Output]` + +The output of the function. + + + + + + + + +```python +langchain_core.runnables.config.call_func_with_variable_args( + func: collections.abc.Callable[[Input], langchain_core.runnables.utils.Output] | collections.abc.Callable[[Input, RunnableConfig], langchain_core.runnables.utils.Output] | collections.abc.Callable[[Input, CallbackManagerForChainRun], langchain_core.runnables.utils.Output] | collections.abc.Callable[[Input, CallbackManagerForChainRun, RunnableConfig], langchain_core.runnables.utils.Output], + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig, + run_manager: langchain_core.callbacks.manager.CallbackManagerForChainRun | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +Call function that may optionally accept a run_manager and/or config. + +**Parameters:** + + +The function to call. + + + +The input to the function. + + + +The config to pass to the function. + + + +The run manager to pass to the function. + + + +The keyword arguments to pass to the function. + + +**Returns:** `Output` + +The output of the function. + + + + + + + + +```python +langchain_core.runnables.config.ensure_config( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> langchain_core.runnables.config.RunnableConfig +``` + + + + + + +Ensure that a config is a dict with all keys present. + +**Parameters:** + + +The config to ensure. + + +**Returns:** `RunnableConfig` + +The ensured config. + + + + + + + + +```python +langchain_core.runnables.config.get_async_callback_manager_for_config( + config: langchain_core.runnables.config.RunnableConfig +) -> langchain_core.callbacks.manager.AsyncCallbackManager +``` + + + + + + +Get an async callback manager for a config. + +**Parameters:** + + +The config. + + +**Returns:** `AsyncCallbackManager` + +The async callback manager. + + + + + + + + +```python +langchain_core.runnables.config.get_callback_manager_for_config( + config: langchain_core.runnables.config.RunnableConfig +) -> langchain_core.callbacks.manager.CallbackManager +``` + + + + + + +Get a callback manager for a config. + +**Parameters:** + + +The config. + + +**Returns:** `CallbackManager` + +The callback manager. + + + + + + + + +```python +langchain_core.runnables.config.get_config_list( + config: langchain_core.runnables.config.RunnableConfig | collections.abc.Sequence[langchain_core.runnables.config.RunnableConfig] | None, + length: int +) -> list[langchain_core.runnables.config.RunnableConfig] +``` + + + + + + +Get a list of configs from a single config or a list of configs. + + It is useful for subclasses overriding batch() or abatch(). + +**Parameters:** + + +The config or list of configs. + + + +The length of the list. + + +**Returns:** `list[RunnableConfig]` + +The list of configs. + +**Raises:** + +- `ValueError`: If the length of the list is not equal to the length of the inputs. + + + + + + + + +```python +langchain_core.runnables.config.get_executor_for_config( + config: langchain_core.runnables.config.RunnableConfig | None +) -> collections.abc.Generator[concurrent.futures.Executor, None, None] +``` + + + + + + +Get an executor for a config. + +**Parameters:** + + +The config. + + + + + + + + + +```python +langchain_core.runnables.config.merge_configs( + configs: langchain_core.runnables.config.RunnableConfig | None = () +) -> langchain_core.runnables.config.RunnableConfig +``` + + + + + + +Merge multiple configs into one. + +**Parameters:** + + +The configs to merge. + + +**Returns:** `RunnableConfig` + +The merged config. + + + + + + + + +```python +langchain_core.runnables.config.patch_config( + config: langchain_core.runnables.config.RunnableConfig | None, + callbacks: langchain_core.callbacks.base.BaseCallbackManager | None = None, + recursion_limit: int | None = None, + max_concurrency: int | None = None, + run_name: str | None = None, + configurable: dict[str, typing.Any] | None = None +) -> langchain_core.runnables.config.RunnableConfig +``` + + + + + + +Patch a config with new values. + +**Parameters:** + + +The config to patch. + + + +The callbacks to set. + + + +The recursion limit to set. + + + +The max concurrency to set. + + + +The run name to set. + + + +The configurable to set. + + +**Returns:** `RunnableConfig` + +The patched config. + + + + + + + + +```python +langchain_core.runnables.config.run_in_executor( + executor_or_config: concurrent.futures.Executor | langchain_core.runnables.config.RunnableConfig | None, + func: collections.abc.Callable[langchain_core.runnables.config.P, langchain_core.runnables.config.T], + args: langchain_core.runnables.config.P.args = (), + kwargs: langchain_core.runnables.config.P.kwargs = {} +) -> langchain_core.runnables.config.T +``` + + + + + + +async + +Run a function in an executor. + +**Parameters:** + + +The executor or config to run in. + + + +The function. + + + +The positional arguments to the function. + + + +The keyword arguments to the function. + + +**Returns:** `T` + +The output of the function. + + + + + + + + +```python +langchain_core.runnables.config.set_config_context( + config: langchain_core.runnables.config.RunnableConfig +) -> collections.abc.Generator[contextvars.Context, None, None] +``` + + + + + + +Set the child Runnable config + tracing context. + +**Parameters:** + + +The config to set. + + + + + + + + + +```python +langchain_core.runnables.config.CONFIG_KEYS = ['tags', 'metadata', 'callbacks', 'run_name', 'max_concurrency', 'recursion_limi... +``` + + + + + + + + + +```python +langchain_core.runnables.config.COPIABLE_KEYS = ['tags', 'metadata', 'callbacks', 'configurable'] +``` + + + + + + + + + +```python +langchain_core.runnables.config.DEFAULT_RECURSION_LIMIT = 25 +``` + + + + + + + + + +```python +langchain_core.runnables.config.P = ParamSpec('P') +``` + + + + + + + + + +```python +langchain_core.runnables.config.T = TypeVar('T') +``` + + + + + + + + + +```python +langchain_core.runnables.config.var_child_runnable_config: ContextVar[RunnableConfig | None] = ContextVar('child_runnable_config', default=None) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/configurable.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/configurable.mdx new file mode 100644 index 0000000..1ec1bec --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/configurable.mdx @@ -0,0 +1,796 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/configurable +title: langchain_core.runnables.configurable +--- + +`Runnable` objects that can be dynamically configured. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DynamicRunnable`](#langchain_core-runnables-configurable-DynamicRunnable) | Serializable `Runnable` that can be dynamically configured. | +| [`RunnableConfigurableAlternatives`](#langchain_core-runnables-configurable-RunnableConfigurableAlternatives) | `Runnable` that can be dynamically configured. | +| [`RunnableConfigurableFields`](#langchain_core-runnables-configurable-RunnableConfigurableFields) | `Runnable` that can be dynamically configured. | +| [`StrEnum`](#langchain_core-runnables-configurable-StrEnum) | String enum. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_strremoveprefix`](#langchain_core-runnables-configurable-_strremoveprefix) | `str.removeprefix()` is only available in Python 3.9+. | +| [`make_options_spec`](#langchain_core-runnables-configurable-make_options_spec) | Make options spec. | +| [`prefix_config_spec`](#langchain_core-runnables-configurable-prefix_config_spec) | Prefix the id of a `ConfigurableFieldSpec`. | + +### Data + +[`_enums_for_spec`](#langchain_core-runnables-configurable-_enums_for_spec) + +[`_enums_for_spec_lock`](#langchain_core-runnables-configurable-_enums_for_spec_lock) + +### API + + + + + +```python +class langchain_core.runnables.configurable.DynamicRunnable() +``` + + + + + + +**Bases:** [RunnableSerializable[Input, Output]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +Serializable `Runnable` that can be dynamically configured. + +A `DynamicRunnable` should be initiated using the `configurable_fields` or +`configurable_alternatives` method of a `Runnable`. + + + + + + + + + +The configuration to use. + + + +The default `Runnable` to use. + + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.__getattr__( + name: str +) -> typing.Any +``` + + + + + + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable._prepare( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> tuple[langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output], langchain_core.runnables.config.RunnableConfig] +``` + + + + + + +abstract + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.abatch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.ainvoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.astream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.atransform( + input: collections.abc.AsyncIterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.batch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.get_graph( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> langchain_core.runnables.graph.Graph +``` + + + + + + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "runnable"]` + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.get_output_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.invoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.prepare( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> tuple[langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output], langchain_core.runnables.config.RunnableConfig] +``` + + + + + + +Prepare the `Runnable` for invocation. + +**Parameters:** + + +The configuration to use. + + +**Returns:** `tuple[Runnable[Input, Output], RunnableConfig]` + +The prepared `Runnable` and configuration. + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.stream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.transform( + input: collections.abc.Iterator[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.configurable.DynamicRunnable.with_config( + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.configurable.RunnableConfigurableAlternatives() +``` + + + + + + +**Bases:** [DynamicRunnable[Input, Output]](#langchain_core-runnables-configurable-DynamicRunnable) + +`Runnable` that can be dynamically configured. + +A `RunnableConfigurableAlternatives` should be initiated using the +`configurable_alternatives` method of a `Runnable` or can be +initiated directly as well. + +Here is an example of using a `RunnableConfigurableAlternatives` that uses +alternative prompts to illustrate its functionality: + + ```python + from langchain_core.runnables import ConfigurableField + from langchain_openai import ChatOpenAI + + # This creates a RunnableConfigurableAlternatives for Prompt Runnable + # with two alternatives. + prompt = PromptTemplate.from_template( + "Tell me a joke about {topic}" + ).configurable_alternatives( + ConfigurableField(id="prompt"), + default_key="joke", + poem=PromptTemplate.from_template("Write a short poem about {topic}"), + ) + + # When invoking the created RunnableSequence, you can pass in the + # value for your ConfigurableField's id which in this case will either be + # `joke` or `poem`. + chain = prompt | ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) + + # The `with_config` method brings in the desired Prompt Runnable in your + # Runnable Sequence. + chain.with_config(configurable={"prompt": "poem"}).invoke({"topic": "bears"}) + ``` + +Equivalently, you can initialize `RunnableConfigurableAlternatives` directly +and use in LCEL in the same way: + + ```python + from langchain_core.runnables import ConfigurableField + from langchain_core.runnables.configurable import ( + RunnableConfigurableAlternatives, + ) + from langchain_openai import ChatOpenAI + + prompt = RunnableConfigurableAlternatives( + which=ConfigurableField(id="prompt"), + default=PromptTemplate.from_template("Tell me a joke about {topic}"), + default_key="joke", + prefix_keys=False, + alternatives={ + "poem": PromptTemplate.from_template("Write a short poem about {topic}") + }, + ) + chain = prompt | ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) + chain.with_config(configurable={"prompt": "poem"}).invoke({"topic": "bears"}) + ``` + + + +The alternatives to choose from. + + + + + + +The enum value to use for the default option. + + + +Whether to prefix configurable fields of each alternative with a namespace +of the form <which.id>==<alternative_key>, e.g. a key named "temperature" used by +the alternative named "gpt3" becomes "model==gpt3/temperature". + + + +The `ConfigurableField` to use to choose between alternatives. + + + + + +```python +langchain_core.runnables.configurable.RunnableConfigurableAlternatives._prepare( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> tuple[langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output], langchain_core.runnables.config.RunnableConfig] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.configurable.RunnableConfigurableAlternatives.configurable_fields( + kwargs: langchain_core.runnables.utils.AnyConfigurableField = {} +) -> langchain_core.runnables.base.RunnableSerializable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.configurable.RunnableConfigurableFields() +``` + + + + + + +**Bases:** [DynamicRunnable[Input, Output]](#langchain_core-runnables-configurable-DynamicRunnable) + +`Runnable` that can be dynamically configured. + +A `RunnableConfigurableFields` should be initiated using the +`configurable_fields` method of a `Runnable`. + +Here is an example of using a `RunnableConfigurableFields` with LLMs: + + ```python + from langchain_core.prompts import PromptTemplate + from langchain_core.runnables import ConfigurableField + from langchain_openai import ChatOpenAI + + model = ChatOpenAI(temperature=0).configurable_fields( + temperature=ConfigurableField( + id="temperature", + name="LLM Temperature", + description="The temperature of the LLM", + ) + ) + # This creates a RunnableConfigurableFields for a chat model. + + # When invoking the created RunnableSequence, you can pass in the + # value for your ConfigurableField's id which in this case + # will be change in temperature + + prompt = PromptTemplate.from_template("Pick a random number above {x}") + chain = prompt | model + + chain.invoke({"x": 0}) + chain.invoke({"x": 0}, config={"configurable": {"temperature": 0.9}}) + ``` + +Here is an example of using a `RunnableConfigurableFields` with `HubRunnables`: + + ```python + from langchain_core.prompts import PromptTemplate + from langchain_core.runnables import ConfigurableField + from langchain_openai import ChatOpenAI + from langchain.runnables.hub import HubRunnable + + prompt = HubRunnable("rlm/rag-prompt").configurable_fields( + owner_repo_commit=ConfigurableField( + id="hub_commit", + name="Hub Commit", + description="The Hub commit to pull from", + ) + ) + + prompt.invoke({"question": "foo", "context": "bar"}) + + # Invoking prompt with `with_config` method + + prompt.invoke( + {"question": "foo", "context": "bar"}, + config={"configurable": {"hub_commit": "rlm/rag-prompt-llama"}}, + ) + ``` + + + +Get the configuration specs for the `RunnableConfigurableFields`. + + + +The configurable fields to use. + + + + + +```python +langchain_core.runnables.configurable.RunnableConfigurableFields._prepare( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> tuple[langchain_core.runnables.base.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output], langchain_core.runnables.config.RunnableConfig] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.configurable.RunnableConfigurableFields.configurable_fields( + kwargs: langchain_core.runnables.utils.AnyConfigurableField = {} +) -> langchain_core.runnables.base.RunnableSerializable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.configurable.StrEnum +``` + + + + + + +**Bases:** `enum.Enum` + +String enum. + + + + + + + +```python +langchain_core.runnables.configurable._strremoveprefix( + s: str, + prefix: str +) -> str +``` + + + + + + +`str.removeprefix()` is only available in Python 3.9+. + + + + + + + + +```python +langchain_core.runnables.configurable.make_options_spec( + spec: langchain_core.runnables.utils.ConfigurableFieldSingleOption | langchain_core.runnables.utils.ConfigurableFieldMultiOption, + description: str | None +) -> langchain_core.runnables.utils.ConfigurableFieldSpec +``` + + + + + + +Make options spec. + +Make a `ConfigurableFieldSpec` for a `ConfigurableFieldSingleOption` or +`ConfigurableFieldMultiOption`. + +**Parameters:** + + +The `ConfigurableFieldSingleOption` or `ConfigurableFieldMultiOption`. + + + +The description to use if the spec does not have one. + + +**Returns:** `ConfigurableFieldSpec` + +The `ConfigurableFieldSpec`. + + + + + + + + +```python +langchain_core.runnables.configurable.prefix_config_spec( + spec: langchain_core.runnables.utils.ConfigurableFieldSpec, + prefix: str +) -> langchain_core.runnables.utils.ConfigurableFieldSpec +``` + + + + + + +Prefix the id of a `ConfigurableFieldSpec`. + +This is useful when a `RunnableConfigurableAlternatives` is used as a +`ConfigurableField` of another `RunnableConfigurableAlternatives`. + +**Parameters:** + + +The `ConfigurableFieldSpec` to prefix. + + + +The prefix to add. + + +**Returns:** `ConfigurableFieldSpec` + +The prefixed `ConfigurableFieldSpec`. + + + + + + + + +```python +langchain_core.runnables.configurable._enums_for_spec: WeakValueDictionary[ConfigurableFieldSingleOption | ConfigurableFieldMultiOption | ConfigurableField, type[StrEnum]] = WeakValueDictionary() +``` + + + + + + + + + +```python +langchain_core.runnables.configurable._enums_for_spec_lock = threading.Lock() +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/fallbacks.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/fallbacks.mdx new file mode 100644 index 0000000..870e22c --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/fallbacks.mdx @@ -0,0 +1,352 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/fallbacks +title: langchain_core.runnables.fallbacks +--- + +`Runnable` that can fallback to other `Runnable` objects if it fails. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RunnableWithFallbacks`](#langchain_core-runnables-fallbacks-RunnableWithFallbacks) | `Runnable` that can fallback to other `Runnable` objects if it fails. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_is_runnable_type`](#langchain_core-runnables-fallbacks-_is_runnable_type) | - | +| [`_returns_runnable`](#langchain_core-runnables-fallbacks-_returns_runnable) | - | + +### API + + + + + +```python +class langchain_core.runnables.fallbacks.RunnableWithFallbacks() +``` + + + + + + +**Bases:** [RunnableSerializable[Input, Output]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +`Runnable` that can fallback to other `Runnable` objects if it fails. + +External APIs (e.g., APIs for a language model) may at times experience +degraded performance or even downtime. + +In these cases, it can be useful to have a fallback `Runnable` that can be +used in place of the original `Runnable` (e.g., fallback to another LLM provider). + +Fallbacks can be defined at the level of a single `Runnable`, or at the level +of a chain of `Runnable`s. Fallbacks are tried in order until one succeeds or +all fail. + +While you can instantiate a `RunnableWithFallbacks` directly, it is usually +more convenient to use the `with_fallbacks` method on a `Runnable`. + + + + + + + + + + + + +If `string` is specified then handled exceptions will be passed to fallbacks as +part of the input under the specified key. + +If `None`, exceptions will not be passed to fallbacks. + +If used, the base `Runnable` and its fallbacks must accept a dictionary as input. + + + +The exceptions on which fallbacks should be tried. + +Any exception that is not a subclass of these exceptions will be raised immediately. + + + +A sequence of fallbacks to try. + + + + + + +The `Runnable` to run first. + + + +Iterator over the `Runnable` and its fallbacks. + + + + + +```python +langchain_core.runnables.fallbacks.RunnableWithFallbacks.__getattr__( + name: str +) -> typing.Any +``` + + + + + + +Get an attribute from the wrapped `Runnable` and its fallbacks. + +**Returns:** `Any` + +If the attribute is anything other than a method that outputs a `Runnable`, + + + + + + + +```python +langchain_core.runnables.fallbacks.RunnableWithFallbacks.abatch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.fallbacks.RunnableWithFallbacks.ainvoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.fallbacks.RunnableWithFallbacks.astream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.fallbacks.RunnableWithFallbacks.batch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.fallbacks.RunnableWithFallbacks.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.fallbacks.RunnableWithFallbacks.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "runnable"]` + + + + + + + +```python +langchain_core.runnables.fallbacks.RunnableWithFallbacks.get_output_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.fallbacks.RunnableWithFallbacks.invoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + + + + + + + +```python +langchain_core.runnables.fallbacks.RunnableWithFallbacks.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.runnables.fallbacks.RunnableWithFallbacks.stream( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + + + +```python +langchain_core.runnables.fallbacks._is_runnable_type( + type_: typing.Any +) -> bool +``` + + + + + + + + + + + + + +```python +langchain_core.runnables.fallbacks._returns_runnable( + attr: typing.Any +) -> bool +``` + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph.mdx new file mode 100644 index 0000000..36770d0 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph.mdx @@ -0,0 +1,1150 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/graph +title: langchain_core.runnables.graph +--- + +Graph used in `Runnable` objects. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Branch`](#langchain_core-runnables-graph-Branch) | Branch in a graph. | +| [`CurveStyle`](#langchain_core-runnables-graph-CurveStyle) | Enum for different curve styles supported by Mermaid. | +| [`Edge`](#langchain_core-runnables-graph-Edge) | Edge in a graph. | +| [`Graph`](#langchain_core-runnables-graph-Graph) | Graph of nodes and edges. | +| [`LabelsDict`](#langchain_core-runnables-graph-LabelsDict) | Dictionary of labels for nodes and edges in a graph. | +| [`MermaidDrawMethod`](#langchain_core-runnables-graph-MermaidDrawMethod) | Enum for different draw methods supported by Mermaid. | +| [`Node`](#langchain_core-runnables-graph-Node) | Node in a graph. | +| [`NodeStyles`](#langchain_core-runnables-graph-NodeStyles) | Schema for Hexadecimal color codes for different node types. | +| [`Stringifiable`](#langchain_core-runnables-graph-Stringifiable) | Protocol for objects that can be converted to a string. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_first_node`](#langchain_core-runnables-graph-_first_node) | Find the single node that is not a target of any edge. | +| [`_last_node`](#langchain_core-runnables-graph-_last_node) | Find the single node that is not a source of any edge. | +| [`is_uuid`](#langchain_core-runnables-graph-is_uuid) | Check if a string is a valid UUID. | +| [`node_data_json`](#langchain_core-runnables-graph-node_data_json) | Convert the data of a node to a JSON-serializable format. | +| [`node_data_str`](#langchain_core-runnables-graph-node_data_str) | Convert the data of a node to a string. | + +### API + + + + + +```python +class langchain_core.runnables.graph.Branch() +``` + + + + + + +**Bases:** `NamedTuple` + +Branch in a graph. + + + +A callable that returns a string representation of the condition. + + + +Optional dictionary of end node IDs for the branches. + + + + + + + +```python +class langchain_core.runnables.graph.CurveStyle +``` + + + + + + +**Bases:** `enum.Enum` + +Enum for different curve styles supported by Mermaid. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class langchain_core.runnables.graph.Edge() +``` + + + + + + +**Bases:** `NamedTuple` + +Edge in a graph. + + + +Whether the edge is conditional. + + + +Optional data associated with the edge. + + + +The source node id. + + + +The target node id. + + + + + +```python +langchain_core.runnables.graph.Edge.copy( + source: str | None = None, + target: str | None = None +) -> langchain_core.runnables.graph.Edge +``` + + + + + + +Return a copy of the edge with optional new source and target nodes. + +**Parameters:** + + +The new source node id. + + + +The new target node id. + + +**Returns:** `Edge` + +A copy of the edge with the new source and target nodes. + + + + + + + + + +```python +class langchain_core.runnables.graph.Graph( + nodes: dict[str, langchain_core.runnables.graph.Node] = dict(), + edges: list[langchain_core.runnables.graph.Edge] = list() +) +``` + + + + + + +Dataclass + +Graph of nodes and edges. + +**Parameters:** + + +Dictionary of nodes in the graph. Defaults to an empty dictionary. + + + +List of edges in the graph. Defaults to an empty list. + + + + + + + + + + + + +```python +langchain_core.runnables.graph.Graph.__bool__() -> bool +``` + + + + + + +Return whether the graph has any nodes. + + + + + + + +```python +langchain_core.runnables.graph.Graph.add_edge( + source: langchain_core.runnables.graph.Node, + target: langchain_core.runnables.graph.Node, + data: langchain_core.runnables.graph.Stringifiable | None = None, + conditional: bool = False +) -> langchain_core.runnables.graph.Edge +``` + + + + + + +Add an edge to the graph and return it. + +**Parameters:** + + +The source node of the edge. + + + +The target node of the edge. + + + +Optional data associated with the edge. + + + +Whether the edge is conditional. + + +**Returns:** `Edge` + +The edge that was added to the graph. + +**Raises:** + +- `ValueError`: If the source or target node is not in the graph. + + + + + + + +```python +langchain_core.runnables.graph.Graph.add_node( + data: type[pydantic.BaseModel] | langchain_core.runnables.base.Runnable | None, + id: str | None = None, + metadata: dict[str, typing.Any] | None = None +) -> langchain_core.runnables.graph.Node +``` + + + + + + +Add a node to the graph and return it. + +**Parameters:** + + +The data of the node. + + + +The id of the node. + + + +Optional metadata for the node. + + +**Returns:** `Node` + +The node that was added to the graph. + +**Raises:** + +- `ValueError`: If a node with the same id already exists. + + + + + + + +```python +langchain_core.runnables.graph.Graph.draw_ascii() -> str +``` + + + + + + +Draw the graph as an ASCII art string. + +**Returns:** `str` + +The ASCII art string. + + + + + + + +```python +langchain_core.runnables.graph.Graph.draw_mermaid( + with_styles: bool = True, + curve_style: langchain_core.runnables.graph.CurveStyle = CurveStyle.LINEAR, + node_colors: langchain_core.runnables.graph.NodeStyles | None = None, + wrap_label_n_words: int = 9, + frontmatter_config: dict[str, typing.Any] | None = None +) -> str +``` + + + + + + +Draw the graph as a Mermaid syntax string. + +Returns: + The Mermaid syntax string. + +**Parameters:** + + +Whether to include styles in the syntax. + + + +The style of the edges. + + + +The colors of the nodes. + + + +The number of words to wrap the node labels at. + + + +Mermaid frontmatter config. +Can be used to customize theme and styles. Will be converted to YAML and +added to the beginning of the mermaid graph. + +See more here: https://mermaid.js.org/config/configuration.html. + +Example config: + +\`\`\`python +{ + "config": { + "theme": "neutral", + "look": "handDrawn", + "themeVariables": {"primaryColor": "#e2e2e2"}, + } +} +\`\`\` + + + + + + + + +```python +langchain_core.runnables.graph.Graph.draw_mermaid_png( + curve_style: langchain_core.runnables.graph.CurveStyle = CurveStyle.LINEAR, + node_colors: langchain_core.runnables.graph.NodeStyles | None = None, + wrap_label_n_words: int = 9, + output_file_path: str | None = None, + draw_method: langchain_core.runnables.graph.MermaidDrawMethod = MermaidDrawMethod.API, + background_color: str = 'white', + padding: int = 10, + max_retries: int = 1, + retry_delay: float = 1.0, + frontmatter_config: dict[str, typing.Any] | None = None, + base_url: str | None = None, + proxies: dict[str, str] | None = None +) -> bytes +``` + + + + + + +Draw the graph as a PNG image using Mermaid. + +**Parameters:** + + +The style of the edges. + + + +The colors of the nodes. + + + +The number of words to wrap the node labels at. + + + +The path to save the image to. If `None`, the image +is not saved. + + + +The method to use to draw the graph. + + + +The color of the background. + + + +The padding around the graph. + + + +The maximum number of retries (`MermaidDrawMethod.API`). + + + +The delay between retries (`MermaidDrawMethod.API`). + + + +Mermaid frontmatter config. +Can be used to customize theme and styles. Will be converted to YAML and +added to the beginning of the mermaid graph. + +See more here: https://mermaid.js.org/config/configuration.html. + +Example config: + +\`\`\`python +{ + "config": { + "theme": "neutral", + "look": "handDrawn", + "themeVariables": {"primaryColor": "#e2e2e2"}, + } +} +\`\`\` + + + +The base URL of the Mermaid server for rendering via API. + + + +HTTP/HTTPS proxies for requests (e.g. `{"http": "http://127.0.0.1:7890"}`). + + +**Returns:** `bytes` + +The PNG image as bytes. + + + + + + + +```python +langchain_core.runnables.graph.Graph.draw_png( + output_file_path: str | None = None, + fontname: str | None = None, + labels: langchain_core.runnables.graph.LabelsDict | None = None +) -> bytes | None +``` + + + + + + +Draw the graph as a PNG image. + +**Parameters:** + + +The path to save the image to. If `None`, the image +is not saved. + + + +The name of the font to use. + + + +Optional labels for nodes and edges in the graph. Defaults to +`None`. + + +**Returns:** `bytes | None` + +The PNG image as bytes if output_file_path is None, None otherwise. + + + + + + + +```python +langchain_core.runnables.graph.Graph.extend( + graph: langchain_core.runnables.graph.Graph, + prefix: str = '' +) -> tuple[langchain_core.runnables.graph.Node | None, langchain_core.runnables.graph.Node | None] +``` + + + + + + +Add all nodes and edges from another graph. + +Note this doesn't check for duplicates, nor does it connect the graphs. + +**Parameters:** + + +The graph to add. + + + +The prefix to add to the node ids. + + +**Returns:** `tuple[Node | None, Node | None]` + +A tuple of the first and last nodes of the subgraph. + + + + + + + +```python +langchain_core.runnables.graph.Graph.first_node() -> langchain_core.runnables.graph.Node | None +``` + + + + + + +Find the single node that is not a target of any edge. + +If there is no such node, or there are multiple, return `None`. +When drawing the graph, this node would be the origin. + +**Returns:** `Node | None` + +The first node, or None if there is no such node or multiple + + + + + + + +```python +langchain_core.runnables.graph.Graph.last_node() -> langchain_core.runnables.graph.Node | None +``` + + + + + + +Find the single node that is not a source of any edge. + +If there is no such node, or there are multiple, return `None`. +When drawing the graph, this node would be the destination. + +**Returns:** `Node | None` + +The last node, or None if there is no such node or multiple + + + + + + + +```python +langchain_core.runnables.graph.Graph.next_id() -> str +``` + + + + + + +Return a new unique node identifier. + +It that can be used to add a node to the graph. + + + + + + + +```python +langchain_core.runnables.graph.Graph.print_ascii() -> None +``` + + + + + + +Print the graph as an ASCII art string. + + + + + + + +```python +langchain_core.runnables.graph.Graph.reid() -> langchain_core.runnables.graph.Graph +``` + + + + + + +Return a new graph with all nodes re-identified. + +Uses their unique, readable names where possible. + + + + + + + +```python +langchain_core.runnables.graph.Graph.remove_node( + node: langchain_core.runnables.graph.Node +) -> None +``` + + + + + + +Remove a node from the graph and all edges connected to it. + +**Parameters:** + + +The node to remove. + + + + + + + + +```python +langchain_core.runnables.graph.Graph.to_json( + with_schemas: bool = False +) -> dict[str, list[dict[str, typing.Any]]] +``` + + + + + + +Convert the graph to a JSON-serializable format. + +**Parameters:** + + +Whether to include the schemas of the nodes if they are +Pydantic models. + + +**Returns:** `dict[str, list[dict[str, Any]]]` + +A dictionary with the nodes and edges of the graph. + + + + + + + +```python +langchain_core.runnables.graph.Graph.trim_first_node() -> None +``` + + + + + + +Remove the first node if it exists and has a single outgoing edge. + +i.e., if removing it would not leave the graph without a "first" node. + + + + + + + +```python +langchain_core.runnables.graph.Graph.trim_last_node() -> None +``` + + + + + + +Remove the last node if it exists and has a single incoming edge. + +i.e., if removing it would not leave the graph without a "last" node. + + + + + + + + + +```python +class langchain_core.runnables.graph.LabelsDict +``` + + + + + + +**Bases:** `typing.TypedDict` + +Dictionary of labels for nodes and edges in a graph. + + +Labels for edges. + + + +Labels for nodes. + + + + + + + + +```python +class langchain_core.runnables.graph.MermaidDrawMethod +``` + + + + + + +**Bases:** `enum.Enum` + +Enum for different draw methods supported by Mermaid. + + + + + + + + + + + + + +```python +class langchain_core.runnables.graph.Node() +``` + + + + + + +**Bases:** `NamedTuple` + +Node in a graph. + + + +The data of the node. + + + +The unique identifier of the node. + + + +Optional metadata for the node. + + + +The name of the node. + + + + + +```python +langchain_core.runnables.graph.Node.copy( + id: str | None = None, + name: str | None = None +) -> langchain_core.runnables.graph.Node +``` + + + + + + +Return a copy of the node with optional new id and name. + +**Parameters:** + + +The new node id. + + + +The new node name. + + +**Returns:** `Node` + +A copy of the node with the new id and name. + + + + + + + + + +```python +class langchain_core.runnables.graph.NodeStyles( + default: str = 'fill:#f2f0ff,line-height:1.2', + first: str = 'fill-opacity:0', + last: str = 'fill:#bfb6fc' +) +``` + + + + + + +Dataclass + +Schema for Hexadecimal color codes for different node types. + +**Parameters:** + + +The default color code. + + + +The color code for the first node. + + + +The color code for the last node. + + + + + + + + + + + + + + + + + +```python +class langchain_core.runnables.graph.Stringifiable() +``` + + + + + + +Protocol + +Protocol for objects that can be converted to a string. + + + + + + +```python +langchain_core.runnables.graph.Stringifiable.__str__() -> str +``` + + + + + + +Convert the object to a string. + + + + + + + + + +```python +langchain_core.runnables.graph._first_node( + graph: langchain_core.runnables.graph.Graph, + exclude: collections.abc.Sequence[str] = () +) -> langchain_core.runnables.graph.Node | None +``` + + + + + + +Find the single node that is not a target of any edge. + +Exclude nodes/sources with IDs in the exclude list. + +If there is no such node, or there are multiple, return `None`. + +When drawing the graph, this node would be the origin. + + + + + + + + +```python +langchain_core.runnables.graph._last_node( + graph: langchain_core.runnables.graph.Graph, + exclude: collections.abc.Sequence[str] = () +) -> langchain_core.runnables.graph.Node | None +``` + + + + + + +Find the single node that is not a source of any edge. + +Exclude nodes/targets with IDs in the exclude list. + +If there is no such node, or there are multiple, return `None`. + +When drawing the graph, this node would be the destination. + + + + + + + + +```python +langchain_core.runnables.graph.is_uuid( + value: str +) -> bool +``` + + + + + + +Check if a string is a valid UUID. + +**Parameters:** + + +The string to check. + + +**Returns:** `bool` + +`True` if the string is a valid UUID, `False` otherwise. + + + + + + + + +```python +langchain_core.runnables.graph.node_data_json( + node: langchain_core.runnables.graph.Node, + with_schemas: bool = False +) -> dict[str, str | dict[str, typing.Any]] +``` + + + + + + +Convert the data of a node to a JSON-serializable format. + +**Parameters:** + + +The `Node` to convert. + + + +Whether to include the schema of the data if it is a Pydantic +model. + + +**Returns:** `dict[str, str | dict[str, Any]]` + +A dictionary with the type of the data and the data itself. + + + + + + + + +```python +langchain_core.runnables.graph.node_data_str( + id: str, + data: type[pydantic.BaseModel] | langchain_core.runnables.base.Runnable | None +) -> str +``` + + + + + + +Convert the data of a node to a string. + +**Parameters:** + + +The node id. + + + +The node data. + + +**Returns:** `str` + +A string representation of the data. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_ascii.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_ascii.mdx new file mode 100644 index 0000000..40dce98 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_ascii.mdx @@ -0,0 +1,391 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/graph_ascii +title: langchain_core.runnables.graph_ascii +--- + +Draws DAG in ASCII. + +Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AsciiCanvas`](#langchain_core-runnables-graph_ascii-AsciiCanvas) | Class for drawing in ASCII. | +| [`VertexViewer`](#langchain_core-runnables-graph_ascii-VertexViewer) | VertexViewer class. | +| [`_EdgeViewer`](#langchain_core-runnables-graph_ascii-_EdgeViewer) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_build_sugiyama_layout`](#langchain_core-runnables-graph_ascii-_build_sugiyama_layout) | - | +| [`draw_ascii`](#langchain_core-runnables-graph_ascii-draw_ascii) | Build a DAG and draw it in ASCII. | + +### Data + +[`_HAS_GRANDALF`](#langchain_core-runnables-graph_ascii-_HAS_GRANDALF) + +### API + + + + + +```python +class langchain_core.runnables.graph_ascii.AsciiCanvas( + cols: int, + lines: int +) +``` + + + + + + +Class for drawing in ASCII. + + + + + + + + + + + +```python +langchain_core.runnables.graph_ascii.AsciiCanvas.box( + x0: int, + y0: int, + width: int, + height: int +) -> None +``` + + + + + + +Create a box on ASCII canvas. + +**Parameters:** + + +x coordinate of the box corner. + + + +y coordinate of the box corner. + + + +box width. + + + +box height. + + +**Raises:** + +- `ValueError`: if box dimensions are invalid. + + + + + + + +```python +langchain_core.runnables.graph_ascii.AsciiCanvas.draw() -> str +``` + + + + + + +Draws ASCII canvas on the screen. + +**Returns:** `str` + +The ASCII canvas string. + + + + + + + +```python +langchain_core.runnables.graph_ascii.AsciiCanvas.line( + x0: int, + y0: int, + x1: int, + y1: int, + char: str +) -> None +``` + + + + + + +Create a line on ASCII canvas. + +**Parameters:** + + +x coordinate where the line should start. + + + +y coordinate where the line should start. + + + +x coordinate where the line should end. + + + +y coordinate where the line should end. + + + +character to draw the line with. + + + + + + + + +```python +langchain_core.runnables.graph_ascii.AsciiCanvas.point( + x: int, + y: int, + char: str +) -> None +``` + + + + + + +Create a point on ASCII canvas. + +**Parameters:** + + +x coordinate. Should be `>= 0` and `<` number of columns in +the canvas. + + + +y coordinate. Should be `>= 0` an `<` number of lines in the +canvas. + + + +character to place in the specified point on the +canvas. + + +**Raises:** + +- `ValueError`: if char is not a single character or if +coordinates are out of bounds. + + + + + + + +```python +langchain_core.runnables.graph_ascii.AsciiCanvas.text( + x: int, + y: int, + text: str +) -> None +``` + + + + + + +Print a text on ASCII canvas. + +**Parameters:** + + +x coordinate where the text should start. + + + +y coordinate where the text should start. + + + +string that should be printed. + + + + + + + + + + +```python +class langchain_core.runnables.graph_ascii.VertexViewer( + name: str +) +``` + + + + + + +VertexViewer class. + +Class to define vertex box boundaries that will be accounted for during +graph building by grandalf. + + + +Height of the box. + + + + + + + + + +Height of the box. + + + +Width of the box. + + + + + + + +```python +class langchain_core.runnables.graph_ascii._EdgeViewer() +``` + + + + + + + + + + + + +```python +langchain_core.runnables.graph_ascii._EdgeViewer.setpath( + pts: list[tuple[float]] +) -> None +``` + + + + + + + + + + + + + + +```python +langchain_core.runnables.graph_ascii._build_sugiyama_layout( + vertices: collections.abc.Mapping[str, str], + edges: collections.abc.Sequence[langchain_core.runnables.graph.Edge] +) -> typing.Any +``` + + + + + + + + + + + + + +```python +langchain_core.runnables.graph_ascii.draw_ascii( + vertices: collections.abc.Mapping[str, str], + edges: collections.abc.Sequence[langchain_core.runnables.graph.Edge] +) -> str +``` + + + + + + +Build a DAG and draw it in ASCII. + +**Parameters:** + + +list of graph vertices. + + + +list of graph edges. + + +**Returns:** `str` + +ASCII representation + +**Raises:** + +- `ValueError`: if the canvas dimensions are invalid or if +edge coordinates are invalid. + + + + + + + + +```python +langchain_core.runnables.graph_ascii._HAS_GRANDALF = True +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_mermaid.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_mermaid.mdx new file mode 100644 index 0000000..aaa2d57 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_mermaid.mdx @@ -0,0 +1,325 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/graph_mermaid +title: langchain_core.runnables.graph_mermaid +--- + +Mermaid graph drawing utilities. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_generate_mermaid_graph_styles`](#langchain_core-runnables-graph_mermaid-_generate_mermaid_graph_styles) | Generates Mermaid graph styles for different node types. | +| [`_render_mermaid_using_api`](#langchain_core-runnables-graph_mermaid-_render_mermaid_using_api) | Renders Mermaid graph using the Mermaid.INK API. | +| [`_render_mermaid_using_pyppeteer`](#langchain_core-runnables-graph_mermaid-_render_mermaid_using_pyppeteer) | Renders Mermaid graph using Pyppeteer. | +| [`_to_safe_id`](#langchain_core-runnables-graph_mermaid-_to_safe_id) | Convert a string into a Mermaid-compatible node id. | +| [`draw_mermaid`](#langchain_core-runnables-graph_mermaid-draw_mermaid) | Draws a Mermaid graph using the provided graph data. | +| [`draw_mermaid_png`](#langchain_core-runnables-graph_mermaid-draw_mermaid_png) | Draws a Mermaid graph as PNG using provided syntax. | + +### Data + +[`MARKDOWN_SPECIAL_CHARS`](#langchain_core-runnables-graph_mermaid-MARKDOWN_SPECIAL_CHARS) + +[`_HAS_PYPPETEER`](#langchain_core-runnables-graph_mermaid-_HAS_PYPPETEER) + +[`_HAS_REQUESTS`](#langchain_core-runnables-graph_mermaid-_HAS_REQUESTS) + +### API + + + + + +```python +langchain_core.runnables.graph_mermaid._generate_mermaid_graph_styles( + node_colors: langchain_core.runnables.graph.NodeStyles +) -> str +``` + + + + + + +Generates Mermaid graph styles for different node types. + + + + + + + + +```python +langchain_core.runnables.graph_mermaid._render_mermaid_using_api( + mermaid_syntax: str, + output_file_path: str | None = None, + background_color: str | None = 'white', + file_type: typing.Literal['jpeg', 'png', 'webp'] | None = 'png', + max_retries: int = 1, + retry_delay: float = 1.0, + proxies: dict[str, str] | None = None, + base_url: str | None = None +) -> bytes +``` + + + + + + +Renders Mermaid graph using the Mermaid.INK API. + + + + + + + + +```python +langchain_core.runnables.graph_mermaid._render_mermaid_using_pyppeteer( + mermaid_syntax: str, + output_file_path: str | None = None, + background_color: str | None = 'white', + padding: int = 10, + device_scale_factor: int = 3 +) -> bytes +``` + + + + + + +async + +Renders Mermaid graph using Pyppeteer. + + + + + + + + +```python +langchain_core.runnables.graph_mermaid._to_safe_id( + label: str +) -> str +``` + + + + + + +Convert a string into a Mermaid-compatible node id. + +Keep [a-zA-Z0-9_-] characters unchanged. +Map every other character -> backslash + lowercase hex codepoint. + +Result is guaranteed to be unique and Mermaid-compatible, +so nodes with special characters always render correctly. + + + + + + + + +```python +langchain_core.runnables.graph_mermaid.draw_mermaid( + nodes: dict[str, langchain_core.runnables.graph.Node], + edges: list[langchain_core.runnables.graph.Edge], + first_node: str | None = None, + last_node: str | None = None, + with_styles: bool = True, + curve_style: langchain_core.runnables.graph.CurveStyle = CurveStyle.LINEAR, + node_styles: langchain_core.runnables.graph.NodeStyles | None = None, + wrap_label_n_words: int = 9, + frontmatter_config: dict[str, typing.Any] | None = None +) -> str +``` + + + + + + +Draws a Mermaid graph using the provided graph data. + +**Parameters:** + + +List of node ids. + + + +List of edges, object with a source, target and data. + + + +Id of the first node. + + + +Id of the last node. + + + +Whether to include styles in the graph. + + + +Curve style for the edges. + + + +Node colors for different types. + + + +Words to wrap the edge labels. + + + +Mermaid frontmatter config. +Can be used to customize theme and styles. Will be converted to YAML and +added to the beginning of the mermaid graph. + +See more here: https://mermaid.js.org/config/configuration.html. + +Example config: + +\`\`\`python +{ + "config": { + "theme": "neutral", + "look": "handDrawn", + "themeVariables": {"primaryColor": "#e2e2e2"}, + } +} +\`\`\` + + +**Returns:** `str` + +Mermaid graph syntax. + + + + + + + + +```python +langchain_core.runnables.graph_mermaid.draw_mermaid_png( + mermaid_syntax: str, + output_file_path: str | None = None, + draw_method: langchain_core.runnables.graph.MermaidDrawMethod = MermaidDrawMethod.API, + background_color: str | None = 'white', + padding: int = 10, + max_retries: int = 1, + retry_delay: float = 1.0, + base_url: str | None = None, + proxies: dict[str, str] | None = None +) -> bytes +``` + + + + + + +Draws a Mermaid graph as PNG using provided syntax. + +**Parameters:** + + +Mermaid graph syntax. + + + +Path to save the PNG image. + + + +Method to draw the graph. + + + +Background color of the image. + + + +Padding around the image. + + + +Maximum number of retries (MermaidDrawMethod.API). + + + +Delay between retries (MermaidDrawMethod.API). + + + +Base URL for the Mermaid.ink API. + + + +HTTP/HTTPS proxies for requests (e.g. `{"http": "http://127.0.0.1:7890"}`). + + +**Returns:** `bytes` + +PNG image bytes. + +**Raises:** + +- `ValueError`: If an invalid draw method is provided. + + + + + + + + +```python +langchain_core.runnables.graph_mermaid.MARKDOWN_SPECIAL_CHARS = '*_`' +``` + + + + + + + + + +```python +langchain_core.runnables.graph_mermaid._HAS_PYPPETEER = True +``` + + + + + + + + + +```python +langchain_core.runnables.graph_mermaid._HAS_REQUESTS = True +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_png.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_png.mdx new file mode 100644 index 0000000..2cc2878 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/graph_png.mdx @@ -0,0 +1,361 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/graph_png +title: langchain_core.runnables.graph_png +--- + +Helper class to draw a state graph into a PNG file. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`PngDrawer`](#langchain_core-runnables-graph_png-PngDrawer) | Helper class to draw a state graph into a PNG file. | + +### Data + +[`_HAS_PYGRAPHVIZ`](#langchain_core-runnables-graph_png-_HAS_PYGRAPHVIZ) + +### API + + + + + +```python +class langchain_core.runnables.graph_png.PngDrawer( + fontname: str | None = None, + labels: langchain_core.runnables.graph.LabelsDict | None = None +) +``` + + + + + + +Helper class to draw a state graph into a PNG file. + +It requires `graphviz` and `pygraphviz` to be installed. + + + + + + + + + + + +```python +langchain_core.runnables.graph_png.PngDrawer.add_edge( + viz: typing.Any, + source: str, + target: str, + label: str | None = None, + conditional: bool = False +) -> None +``` + + + + + + +Adds an edge to the graph. + +**Parameters:** + + +The graphviz object. + + + +The source node. + + + +The target node. + + + +The label for the edge. + + + +Whether the edge is conditional. + + + + + + + + +```python +langchain_core.runnables.graph_png.PngDrawer.add_edges( + viz: typing.Any, + graph: langchain_core.runnables.graph.Graph +) -> None +``` + + + + + + +Add edges to the graph. + +**Parameters:** + + +The graphviz object. + + + +The graph to draw. + + + + + + + + +```python +langchain_core.runnables.graph_png.PngDrawer.add_node( + viz: typing.Any, + node: str +) -> None +``` + + + + + + +Adds a node to the graph. + +**Parameters:** + + +The graphviz object. + + + +The node to add. + + + + + + + + +```python +langchain_core.runnables.graph_png.PngDrawer.add_nodes( + viz: typing.Any, + graph: langchain_core.runnables.graph.Graph +) -> None +``` + + + + + + +Add nodes to the graph. + +**Parameters:** + + +The graphviz object. + + + +The graph to draw. + + + + + + + + +```python +langchain_core.runnables.graph_png.PngDrawer.add_subgraph( + viz: typing.Any, + nodes: list[list[str]], + parent_prefix: list[str] | None = None +) -> None +``` + + + + + + +Add subgraphs to the graph. + +**Parameters:** + + +The graphviz object. + + + +The nodes to add. + + + +The prefix of the parent subgraph. + + + + + + + + +```python +langchain_core.runnables.graph_png.PngDrawer.draw( + graph: langchain_core.runnables.graph.Graph, + output_path: str | None = None +) -> bytes | None +``` + + + + + + +Draw the given state graph into a PNG file. + +Requires `graphviz` and `pygraphviz` to be installed. + +**Parameters:** + + +The graph to draw + + + +The path to save the PNG. If `None`, PNG bytes are returned. + + +**Returns:** `bytes | None` + +The PNG bytes if `output_path` is None, else None. + +**Raises:** + +- `ImportError`: If `pygraphviz` is not installed. + + + + + + + +```python +langchain_core.runnables.graph_png.PngDrawer.get_edge_label( + label: str +) -> str +``` + + + + + + +Returns the label to use for an edge. + +**Parameters:** + + +The original label. + + +**Returns:** `str` + +The new label. + + + + + + + +```python +langchain_core.runnables.graph_png.PngDrawer.get_node_label( + label: str +) -> str +``` + + + + + + +Returns the label to use for a node. + +**Parameters:** + + +The original label. + + +**Returns:** `str` + +The new label. + + + + + + + +```python +langchain_core.runnables.graph_png.PngDrawer.update_styles( + viz: typing.Any, + graph: langchain_core.runnables.graph.Graph +) -> None +``` + + + + + + +staticmethod + +Update the styles of the entrypoint and END nodes. + +**Parameters:** + + +The graphviz object. + + + +The graph to draw. + + + + + + + + + + +```python +langchain_core.runnables.graph_png._HAS_PYGRAPHVIZ = True +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/history.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/history.mdx new file mode 100644 index 0000000..5e306ee --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/history.mdx @@ -0,0 +1,497 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/history +title: langchain_core.runnables.history +--- + +`Runnable` that manages chat message history for another `Runnable`. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RunnableWithMessageHistory`](#langchain_core-runnables-history-RunnableWithMessageHistory) | `Runnable` that manages chat message history for another `Runnable`. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_parameter_names`](#langchain_core-runnables-history-_get_parameter_names) | Get the parameter names of the `Callable`. | + +### Data + +[`GetSessionHistoryCallable`](#langchain_core-runnables-history-GetSessionHistoryCallable) + +[`MessagesOrDictWithMessages`](#langchain_core-runnables-history-MessagesOrDictWithMessages) + +### API + + + + + +```python +class langchain_core.runnables.history.RunnableWithMessageHistory( + runnable: langchain_core.runnables.base.Runnable[list[langchain_core.messages.BaseMessage], str | langchain_core.messages.BaseMessage | langchain_core.runnables.history.MessagesOrDictWithMessages] | langchain_core.runnables.base.Runnable[dict[str, typing.Any], str | langchain_core.messages.BaseMessage | langchain_core.runnables.history.MessagesOrDictWithMessages] | langchain_core.language_models.base.LanguageModelLike, + get_session_history: langchain_core.runnables.history.GetSessionHistoryCallable, + input_messages_key: str | None = None, + output_messages_key: str | None = None, + history_messages_key: str | None = None, + history_factory_config: collections.abc.Sequence[langchain_core.runnables.utils.ConfigurableFieldSpec] | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [RunnableBindingBase](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableBindingBase) + +`Runnable` that manages chat message history for another `Runnable`. + +A chat message history is a sequence of messages that represent a conversation. + +`RunnableWithMessageHistory` wraps another `Runnable` and manages the chat message +history for it; it is responsible for reading and updating the chat message +history. + +The formats supported for the inputs and outputs of the wrapped `Runnable` +are described below. + +`RunnableWithMessageHistory` must always be called with a config that contains +the appropriate parameters for the chat message history factory. + +By default, the `Runnable` is expected to take a single configuration parameter +called `session_id` which is a string. This parameter is used to create a new +or look up an existing chat message history that matches the given `session_id`. + +In this case, the invocation would look like this: + +`with_history.invoke(..., config={"configurable": {"session_id": "bar"}})` +; e.g., `{"configurable": {"session_id": "<SESSION_ID>"}}`. + +The configuration can be customized by passing in a list of +`ConfigurableFieldSpec` objects to the `history_factory_config` parameter (see +example below). + +In the examples, we will use a chat message history with an in-memory +implementation to make it easy to experiment and see the results. + +For production use cases, you will want to use a persistent implementation +of chat message history, such as `RedisChatMessageHistory`. + +Example: Chat message history with an in-memory implementation for testing. + + ```python + from operator import itemgetter + + from langchain_openai.chat_models import ChatOpenAI + + from langchain_core.chat_history import BaseChatMessageHistory + from langchain_core.documents import Document + from langchain_core.messages import BaseMessage, AIMessage + from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + from pydantic import BaseModel, Field + from langchain_core.runnables import ( + RunnableLambda, + ConfigurableFieldSpec, + RunnablePassthrough, + ) + from langchain_core.runnables.history import RunnableWithMessageHistory + + + class InMemoryHistory(BaseChatMessageHistory, BaseModel): + """In memory implementation of chat message history.""" + + messages: list[BaseMessage] = Field(default_factory=list) + + def add_messages(self, messages: list[BaseMessage]) -> None: + """Add a list of messages to the store""" + self.messages.extend(messages) + + def clear(self) -> None: + self.messages = [] + + # Here we use a global variable to store the chat message history. + # This will make it easier to inspect it to see the underlying results. + store = {} + + def get_by_session_id(session_id: str) -> BaseChatMessageHistory: + if session_id not in store: + store[session_id] = InMemoryHistory() + return store[session_id] + + + history = get_by_session_id("1") + history.add_message(AIMessage(content="hello")) + print(store) # noqa: T201 + + ``` + +Example where the wrapped `Runnable` takes a dictionary input: + + ```python + from typing import Optional + + from langchain_anthropic import ChatAnthropic + from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + from langchain_core.runnables.history import RunnableWithMessageHistory + + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You're an assistant who's good at {ability}"), + MessagesPlaceholder(variable_name="history"), + ("human", "{question}"), + ] + ) + + chain = prompt | ChatAnthropic(model="claude-2") + + chain_with_history = RunnableWithMessageHistory( + chain, + # Uses the get_by_session_id function defined in the example + # above. + get_by_session_id, + input_messages_key="question", + history_messages_key="history", + ) + + print( + chain_with_history.invoke( # noqa: T201 + {"ability": "math", "question": "What does cosine mean?"}, + config={"configurable": {"session_id": "foo"}}, + ) + ) + + # Uses the store defined in the example above. + print(store) # noqa: T201 + + print( + chain_with_history.invoke( # noqa: T201 + {"ability": "math", "question": "What's its inverse"}, + config={"configurable": {"session_id": "foo"}}, + ) + ) + + print(store) # noqa: T201 + ``` + +Example where the session factory takes two keys (`user_id` and `conversation_id`): + + ```python + store = {} + + + def get_session_history( + user_id: str, conversation_id: str + ) -> BaseChatMessageHistory: + if (user_id, conversation_id) not in store: + store[(user_id, conversation_id)] = InMemoryHistory() + return store[(user_id, conversation_id)] + + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You're an assistant who's good at {ability}"), + MessagesPlaceholder(variable_name="history"), + ("human", "{question}"), + ] + ) + + chain = prompt | ChatAnthropic(model="claude-2") + + with_message_history = RunnableWithMessageHistory( + chain, + get_session_history=get_session_history, + input_messages_key="question", + history_messages_key="history", + history_factory_config=[ + ConfigurableFieldSpec( + id="user_id", + annotation=str, + name="User ID", + description="Unique identifier for the user.", + default="", + is_shared=True, + ), + ConfigurableFieldSpec( + id="conversation_id", + annotation=str, + name="Conversation ID", + description="Unique identifier for the conversation.", + default="", + is_shared=True, + ), + ], + ) + + with_message_history.invoke( + {"ability": "math", "question": "What does cosine mean?"}, + config={"configurable": {"user_id": "123", "conversation_id": "1"}}, + ) + ``` + + + + + + +Get the configuration specs for the `RunnableWithMessageHistory`. + + + +Function that returns a new `BaseChatMessageHistory`. + +This function should either take a single positional argument `session_id` of type +string and return a corresponding chat message history instance + + + +Configure fields that should be passed to the chat history factory. + +See `ConfigurableFieldSpec` for more details. + + + +Must be specified if the base `Runnable` accepts a `dict` as input and expects a +separate key for historical messages. + + + +Must be specified if the base `Runnable` accepts a `dict` as input. +The key in the input `dict` that contains the messages. + + + +Must be specified if the base `Runnable` returns a `dict` as output. +The key in the output `dict` that contains the messages. + + + + + +```python +langchain_core.runnables.history.RunnableWithMessageHistory._aenter_history( + value: dict[str, typing.Any], + config: langchain_core.runnables.config.RunnableConfig +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.history.RunnableWithMessageHistory._aexit_history( + run: langchain_core.tracers.schemas.Run, + config: langchain_core.runnables.config.RunnableConfig +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.history.RunnableWithMessageHistory._enter_history( + value: typing.Any, + config: langchain_core.runnables.config.RunnableConfig +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.history.RunnableWithMessageHistory._exit_history( + run: langchain_core.tracers.schemas.Run, + config: langchain_core.runnables.config.RunnableConfig +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.runnables.history.RunnableWithMessageHistory._get_input_messages( + input_val: str | langchain_core.messages.BaseMessage | collections.abc.Sequence[langchain_core.messages.BaseMessage] | dict +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.history.RunnableWithMessageHistory._get_output_messages( + output_val: str | langchain_core.messages.BaseMessage | collections.abc.Sequence[langchain_core.messages.BaseMessage] | dict +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.history.RunnableWithMessageHistory._merge_configs( + configs: langchain_core.runnables.config.RunnableConfig | None = () +) -> langchain_core.runnables.config.RunnableConfig +``` + + + + + + + + + + + + +```python +langchain_core.runnables.history.RunnableWithMessageHistory.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.history.RunnableWithMessageHistory.get_output_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Get a Pydantic model that can be used to validate output to the `Runnable`. + +`Runnable` objects that leverage the `configurable_fields` and +`configurable_alternatives` methods will have a dynamic output schema that +depends on which configuration the `Runnable` is invoked with. + +This method allows to get an output schema for a specific configuration. + +**Parameters:** + + +A config to use when generating the schema. + + +**Returns:** `type[BaseModel]` + +A Pydantic model that can be used to validate output. + + + + + + + + + +```python +langchain_core.runnables.history._get_parameter_names( + callable_: langchain_core.runnables.history.GetSessionHistoryCallable +) -> list[str] +``` + + + + + + +Get the parameter names of the `Callable`. + + + + + + + + +```python +langchain_core.runnables.history.GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] +``` + + + + + + + + + +```python +langchain_core.runnables.history.MessagesOrDictWithMessages = Sequence['BaseMessage'] | dict[str, Any] +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/passthrough.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/passthrough.mdx new file mode 100644 index 0000000..ceb6d01 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/passthrough.mdx @@ -0,0 +1,1080 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/passthrough +title: langchain_core.runnables.passthrough +--- + +Implementation of the `RunnablePassthrough`. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RunnableAssign`](#langchain_core-runnables-passthrough-RunnableAssign) | Runnable that assigns key-value pairs to `dict[str, Any]` inputs. | +| [`RunnablePassthrough`](#langchain_core-runnables-passthrough-RunnablePassthrough) | Runnable to passthrough inputs unchanged or with additional keys. | +| [`RunnablePick`](#langchain_core-runnables-passthrough-RunnablePick) | `Runnable` that picks keys from `dict[str, Any]` inputs. | + +### Functions + +| Name | Description | +|------|-------------| +| [`aidentity`](#langchain_core-runnables-passthrough-aidentity) | Async identity function. | +| [`identity`](#langchain_core-runnables-passthrough-identity) | Identity function. | + +### Data + +[`_graph_passthrough`](#langchain_core-runnables-passthrough-_graph_passthrough) + +### API + + + + + +```python +class langchain_core.runnables.passthrough.RunnableAssign( + mapper: langchain_core.runnables.base.RunnableParallel[dict[str, typing.Any]], + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [RunnableSerializable[dict[str, Any], dict[str, Any]]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +Runnable that assigns key-value pairs to `dict[str, Any]` inputs. + +The `RunnableAssign` class takes input dictionaries and, through a +`RunnableParallel` instance, applies transformations, then combines +these with the original data, introducing new key-value pairs based +on the mapper's logic. + +**Examples:** + + + +```python +# This is a RunnableAssign +from langchain_core.runnables.passthrough import ( + RunnableAssign, + RunnableParallel, +) +from langchain_core.runnables.base import RunnableLambda + + +def add_ten(x: dict[str, int]) -> dict[str, int]: + return {"added": x["input"] + 10} + + +mapper = RunnableParallel( + { + "add_step": RunnableLambda(add_ten), + } +) + +runnable_assign = RunnableAssign(mapper) + +# Synchronous example +runnable_assign.invoke({"input": 5}) +# returns {'input': 5, 'add_step': {'added': 15}} + +# Asynchronous example +await runnable_assign.ainvoke({"input": 5}) +# returns {'input': 5, 'add_step': {'added': 15}} +``` + + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign._ainvoke( + value: dict[str, typing.Any], + run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> dict[str, typing.Any] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign._atransform( + values: collections.abc.AsyncIterator[dict[str, typing.Any]], + run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[dict[str, typing.Any]] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign._invoke( + value: dict[str, typing.Any], + run_manager: langchain_core.callbacks.manager.CallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> dict[str, typing.Any] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign._transform( + values: collections.abc.Iterator[dict[str, typing.Any]], + run_manager: langchain_core.callbacks.manager.CallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[dict[str, typing.Any]] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.ainvoke( + input: dict[str, typing.Any], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> dict[str, typing.Any] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.astream( + input: dict[str, typing.Any], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[dict[str, typing.Any]] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.atransform( + input: collections.abc.AsyncIterator[dict[str, typing.Any]], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[dict[str, typing.Any]] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.get_graph( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> langchain_core.runnables.graph.Graph +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.get_input_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "runnable"]` + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.get_name( + suffix: str | None = None, + name: str | None = None +) -> str +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.get_output_schema( + config: langchain_core.runnables.config.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.invoke( + input: dict[str, typing.Any], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> dict[str, typing.Any] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.stream( + input: dict[str, typing.Any], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[dict[str, typing.Any]] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnableAssign.transform( + input: collections.abc.Iterator[dict[str, typing.Any]], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[dict[str, typing.Any]] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.passthrough.RunnablePassthrough( + func: collections.abc.Callable[[Other], None] | collections.abc.Callable[[Other, RunnableConfig], None] | collections.abc.Callable[[Other], collections.abc.Awaitable[None]] | collections.abc.Callable[[Other, RunnableConfig], collections.abc.Awaitable[None]] | None = None, + afunc: collections.abc.Callable[[Other], collections.abc.Awaitable[None]] | collections.abc.Callable[[Other, RunnableConfig], collections.abc.Awaitable[None]] | None = None, + input_type: type[langchain_core.runnables.base.Other] | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [RunnableSerializable[Other, Other]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +Runnable to passthrough inputs unchanged or with additional keys. + +This `Runnable` behaves almost like the identity function, except that it +can be configured to add additional keys to the output, if the input is a +dict. + +The examples below demonstrate this `Runnable` works using a few simple +chains. The chains rely on simple lambdas to make the examples easy to execute +and experiment with. + +In some cases, it may be useful to pass the input through while adding some +keys to the output. In this case, you can use the `assign` method: + + ```python + from langchain_core.runnables import RunnablePassthrough + + + def fake_llm(prompt: str) -> str: # Fake LLM for the example + return "completion" + + + runnable = { + "llm1": fake_llm, + "llm2": fake_llm, + } | RunnablePassthrough.assign( + total_chars=lambda inputs: len(inputs["llm1"] + inputs["llm2"]) + ) + + runnable.invoke("hello") + # {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20} + ``` + +**Examples:** + + + +```python +from langchain_core.runnables import ( + RunnableLambda, + RunnableParallel, + RunnablePassthrough, +) + +runnable = RunnableParallel( + origin=RunnablePassthrough(), modified=lambda x: x + 1 +) + +runnable.invoke(1) # {'origin': 1, 'modified': 2} + + +def fake_llm(prompt: str) -> str: # Fake LLM for the example + return "completion" + + +chain = RunnableLambda(fake_llm) | { + "original": RunnablePassthrough(), # Original LLM output + "parsed": lambda text: text[::-1], # Parsing logic +} + +chain.invoke("hello") # {'original': 'completion', 'parsed': 'noitelpmoc'} +``` + + + + + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePassthrough.__repr_args__() -> typing.Any +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePassthrough.ainvoke( + input: langchain_core.runnables.base.Other, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.runnables.base.Other +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePassthrough.assign( + kwargs: langchain_core.runnables.base.Runnable[dict[str, typing.Any], typing.Any] | collections.abc.Callable[[dict[str, Any]], typing.Any] | collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[dict[str, typing.Any], typing.Any] | collections.abc.Callable[[dict[str, Any]], typing.Any]] = {} +) -> langchain_core.runnables.passthrough.RunnableAssign +``` + + + + + + +classmethod + +Merge the Dict input with the output produced by the mapping argument. + +**Parameters:** + + +`Runnable`, `Callable` or a `Mapping` from keys to `Runnable` +objects or `Callable`s. + + +**Returns:** `RunnableAssign` + +A `Runnable` that merges the `dict` input with the output produced by the + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePassthrough.astream( + input: langchain_core.runnables.base.Other, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.base.Other] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePassthrough.atransform( + input: collections.abc.AsyncIterator[langchain_core.runnables.base.Other], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.base.Other] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePassthrough.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "runnable"]` + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePassthrough.invoke( + input: langchain_core.runnables.base.Other, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.base.Other +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePassthrough.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePassthrough.stream( + input: langchain_core.runnables.base.Other, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.runnables.base.Other] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePassthrough.transform( + input: collections.abc.Iterator[langchain_core.runnables.base.Other], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[langchain_core.runnables.base.Other] +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.passthrough.RunnablePick( + keys: str | list[str], + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [RunnableSerializable[dict[str, Any], Any]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +`Runnable` that picks keys from `dict[str, Any]` inputs. + +`RunnablePick` class represents a `Runnable` that selectively picks keys from a +dictionary input. It allows you to specify one or more keys to extract +from the input dictionary. + +!!! note "Return Type Behavior" + The return type depends on the `keys` parameter: + + - When `keys` is a `str`: Returns the single value associated with that key + - When `keys` is a `list`: Returns a dictionary containing only the selected + keys + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick._ainvoke( + value: dict[str, typing.Any] +) -> typing.Any +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick._atransform( + chunks: collections.abc.AsyncIterator[dict[str, typing.Any]] +) -> collections.abc.AsyncIterator[typing.Any] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick._pick( + value: dict[str, typing.Any] +) -> typing.Any +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick._transform( + chunks: collections.abc.Iterator[dict[str, typing.Any]] +) -> collections.abc.Iterator[typing.Any] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick.ainvoke( + input: dict[str, typing.Any], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick.astream( + input: dict[str, typing.Any], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[typing.Any] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick.atransform( + input: collections.abc.AsyncIterator[dict[str, typing.Any]], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[typing.Any] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "runnable"]` + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick.get_name( + suffix: str | None = None, + name: str | None = None +) -> str +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick.invoke( + input: dict[str, typing.Any], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick.stream( + input: dict[str, typing.Any], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[typing.Any] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.RunnablePick.transform( + input: collections.abc.Iterator[dict[str, typing.Any]], + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> collections.abc.Iterator[typing.Any] +``` + + + + + + + + + + + + + + +```python +langchain_core.runnables.passthrough.aidentity( + x: langchain_core.runnables.base.Other +) -> langchain_core.runnables.base.Other +``` + + + + + + +async + +Async identity function. + +**Parameters:** + + +Input. + + +**Returns:** `Other` + +Output. + + + + + + + + +```python +langchain_core.runnables.passthrough.identity( + x: langchain_core.runnables.base.Other +) -> langchain_core.runnables.base.Other +``` + + + + + + +Identity function. + +**Parameters:** + + +Input. + + +**Returns:** `Other` + +Output. + + + + + + + + +```python +langchain_core.runnables.passthrough._graph_passthrough: RunnablePassthrough = RunnablePassthrough() +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/retry.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/retry.mdx new file mode 100644 index 0000000..fbe3874 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/retry.mdx @@ -0,0 +1,414 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/retry +title: langchain_core.runnables.retry +--- + +`Runnable` that retries a `Runnable` if it fails. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ExponentialJitterParams`](#langchain_core-runnables-retry-ExponentialJitterParams) | Parameters for `tenacity.wait_exponential_jitter`. | +| [`RunnableRetry`](#langchain_core-runnables-retry-RunnableRetry) | Retry a Runnable if it fails. | + +### Data + +[`T`](#langchain_core-runnables-retry-T) + +[`U`](#langchain_core-runnables-retry-U) + +### API + + + + + +```python +class langchain_core.runnables.retry.ExponentialJitterParams +``` + + + + + + +**Bases:** `typing.TypedDict` + +Parameters for `tenacity.wait_exponential_jitter`. + + +Base for exponential backoff. + + + +Initial wait. + + + +Random additional wait sampled from random.uniform(0, jitter). + + + +Maximum wait. + + + + + + + + +```python +class langchain_core.runnables.retry.RunnableRetry() +``` + + + + + + +**Bases:** [RunnableBindingBase[Input, Output]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableBindingBase) + +Retry a Runnable if it fails. + +RunnableRetry can be used to add retry logic to any object +that subclasses the base Runnable. + +Such retries are especially useful for network calls that may fail +due to transient errors. + +The RunnableRetry is implemented as a RunnableBinding. The easiest +way to use it is through the `.with_retry()` method on all Runnables. + +Example: +Here's an example that uses a RunnableLambda to raise an exception + + ```python + import time + + + def foo(input) -> None: + '''Fake function that raises an exception.''' + raise ValueError(f"Invoking foo failed. At time {time.time()}") + + + runnable = RunnableLambda(foo) + + runnable_with_retries = runnable.with_retry( + retry_if_exception_type=(ValueError,), # Retry only on ValueError + wait_exponential_jitter=True, # Add jitter to the exponential backoff + stop_after_attempt=2, # Try twice + exponential_jitter_params={"initial": 2}, # if desired, customize backoff + ) + + # The method invocation above is equivalent to the longer form below: + + runnable_with_retries = RunnableRetry( + bound=runnable, + retry_exception_types=(ValueError,), + max_attempt_number=2, + wait_exponential_jitter=True, + exponential_jitter_params={"initial": 2}, + ) + ``` + +This logic can be used to retry any Runnable, including a chain of Runnables, +but in general it's best practice to keep the scope of the retry as small as +possible. For example, if you have a chain of Runnables, you should only retry +the Runnable that is likely to fail, not the entire chain. + + + + + + +Parameters for `tenacity.wait_exponential_jitter`. Namely: `initial`, +`max`, `exp_base`, and `jitter` (all `float` values). + + + +The maximum number of attempts to retry the Runnable. + + + +The exception types to retry on. By default all exceptions are retried. + +In general you should only retry on exceptions that are likely to be +transient, such as network errors. + +Good exceptions to retry are all server errors (5xx) and selected client +errors (4xx) such as 429 Too Many Requests. + + + +Whether to add jitter to the exponential backoff. + + + + + +```python +langchain_core.runnables.retry.RunnableRetry._abatch( + inputs: list[langchain_core.runnables.utils.Input], + run_manager: list[langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun], + config: list[langchain_core.runnables.config.RunnableConfig], + kwargs: typing.Any = {} +) -> list[langchain_core.runnables.utils.Output | Exception] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.retry.RunnableRetry._ainvoke( + input_: langchain_core.runnables.utils.Input, + run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.retry.RunnableRetry._async_retrying( + kwargs: typing.Any = {} +) -> tenacity.AsyncRetrying +``` + + + + + + + + + + + + +```python +langchain_core.runnables.retry.RunnableRetry._batch( + inputs: list[langchain_core.runnables.utils.Input], + run_manager: list[langchain_core.callbacks.manager.CallbackManagerForChainRun], + config: list[langchain_core.runnables.config.RunnableConfig], + kwargs: typing.Any = {} +) -> list[langchain_core.runnables.utils.Output | Exception] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.retry.RunnableRetry._invoke( + input_: langchain_core.runnables.utils.Input, + run_manager: langchain_core.callbacks.manager.CallbackManagerForChainRun, + config: langchain_core.runnables.config.RunnableConfig, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + + + + + + + +```python +langchain_core.runnables.retry.RunnableRetry._patch_config( + config: langchain_core.runnables.config.RunnableConfig, + run_manager: langchain_core.runnables.retry.T, + retry_state: tenacity.RetryCallState +) -> langchain_core.runnables.config.RunnableConfig +``` + + + + + + +staticmethod + + + + + + + +```python +langchain_core.runnables.retry.RunnableRetry._patch_config_list( + config: list[langchain_core.runnables.config.RunnableConfig], + run_manager: list[langchain_core.runnables.retry.T], + retry_state: tenacity.RetryCallState +) -> list[langchain_core.runnables.config.RunnableConfig] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.retry.RunnableRetry._sync_retrying( + kwargs: typing.Any = {} +) -> tenacity.Retrying +``` + + + + + + + + + + + + +```python +langchain_core.runnables.retry.RunnableRetry.abatch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.retry.RunnableRetry.ainvoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.retry.RunnableRetry.batch( + inputs: list[langchain_core.runnables.utils.Input], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.retry.RunnableRetry.invoke( + input: langchain_core.runnables.utils.Input, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + + + + + + + + + +```python +langchain_core.runnables.retry.T = TypeVar('T', CallbackManagerForChainRun, AsyncCallbackManagerForChainRun) +``` + + + + + + + + + +```python +langchain_core.runnables.retry.U = TypeVar('U') +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/router.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/router.mdx new file mode 100644 index 0000000..5e6a68c --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/router.mdx @@ -0,0 +1,241 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/router +title: langchain_core.runnables.router +--- + +`Runnable` that routes to a set of `Runnable` objects. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RouterInput`](#langchain_core-runnables-router-RouterInput) | Router input. | +| [`RouterRunnable`](#langchain_core-runnables-router-RouterRunnable) | `Runnable` that routes to a set of `Runnable` based on `Input['key']`. | + +### API + + + + + +```python +class langchain_core.runnables.router.RouterInput +``` + + + + + + +**Bases:** `typing.TypedDict` + +Router input. + + +The input to pass to the selected `Runnable`. + + + +The key to route on. + + + + + + + + +```python +class langchain_core.runnables.router.RouterRunnable( + runnables: collections.abc.Mapping[str, langchain_core.runnables.base.Runnable[typing.Any, langchain_core.runnables.utils.Output] | collections.abc.Callable[[Any], langchain_core.runnables.utils.Output]] +) +``` + + + + + + +**Bases:** [RunnableSerializable[RouterInput, Output]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +`Runnable` that routes to a set of `Runnable` based on `Input['key']`. + +Returns the output of the selected Runnable. + + + + + + + + + + + + + + +```python +langchain_core.runnables.router.RouterRunnable.abatch( + inputs: list[langchain_core.runnables.router.RouterInput], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.router.RouterRunnable.ainvoke( + input: langchain_core.runnables.router.RouterInput, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.router.RouterRunnable.astream( + input: langchain_core.runnables.router.RouterInput, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.utils.Output] +``` + + + + + + +async + + + + + + + +```python +langchain_core.runnables.router.RouterRunnable.batch( + inputs: list[langchain_core.runnables.router.RouterInput], + config: langchain_core.runnables.config.RunnableConfig | list[langchain_core.runnables.config.RunnableConfig] | None = None, + return_exceptions: bool = False, + kwargs: typing.Any | None = {} +) -> list[langchain_core.runnables.utils.Output] +``` + + + + + + + + + + + + +```python +langchain_core.runnables.router.RouterRunnable.get_lc_namespace() -> list[str] +``` + + + + + + +classmethod + +Get the namespace of the LangChain object. + +**Returns:** `list[str]` + +`["langchain", "schema", "runnable"]` + + + + + + + +```python +langchain_core.runnables.router.RouterRunnable.invoke( + input: langchain_core.runnables.router.RouterInput, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> langchain_core.runnables.utils.Output +``` + + + + + + + + + + + + +```python +langchain_core.runnables.router.RouterRunnable.is_lc_serializable() -> bool +``` + + + + + + +classmethod + +Return `True` as this class is serializable. + + + + + + + +```python +langchain_core.runnables.router.RouterRunnable.stream( + input: langchain_core.runnables.router.RouterInput, + config: langchain_core.runnables.config.RunnableConfig | None = None, + kwargs: typing.Any | None = {} +) -> collections.abc.Iterator[langchain_core.runnables.utils.Output] +``` + + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/schema.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/schema.mdx new file mode 100644 index 0000000..f84aa7c --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/schema.mdx @@ -0,0 +1,243 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/schema +title: langchain_core.runnables.schema +--- + +Module contains typedefs that are used with `Runnable` objects. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseStreamEvent`](#langchain_core-runnables-schema-BaseStreamEvent) | Streaming event. | +| [`CustomStreamEvent`](#langchain_core-runnables-schema-CustomStreamEvent) | Custom stream event created by the user. | +| [`EventData`](#langchain_core-runnables-schema-EventData) | Data associated with a streaming event. | +| [`StandardStreamEvent`](#langchain_core-runnables-schema-StandardStreamEvent) | A standard stream event that follows LangChain convention for event data. | + +### Data + +[`StreamEvent`](#langchain_core-runnables-schema-StreamEvent) + +### API + + + + + +```python +class langchain_core.runnables.schema.BaseStreamEvent +``` + + + + + + +**Bases:** `typing.TypedDict` + +Streaming event. + +Schema of a streaming event which is produced from the `astream_events` method. + + +Event names are of the format: `on_[runnable_type]_(start|stream|end)`. + +Runnable types are one of: + +- **llm** - used by non chat models +- **chat_model** - used by chat models +- **prompt** -- e.g., `ChatPromptTemplate` +- **tool** -- from tools defined via `@tool` decorator or inheriting + from `Tool`/`BaseTool` +- **chain** - most `Runnable` objects are of this type + +Further, the events are categorized as one of: + +- **start** - when the `Runnable` starts +- **stream** - when the `Runnable` is streaming +- **end* - when the `Runnable` ends + +start, stream and end are associated with slightly different `data` payload. + +Please see the documentation for `EventData` for more details. + + + +Metadata associated with the `Runnable` that generated this event. + +Metadata can either be bound to a `Runnable` using + + `.with_config({"metadata": { "foo": "bar" }})` + +or passed at run time using + + `.astream_events(..., {"metadata": {"foo": "bar"}})`. + + + +A list of the parent IDs associated with this event. + +Root Events will have an empty list. + +For example, if a `Runnable` A calls `Runnable` B, then the event generated by +`Runnable` B will have `Runnable` A's ID in the `parent_ids` field. + +The order of the parent IDs is from the root parent to the immediate parent. + +Only supported as of v2 of the astream events API. v1 will return an empty list. + + + +An randomly generated ID to keep track of the execution of the given `Runnable`. + +Each child `Runnable` that gets invoked as part of the execution of a parent +`Runnable` is assigned its own unique ID. + + + +Tags associated with the `Runnable` that generated this event. + +Tags are always inherited from parent `Runnable` objects. + +Tags can either be bound to a `Runnable` using `.with_config({"tags": ["hello"]})` +or passed at run time using `.astream_events(..., {"tags": ["hello"]})`. + + + + + + + + +```python +class langchain_core.runnables.schema.CustomStreamEvent() +``` + + + + + + +**Bases:** [BaseStreamEvent](#langchain_core-runnables-schema-BaseStreamEvent) + +Custom stream event created by the user. + + + +The data associated with the event. Free form and can be anything. + + + +The event type. + + + +User defined name for the event. + + + + + + + +```python +class langchain_core.runnables.schema.EventData +``` + + + + + + +**Bases:** `typing.TypedDict` + +Data associated with a streaming event. + + +A streaming chunk from the output that generated the event. + +chunks support addition in general, and adding them up should result +in the output of the `Runnable` that generated the event. + + + +The error that occurred during the execution of the `Runnable`. + +This field is only available if the `Runnable` raised an exception. + +!!! version-added "Added in `langchain-core` 1.0.0" + + + +The input passed to the `Runnable` that generated the event. + +Inputs will sometimes be available at the *START* of the `Runnable`, and +sometimes at the *END* of the `Runnable`. + +If a `Runnable` is able to stream its inputs, then its input by definition +won't be known until the *END* of the `Runnable` when it has finished streaming +its inputs. + + + +The output of the `Runnable` that generated the event. + +Outputs will only be available at the *END* of the `Runnable`. + +For most `Runnable` objects, this field can be inferred from the `chunk` field, +though there might be some exceptions for special a cased `Runnable` (e.g., like +chat models), which may return more information. + + + +The tool call ID associated with the tool execution. + +This field is available for the `on_tool_error` event and can be used to +link errors to specific tool calls in stateless agent implementations. + + + + + + + + +```python +class langchain_core.runnables.schema.StandardStreamEvent() +``` + + + + + + +**Bases:** [BaseStreamEvent](#langchain_core-runnables-schema-BaseStreamEvent) + +A standard stream event that follows LangChain convention for event data. + + + +Event data. + +The contents of the event data depend on the event type. + + + +The name of the `Runnable` that generated the event. + + + + + + + +```python +langchain_core.runnables.schema.StreamEvent = StandardStreamEvent | CustomStreamEvent +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/utils.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/utils.mdx new file mode 100644 index 0000000..0a56a40 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/runnables/utils.mdx @@ -0,0 +1,1426 @@ +--- +layout: overview +slug: langchain-core/langchain_core/runnables/utils +title: langchain_core.runnables.utils +--- + +Utility code for `Runnable` objects. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AddableDict`](#langchain_core-runnables-utils-AddableDict) | Dictionary that can be added to another dictionary. | +| [`ConfigurableField`](#langchain_core-runnables-utils-ConfigurableField) | Field that can be configured by the user. | +| [`ConfigurableFieldMultiOption`](#langchain_core-runnables-utils-ConfigurableFieldMultiOption) | Field that can be configured by the user with multiple default values. | +| [`ConfigurableFieldSingleOption`](#langchain_core-runnables-utils-ConfigurableFieldSingleOption) | Field that can be configured by the user with a default value. | +| [`ConfigurableFieldSpec`](#langchain_core-runnables-utils-ConfigurableFieldSpec) | Field that can be configured by the user. It is a specification of a field. | +| [`FunctionNonLocals`](#langchain_core-runnables-utils-FunctionNonLocals) | Get the nonlocal variables accessed of a function. | +| [`GetLambdaSource`](#langchain_core-runnables-utils-GetLambdaSource) | Get the source code of a lambda function. | +| [`IsFunctionArgDict`](#langchain_core-runnables-utils-IsFunctionArgDict) | Check if the first argument of a function is a dict. | +| [`IsLocalDict`](#langchain_core-runnables-utils-IsLocalDict) | Check if a name is a local dict. | +| [`NonLocals`](#langchain_core-runnables-utils-NonLocals) | Get nonlocal variables accessed. | +| [`SupportsAdd`](#langchain_core-runnables-utils-SupportsAdd) | Protocol for objects that support addition. | +| [`_RootEventFilter`](#langchain_core-runnables-utils-_RootEventFilter) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`aadd`](#langchain_core-runnables-utils-aadd) | Asynchronously add a sequence of addable objects together. | +| [`accepts_config`](#langchain_core-runnables-utils-accepts_config) | Check if a callable accepts a config argument. | +| [`accepts_context`](#langchain_core-runnables-utils-accepts_context) | Check if a callable accepts a context argument. | +| [`accepts_run_manager`](#langchain_core-runnables-utils-accepts_run_manager) | Check if a callable accepts a run_manager argument. | +| [`add`](#langchain_core-runnables-utils-add) | Add a sequence of addable objects together. | +| [`asyncio_accepts_context`](#langchain_core-runnables-utils-asyncio_accepts_context) | Check if asyncio.create_task accepts a `context` arg. | +| [`coro_with_context`](#langchain_core-runnables-utils-coro_with_context) | Await a coroutine with a context. | +| [`gated_coro`](#langchain_core-runnables-utils-gated_coro) | Run a coroutine with a semaphore. | +| [`gather_with_concurrency`](#langchain_core-runnables-utils-gather_with_concurrency) | Gather coroutines with a limit on the number of concurrent coroutines. | +| [`get_function_first_arg_dict_keys`](#langchain_core-runnables-utils-get_function_first_arg_dict_keys) | Get the keys of the first argument of a function if it is a dict. | +| [`get_function_nonlocals`](#langchain_core-runnables-utils-get_function_nonlocals) | Get the nonlocal variables accessed by a function. | +| [`get_lambda_source`](#langchain_core-runnables-utils-get_lambda_source) | Get the source code of a lambda function. | +| [`get_unique_config_specs`](#langchain_core-runnables-utils-get_unique_config_specs) | Get the unique config specs from a sequence of config specs. | +| [`indent_lines_after_first`](#langchain_core-runnables-utils-indent_lines_after_first) | Indent all lines of text after the first line. | +| [`is_async_callable`](#langchain_core-runnables-utils-is_async_callable) | Check if a function is async. | +| [`is_async_generator`](#langchain_core-runnables-utils-is_async_generator) | Check if a function is an async generator. | + +### Data + +[`Addable`](#langchain_core-runnables-utils-Addable) + +[`AnyConfigurableField`](#langchain_core-runnables-utils-AnyConfigurableField) + +[`Input`](#langchain_core-runnables-utils-Input) + +[`Output`](#langchain_core-runnables-utils-Output) + +[`_T`](#langchain_core-runnables-utils-_T) + +[`_T_co`](#langchain_core-runnables-utils-_T_co) + +[`_T_contra`](#langchain_core-runnables-utils-_T_contra) + +### API + + + + + +```python +class langchain_core.runnables.utils.AddableDict() +``` + + + + + + +**Bases:** `dict[str, Any]` + +Dictionary that can be added to another dictionary. + + + + + + +```python +langchain_core.runnables.utils.AddableDict.__add__( + other: langchain_core.runnables.utils.AddableDict +) -> langchain_core.runnables.utils.AddableDict +``` + + + + + + +Add a dictionary to this dictionary. + +**Parameters:** + + +The other dictionary to add. + + +**Returns:** `AddableDict` + +A dictionary that is the result of adding the two dictionaries. + + + + + + + +```python +langchain_core.runnables.utils.AddableDict.__radd__( + other: langchain_core.runnables.utils.AddableDict +) -> langchain_core.runnables.utils.AddableDict +``` + + + + + + +Add this dictionary to another dictionary. + +**Parameters:** + + +The other dictionary to be added to. + + +**Returns:** `AddableDict` + +A dictionary that is the result of adding the two dictionaries. + + + + + + + + + +```python +class langchain_core.runnables.utils.ConfigurableField() +``` + + + + + + +**Bases:** `NamedTuple` + +Field that can be configured by the user. + + + +The annotation of the field. + + + +The description of the field. + + + +The unique identifier of the field. + + + +Whether the field is shared. + + + +The name of the field. + + + + + +```python +langchain_core.runnables.utils.ConfigurableField.__hash__() -> int +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.utils.ConfigurableFieldMultiOption() +``` + + + + + + +**Bases:** `NamedTuple` + +Field that can be configured by the user with multiple default values. + + + +The default values for the field. + + + +The description of the field. + + + +The unique identifier of the field. + + + +Whether the field is shared. + + + +The name of the field. + + + +The options for the field. + + + + + +```python +langchain_core.runnables.utils.ConfigurableFieldMultiOption.__hash__() -> int +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.utils.ConfigurableFieldSingleOption() +``` + + + + + + +**Bases:** `NamedTuple` + +Field that can be configured by the user with a default value. + + + +The default value for the field. + + + +The description of the field. + + + +The unique identifier of the field. + + + +Whether the field is shared. + + + +The name of the field. + + + +The options for the field. + + + + + +```python +langchain_core.runnables.utils.ConfigurableFieldSingleOption.__hash__() -> int +``` + + + + + + + + + + + + + + +```python +class langchain_core.runnables.utils.ConfigurableFieldSpec() +``` + + + + + + +**Bases:** `NamedTuple` + +Field that can be configured by the user. It is a specification of a field. + + + +The annotation of the field. + + + +The default value for the field. + + + +The dependencies of the field. + + + +The description of the field. + + + +The unique identifier of the field. + + + +Whether the field is shared. + + + +The name of the field. + + + + + + + +```python +class langchain_core.runnables.utils.FunctionNonLocals() +``` + + + + + + +**Bases:** `NodeVisitor` + +Get the nonlocal variables accessed of a function. + + + + + + + + +```python +langchain_core.runnables.utils.FunctionNonLocals.visit_AsyncFunctionDef( + node: ast.AsyncFunctionDef +) -> None +``` + + + + + + +Visit an async function definition. + +**Parameters:** + + +The node to visit. + + + + + + + + +```python +langchain_core.runnables.utils.FunctionNonLocals.visit_FunctionDef( + node: ast.FunctionDef +) -> None +``` + + + + + + +Visit a function definition. + +**Parameters:** + + +The node to visit. + + + + + + + + +```python +langchain_core.runnables.utils.FunctionNonLocals.visit_Lambda( + node: ast.Lambda +) -> None +``` + + + + + + +Visit a lambda function. + +**Parameters:** + + +The node to visit. + + + + + + + + + + +```python +class langchain_core.runnables.utils.GetLambdaSource() +``` + + + + + + +**Bases:** `NodeVisitor` + +Get the source code of a lambda function. + + + + + + + + + + + +```python +langchain_core.runnables.utils.GetLambdaSource.visit_Lambda( + node: ast.Lambda +) -> None +``` + + + + + + +Visit a lambda function. + +**Parameters:** + + +The node to visit. + + + + + + + + + + +```python +class langchain_core.runnables.utils.IsFunctionArgDict() +``` + + + + + + +**Bases:** `NodeVisitor` + +Check if the first argument of a function is a dict. + + + + + + + + +```python +langchain_core.runnables.utils.IsFunctionArgDict.visit_AsyncFunctionDef( + node: ast.AsyncFunctionDef +) -> None +``` + + + + + + +Visit an async function definition. + +**Parameters:** + + +The node to visit. + + + + + + + + +```python +langchain_core.runnables.utils.IsFunctionArgDict.visit_FunctionDef( + node: ast.FunctionDef +) -> None +``` + + + + + + +Visit a function definition. + +**Parameters:** + + +The node to visit. + + + + + + + + +```python +langchain_core.runnables.utils.IsFunctionArgDict.visit_Lambda( + node: ast.Lambda +) -> None +``` + + + + + + +Visit a lambda function. + +**Parameters:** + + +The node to visit. + + + + + + + + + + +```python +class langchain_core.runnables.utils.IsLocalDict( + name: str, + keys: set[str] +) +``` + + + + + + +**Bases:** `NodeVisitor` + +Check if a name is a local dict. + + + + + + +```python +langchain_core.runnables.utils.IsLocalDict.visit_Call( + node: ast.Call +) -> None +``` + + + + + + +Visit a call node. + +**Parameters:** + + +The node to visit. + + + + + + + + +```python +langchain_core.runnables.utils.IsLocalDict.visit_Subscript( + node: ast.Subscript +) -> None +``` + + + + + + +Visit a subscript node. + +**Parameters:** + + +The node to visit. + + + + + + + + + + +```python +class langchain_core.runnables.utils.NonLocals() +``` + + + + + + +**Bases:** `NodeVisitor` + +Get nonlocal variables accessed. + + + + + + + + + + + +```python +langchain_core.runnables.utils.NonLocals.visit_Attribute( + node: ast.Attribute +) -> None +``` + + + + + + +Visit an attribute node. + +**Parameters:** + + +The node to visit. + + + + + + + + +```python +langchain_core.runnables.utils.NonLocals.visit_Name( + node: ast.Name +) -> None +``` + + + + + + +Visit a name node. + +**Parameters:** + + +The node to visit. + + + + + + + + + + +```python +class langchain_core.runnables.utils.SupportsAdd() +``` + + + + + + +Protocol + +**Bases:** `Protocol[_T_contra, _T_co]` + +Protocol for objects that support addition. + + + + + + +```python +langchain_core.runnables.utils.SupportsAdd.__add__( + x: langchain_core.runnables.utils._T_contra +) -> langchain_core.runnables.utils._T_co +``` + + + + + + +Add the object to another object. + + + + + + + + + +```python +class langchain_core.runnables.utils._RootEventFilter( + include_names: collections.abc.Sequence[str] | None = None, + include_types: collections.abc.Sequence[str] | None = None, + include_tags: collections.abc.Sequence[str] | None = None, + exclude_names: collections.abc.Sequence[str] | None = None, + exclude_types: collections.abc.Sequence[str] | None = None, + exclude_tags: collections.abc.Sequence[str] | None = None +) +``` + + + + + + + + + + +```python +langchain_core.runnables.utils._RootEventFilter.include_event( + event: langchain_core.runnables.schema.StreamEvent, + root_type: str +) -> bool +``` + + + + + + +Determine whether to include an event. + + + + + + + + + +```python +langchain_core.runnables.utils.aadd( + addables: collections.abc.AsyncIterable[langchain_core.runnables.utils.Addable] +) -> langchain_core.runnables.utils.Addable | None +``` + + + + + + +async + +Asynchronously add a sequence of addable objects together. + +**Parameters:** + + +The addable objects to add. + + +**Returns:** `Addable | None` + +The result of adding the addable objects. + + + + + + + + +```python +langchain_core.runnables.utils.accepts_config( + callable: collections.abc.Callable[..., typing.Any] +) -> bool +``` + + + + + + +Check if a callable accepts a config argument. + +**Parameters:** + + +The callable to check. + + +**Returns:** `bool` + +`True` if the callable accepts a config argument, `False` otherwise. + + + + + + + + +```python +langchain_core.runnables.utils.accepts_context( + callable: collections.abc.Callable[..., typing.Any] +) -> bool +``` + + + + + + +Check if a callable accepts a context argument. + +**Parameters:** + + +The callable to check. + + +**Returns:** `bool` + +`True` if the callable accepts a context argument, `False` otherwise. + + + + + + + + +```python +langchain_core.runnables.utils.accepts_run_manager( + callable: collections.abc.Callable[..., typing.Any] +) -> bool +``` + + + + + + +Check if a callable accepts a run_manager argument. + +**Parameters:** + + +The callable to check. + + +**Returns:** `bool` + +`True` if the callable accepts a run_manager argument, `False` otherwise. + + + + + + + + +```python +langchain_core.runnables.utils.add( + addables: collections.abc.Iterable[langchain_core.runnables.utils.Addable] +) -> langchain_core.runnables.utils.Addable | None +``` + + + + + + +Add a sequence of addable objects together. + +**Parameters:** + + +The addable objects to add. + + +**Returns:** `Addable | None` + +The result of adding the addable objects. + + + + + + + + +```python +langchain_core.runnables.utils.asyncio_accepts_context() -> bool +``` + + + + + + +Check if asyncio.create_task accepts a `context` arg. + +**Returns:** `bool` + +True if `asyncio.create_task` accepts a context argument, `False` otherwise. + + + + + + + + +```python +langchain_core.runnables.utils.coro_with_context( + coro: collections.abc.Awaitable[langchain_core.runnables.utils._T], + context: contextvars.Context, + create_task: bool = False +) -> collections.abc.Awaitable[langchain_core.runnables.utils._T] +``` + + + + + + +Await a coroutine with a context. + +**Parameters:** + + +The coroutine to await. + + + +The context to use. + + + +Whether to create a task. + + +**Returns:** `Awaitable[_T]` + +The coroutine with the context. + + + + + + + + +```python +langchain_core.runnables.utils.gated_coro( + semaphore: asyncio.Semaphore, + coro: collections.abc.Coroutine +) -> typing.Any +``` + + + + + + +async + +Run a coroutine with a semaphore. + +**Parameters:** + + +The semaphore to use. + + + +The coroutine to run. + + +**Returns:** `Any` + +The result of the coroutine. + + + + + + + + +```python +langchain_core.runnables.utils.gather_with_concurrency( + n: int | None, + coros: collections.abc.Coroutine = () +) -> list +``` + + + + + + +async + +Gather coroutines with a limit on the number of concurrent coroutines. + +**Parameters:** + + +The number of coroutines to run concurrently. + + + +The coroutines to run. + + +**Returns:** `list` + +The results of the coroutines. + + + + + + + + +```python +langchain_core.runnables.utils.get_function_first_arg_dict_keys( + func: collections.abc.Callable +) -> list[str] | None +``` + + + + + + +Get the keys of the first argument of a function if it is a dict. + +**Parameters:** + + +The function to check. + + +**Returns:** `list[str] | None` + +The keys of the first argument if it is a dict, None otherwise. + + + + + + + + +```python +langchain_core.runnables.utils.get_function_nonlocals( + func: collections.abc.Callable +) -> list[typing.Any] +``` + + + + + + +Get the nonlocal variables accessed by a function. + +**Parameters:** + + +The function to check. + + +**Returns:** `list[Any]` + +The nonlocal variables accessed by the function. + + + + + + + + +```python +langchain_core.runnables.utils.get_lambda_source( + func: collections.abc.Callable +) -> str | None +``` + + + + + + +Get the source code of a lambda function. + +**Parameters:** + + +a Callable that can be a lambda function. + + +**Returns:** `str | None` + +the source code of the lambda function. + + + + + + + + +```python +langchain_core.runnables.utils.get_unique_config_specs( + specs: collections.abc.Iterable[langchain_core.runnables.utils.ConfigurableFieldSpec] +) -> list[langchain_core.runnables.utils.ConfigurableFieldSpec] +``` + + + + + + +Get the unique config specs from a sequence of config specs. + +**Parameters:** + + +The config specs. + + +**Returns:** `list[ConfigurableFieldSpec]` + +The unique config specs. + +**Raises:** + +- `ValueError`: If the runnable sequence contains conflicting config specs. + + + + + + + + +```python +langchain_core.runnables.utils.indent_lines_after_first( + text: str, + prefix: str +) -> str +``` + + + + + + +Indent all lines of text after the first line. + +**Parameters:** + + +The text to indent. + + + +Used to determine the number of spaces to indent. + + +**Returns:** `str` + +The indented text. + + + + + + + + +```python +langchain_core.runnables.utils.is_async_callable( + func: typing.Any +) -> typing.TypeGuard[collections.abc.Callable[..., collections.abc.Awaitable]] +``` + + + + + + +Check if a function is async. + +**Parameters:** + + +The function to check. + + +**Returns:** `TypeGuard[Callable[..., Awaitable]]` + +`True` if the function is async, `False` otherwise. + + + + + + + + +```python +langchain_core.runnables.utils.is_async_generator( + func: typing.Any +) -> typing.TypeGuard[collections.abc.Callable[..., collections.abc.AsyncIterator]] +``` + + + + + + +Check if a function is an async generator. + +**Parameters:** + + +The function to check. + + +**Returns:** `TypeGuard[Callable[..., AsyncIterator]]` + +`True` if the function is an async generator, `False` otherwise. + + + + + + + + +```python +langchain_core.runnables.utils.Addable = TypeVar('Addable', bound=(SupportsAdd[Any, Any])) +``` + + + + + + + + + +```python +langchain_core.runnables.utils.AnyConfigurableField = ConfigurableField | ConfigurableFieldSingleOption | ConfigurableFieldMultiOption +``` + + + + + + + + + +```python +langchain_core.runnables.utils.Input = TypeVar('Input', contravariant=True) +``` + + + + + + + + + +```python +langchain_core.runnables.utils.Output = TypeVar('Output', covariant=True) +``` + + + + + + + + + +```python +langchain_core.runnables.utils._T = TypeVar('_T') +``` + + + + + + + + + +```python +langchain_core.runnables.utils._T_co = TypeVar('_T_co', covariant=True) +``` + + + + + + + + + +```python +langchain_core.runnables.utils._T_contra = TypeVar('_T_contra', contravariant=True) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/stores.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/stores.mdx new file mode 100644 index 0000000..362690e --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/stores.mdx @@ -0,0 +1,657 @@ +--- +layout: overview +slug: langchain-core/langchain_core/stores +title: langchain_core.stores +--- + +**Store** implements the key-value stores and storage helpers. + +Module provides implementations of various key-value stores that conform +to a simple key-value interface. + +The primary goal of these storages is to support implementation of caching. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseStore`](#langchain_core-stores-BaseStore) | Abstract interface for a key-value store. | +| [`InMemoryBaseStore`](#langchain_core-stores-InMemoryBaseStore) | In-memory implementation of the `BaseStore` using a dictionary. | +| [`InMemoryByteStore`](#langchain_core-stores-InMemoryByteStore) | In-memory store for bytes. | +| [`InMemoryStore`](#langchain_core-stores-InMemoryStore) | In-memory store for any type of data. | +| [`InvalidKeyException`](#langchain_core-stores-InvalidKeyException) | Raised when a key is invalid; e.g., uses incorrect characters. | + +### Data + +[`ByteStore`](#langchain_core-stores-ByteStore) + +[`K`](#langchain_core-stores-K) + +[`V`](#langchain_core-stores-V) + +### API + + + + + +```python +class langchain_core.stores.BaseStore() +``` + + + + + + +Abstract + +**Bases:** `Generic[K, V]` + +Abstract interface for a key-value store. + +This is an interface that's meant to abstract away the details of different +key-value stores. It provides a simple interface for getting, setting, and deleting +key-value pairs. + +The basic methods are `mget`, `mset`, and `mdelete` for getting, setting, and +deleting multiple key-value pairs at once. The `yield_keys` method is used to +iterate over keys that match a given prefix. + +The async versions of these methods are also provided, which are meant to be used in +async contexts. The async methods are named with an `a` prefix, e.g., `amget`, +`amset`, `amdelete`, and `ayield_keys`. + +By default, the `amget`, `amset`, `amdelete`, and `ayield_keys` methods are +implemented using the synchronous methods. If the store can natively support async +operations, it should override these methods. + +By design the methods only accept batches of keys and values, and not single keys or +values. This is done to force user code to work with batches which will usually be +more efficient by saving on round trips to the store. + +**Examples:** + + + +```python +from langchain.storage import BaseStore + + +class MyInMemoryStore(BaseStore[str, int]): + def __init__(self) -> None: + self.store: dict[str, int] = {} + + def mget(self, keys: Sequence[str]) -> list[int | None]: + return [self.store.get(key) for key in keys] + + def mset(self, key_value_pairs: Sequence[tuple[str, int]]) -> None: + for key, value in key_value_pairs: + self.store[key] = value + + def mdelete(self, keys: Sequence[str]) -> None: + for key in keys: + if key in self.store: + del self.store[key] + + def yield_keys(self, prefix: str | None = None) -> Iterator[str]: + if prefix is None: + yield from self.store.keys() + else: + for key in self.store.keys(): + if key.startswith(prefix): + yield key +``` + + + + + + + + +```python +langchain_core.stores.BaseStore.amdelete( + keys: collections.abc.Sequence[langchain_core.stores.K] +) -> None +``` + + + + + + +async + +Async delete the given keys and their associated values. + +**Parameters:** + + +A sequence of keys to delete. + + + + + + + + +```python +langchain_core.stores.BaseStore.amget( + keys: collections.abc.Sequence[langchain_core.stores.K] +) -> list[langchain_core.stores.V | None] +``` + + + + + + +async + +Async get the values associated with the given keys. + +**Parameters:** + + +A sequence of keys. + + +**Returns:** `list[V | None]` + +A sequence of optional values associated with the keys. +If a key is not found, the corresponding value will be `None`. + + + + + + + +```python +langchain_core.stores.BaseStore.amset( + key_value_pairs: collections.abc.Sequence[tuple[langchain_core.stores.K, langchain_core.stores.V]] +) -> None +``` + + + + + + +async + +Async set the values for the given keys. + +**Parameters:** + + +A sequence of key-value pairs. + + + + + + + + +```python +langchain_core.stores.BaseStore.ayield_keys( + prefix: str | None = None +) -> collections.abc.AsyncIterator[langchain_core.stores.K] | collections.abc.AsyncIterator[str] +``` + + + + + + +async + +Async get an iterator over keys that match the given prefix. + +**Parameters:** + + +The prefix to match. + + + + + + + + +```python +langchain_core.stores.BaseStore.mdelete( + keys: collections.abc.Sequence[langchain_core.stores.K] +) -> None +``` + + + + + + +abstract + +Delete the given keys and their associated values. + +**Parameters:** + + +A sequence of keys to delete. + + + + + + + + +```python +langchain_core.stores.BaseStore.mget( + keys: collections.abc.Sequence[langchain_core.stores.K] +) -> list[langchain_core.stores.V | None] +``` + + + + + + +abstract + +Get the values associated with the given keys. + +**Parameters:** + + +A sequence of keys. + + +**Returns:** `list[V | None]` + +A sequence of optional values associated with the keys. +If a key is not found, the corresponding value will be `None`. + + + + + + + +```python +langchain_core.stores.BaseStore.mset( + key_value_pairs: collections.abc.Sequence[tuple[langchain_core.stores.K, langchain_core.stores.V]] +) -> None +``` + + + + + + +abstract + +Set the values for the given keys. + +**Parameters:** + + +A sequence of key-value pairs. + + + + + + + + +```python +langchain_core.stores.BaseStore.yield_keys( + prefix: str | None = None +) -> collections.abc.Iterator[langchain_core.stores.K] | collections.abc.Iterator[str] +``` + + + + + + +abstract + +Get an iterator over keys that match the given prefix. + +**Parameters:** + + +The prefix to match. + + + + + + + + + + +```python +class langchain_core.stores.InMemoryBaseStore() +``` + + + + + + +**Bases:** [BaseStore[str, V]](#langchain_core-stores-BaseStore), `Generic[V]` + +In-memory implementation of the `BaseStore` using a dictionary. + + + + + + + + +```python +langchain_core.stores.InMemoryBaseStore.amdelete( + keys: collections.abc.Sequence[str] +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.stores.InMemoryBaseStore.amget( + keys: collections.abc.Sequence[str] +) -> list[langchain_core.stores.V | None] +``` + + + + + + +async + + + + + + + +```python +langchain_core.stores.InMemoryBaseStore.amset( + key_value_pairs: collections.abc.Sequence[tuple[str, langchain_core.stores.V]] +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.stores.InMemoryBaseStore.ayield_keys( + prefix: str | None = None +) -> collections.abc.AsyncIterator[str] +``` + + + + + + +async + +Async get an async iterator over keys that match the given prefix. + +**Parameters:** + + +The prefix to match. + + + + + + + + +```python +langchain_core.stores.InMemoryBaseStore.mdelete( + keys: collections.abc.Sequence[str] +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.stores.InMemoryBaseStore.mget( + keys: collections.abc.Sequence[str] +) -> list[langchain_core.stores.V | None] +``` + + + + + + + + + + + + +```python +langchain_core.stores.InMemoryBaseStore.mset( + key_value_pairs: collections.abc.Sequence[tuple[str, langchain_core.stores.V]] +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.stores.InMemoryBaseStore.yield_keys( + prefix: str | None = None +) -> collections.abc.Iterator[str] +``` + + + + + + +Get an iterator over keys that match the given prefix. + +**Parameters:** + + +The prefix to match. + + + + + + + + + + +```python +class langchain_core.stores.InMemoryByteStore() +``` + + + + + + +**Bases:** [InMemoryBaseStore[bytes]](#langchain_core-stores-InMemoryBaseStore) + +In-memory store for bytes. + +**Examples:** + + + +```python +from langchain.storage import InMemoryByteStore + +store = InMemoryByteStore() +store.mset([("key1", b"value1"), ("key2", b"value2")]) +store.mget(["key1", "key2"]) +# [b'value1', b'value2'] +store.mdelete(["key1"]) +list(store.yield_keys()) +# ['key2'] +list(store.yield_keys(prefix="k")) +# ['key2'] +``` + + + + + + + + + + +```python +class langchain_core.stores.InMemoryStore() +``` + + + + + + +**Bases:** [InMemoryBaseStore[Any]](#langchain_core-stores-InMemoryBaseStore) + +In-memory store for any type of data. + +**Examples:** + + + +```python +from langchain.storage import InMemoryStore + +store = InMemoryStore() +store.mset([("key1", "value1"), ("key2", "value2")]) +store.mget(["key1", "key2"]) +# ['value1', 'value2'] +store.mdelete(["key1"]) +list(store.yield_keys()) +# ['key2'] +list(store.yield_keys(prefix="k")) +# ['key2'] +``` + + + + + + + + + + +```python +class langchain_core.stores.InvalidKeyException() +``` + + + + + + +Exception + +**Bases:** [LangChainException](/langchain-core/langchain_core/exceptions#langchain_core-exceptions-LangChainException) + +Raised when a key is invalid; e.g., uses incorrect characters. + + + + + + + + +```python +langchain_core.stores.ByteStore = BaseStore[str, bytes] +``` + + + + + + + + + +```python +langchain_core.stores.K = TypeVar('K') +``` + + + + + + + + + +```python +langchain_core.stores.V = TypeVar('V') +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/structured_query.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/structured_query.mdx new file mode 100644 index 0000000..2a07f8f --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/structured_query.mdx @@ -0,0 +1,426 @@ +--- +layout: overview +slug: langchain-core/langchain_core/structured_query +title: langchain_core.structured_query +--- + +Internal representation of a structured query language. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Comparator`](#langchain_core-structured_query-Comparator) | Enumerator of the comparison operators. | +| [`Comparison`](#langchain_core-structured_query-Comparison) | Comparison to a value. | +| [`Expr`](#langchain_core-structured_query-Expr) | Base class for all expressions. | +| [`FilterDirective`](#langchain_core-structured_query-FilterDirective) | Filtering expression. | +| [`Operation`](#langchain_core-structured_query-Operation) | Logical operation over other directives. | +| [`Operator`](#langchain_core-structured_query-Operator) | Enumerator of the operations. | +| [`StructuredQuery`](#langchain_core-structured_query-StructuredQuery) | Structured query. | +| [`Visitor`](#langchain_core-structured_query-Visitor) | Defines interface for IR translation using a visitor pattern. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_to_snake_case`](#langchain_core-structured_query-_to_snake_case) | Convert a name into snake_case. | + +### API + + + + + +```python +class langchain_core.structured_query.Comparator +``` + + + + + + +**Bases:** `enum.Enum` + +Enumerator of the comparison operators. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class langchain_core.structured_query.Comparison( + comparator: langchain_core.structured_query.Comparator, + attribute: str, + value: typing.Any, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [FilterDirective](#langchain_core-structured_query-FilterDirective) + +Comparison to a value. + + + +The attribute to compare. + + + +The comparator to use. + + + +The value to compare to. + + + + + + + +```python +class langchain_core.structured_query.Expr() +``` + + + + + + +**Bases:** `BaseModel` + +Base class for all expressions. + + + + + + +```python +langchain_core.structured_query.Expr.accept( + visitor: langchain_core.structured_query.Visitor +) -> typing.Any +``` + + + + + + +Accept a visitor. + +**Parameters:** + + +visitor to accept. + + +**Returns:** `Any` + +result of visiting. + + + + + + + + + +```python +class langchain_core.structured_query.FilterDirective() +``` + + + + + + +Abstract + +**Bases:** [Expr](#langchain_core-structured_query-Expr) + +Filtering expression. + + + + + + + + +```python +class langchain_core.structured_query.Operation( + operator: langchain_core.structured_query.Operator, + arguments: list[langchain_core.structured_query.FilterDirective], + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [FilterDirective](#langchain_core-structured_query-FilterDirective) + +Logical operation over other directives. + + + +The arguments to the operator. + + + +The operator to use. + + + + + + + +```python +class langchain_core.structured_query.Operator +``` + + + + + + +**Bases:** `enum.Enum` + +Enumerator of the operations. + + + + + + + + + + + + + + + + +```python +class langchain_core.structured_query.StructuredQuery( + query: str, + filter: langchain_core.structured_query.FilterDirective | None, + limit: int | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [Expr](#langchain_core-structured_query-Expr) + +Structured query. + + + +Filtering expression. + + + +Limit on the number of results. + + + +Query string. + + + + + + + +```python +class langchain_core.structured_query.Visitor() +``` + + + + + + +Abstract + +Defines interface for IR translation using a visitor pattern. + + + +Allowed comparators for the visitor. + + + +Allowed operators for the visitor. + + + + + +```python +langchain_core.structured_query.Visitor._validate_func( + func: langchain_core.structured_query.Operator | langchain_core.structured_query.Comparator +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.structured_query.Visitor.visit_comparison( + comparison: langchain_core.structured_query.Comparison +) -> typing.Any +``` + + + + + + +abstract + +Translate a Comparison. + +**Parameters:** + + +Comparison to translate. + + + + + + + + +```python +langchain_core.structured_query.Visitor.visit_operation( + operation: langchain_core.structured_query.Operation +) -> typing.Any +``` + + + + + + +abstract + +Translate an Operation. + +**Parameters:** + + +Operation to translate. + + + + + + + + +```python +langchain_core.structured_query.Visitor.visit_structured_query( + structured_query: langchain_core.structured_query.StructuredQuery +) -> typing.Any +``` + + + + + + +abstract + +Translate a StructuredQuery. + +**Parameters:** + + +StructuredQuery to translate. + + + + + + + + + + +```python +langchain_core.structured_query._to_snake_case( + name: str +) -> str +``` + + + + + + +Convert a name into snake_case. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/sys_info.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/sys_info.mdx new file mode 100644 index 0000000..d0c7600 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/sys_info.mdx @@ -0,0 +1,64 @@ +--- +layout: overview +slug: langchain-core/langchain_core/sys_info +title: langchain_core.sys_info +--- + +Print information about the system and langchain packages for debugging purposes. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_sub_deps`](#langchain_core-sys_info-_get_sub_deps) | Get any specified sub-dependencies. | +| [`print_sys_info`](#langchain_core-sys_info-print_sys_info) | Print information about the environment for debugging purposes. | + +### API + + + + + +```python +langchain_core.sys_info._get_sub_deps( + packages: collections.abc.Sequence[str] +) -> list[str] +``` + + + + + + +Get any specified sub-dependencies. + + + + + + + + +```python +langchain_core.sys_info.print_sys_info( + additional_pkgs: collections.abc.Sequence[str] = () +) -> None +``` + + + + + + +Print information about the environment for debugging purposes. + +**Parameters:** + + +Additional packages to include in the output. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools.mdx new file mode 100644 index 0000000..e70857b --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools.mdx @@ -0,0 +1,94 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tools +title: langchain_core.tools +--- + +Tools are classes that an Agent uses to interact with the world. + +Each tool has a description. Agent uses the description to choose the righ tool for the +job. + +## Submodules + +- **[`langchain_core.tools.base`](/langchain-core/langchain_core/tools/base)** +- **[`langchain_core.tools.convert`](/langchain-core/langchain_core/tools/convert)** +- **[`langchain_core.tools.render`](/langchain-core/langchain_core/tools/render)** +- **[`langchain_core.tools.retriever`](/langchain-core/langchain_core/tools/retriever)** +- **[`langchain_core.tools.simple`](/langchain-core/langchain_core/tools/simple)** +- **[`langchain_core.tools.structured`](/langchain-core/langchain_core/tools/structured)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-tools-__dir__) | - | +| [`__getattr__`](#langchain_core-tools-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-tools-__all__) + +[`_dynamic_imports`](#langchain_core-tools-_dynamic_imports) + +### API + + + + + +```python +langchain_core.tools.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.tools.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.tools.__all__ = ('FILTERED_ARGS', 'ArgsSchema', 'BaseTool', 'BaseToolkit', 'InjectedToolArg', 'I... +``` + + + + + + + + + +```python +langchain_core.tools._dynamic_imports = {'FILTERED_ARGS': 'base', 'ArgsSchema': 'base', 'BaseTool': 'base', 'BaseToolkit... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/base.mdx new file mode 100644 index 0000000..0e5a701 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/base.mdx @@ -0,0 +1,1764 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tools/base +title: langchain_core.tools.base +--- + +Base classes and utilities for LangChain tools. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseTool`](#langchain_core-tools-base-BaseTool) | Base class for all LangChain tools. | +| [`BaseToolkit`](#langchain_core-tools-base-BaseToolkit) | Base class for toolkits containing related tools. | +| [`InjectedToolArg`](#langchain_core-tools-base-InjectedToolArg) | Annotation for tool arguments that are injected at runtime. | +| [`InjectedToolCallId`](#langchain_core-tools-base-InjectedToolCallId) | Annotation for injecting the tool call ID. | +| [`SchemaAnnotationError`](#langchain_core-tools-base-SchemaAnnotationError) | Raised when `args_schema` is missing or has an incorrect type annotation. | +| [`ToolException`](#langchain_core-tools-base-ToolException) | Exception thrown when a tool execution error occurs. | +| [`_DirectlyInjectedToolArg`](#langchain_core-tools-base-_DirectlyInjectedToolArg) | Annotation for tool arguments that are injected at runtime. | +| [`_SchemaConfig`](#langchain_core-tools-base-_SchemaConfig) | Configuration for Pydantic models generated from function signatures. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_format_output`](#langchain_core-tools-base-_format_output) | Format tool output as a `ToolMessage` if appropriate. | +| [`_function_annotations_are_pydantic_v1`](#langchain_core-tools-base-_function_annotations_are_pydantic_v1) | Check if all Pydantic annotations in a function are from v1. | +| [`_get_annotation_description`](#langchain_core-tools-base-_get_annotation_description) | Extract description from an `Annotated` type. | +| [`_get_filtered_args`](#langchain_core-tools-base-_get_filtered_args) | Get filtered arguments from a function's signature. | +| [`_get_runnable_config_param`](#langchain_core-tools-base-_get_runnable_config_param) | Find the parameter name for `RunnableConfig` in a function. | +| [`_get_type_hints`](#langchain_core-tools-base-_get_type_hints) | Get type hints from a function, handling partial functions. | +| [`_handle_tool_error`](#langchain_core-tools-base-_handle_tool_error) | Handle tool execution errors based on the configured flag. | +| [`_handle_validation_error`](#langchain_core-tools-base-_handle_validation_error) | Handle validation errors based on the configured flag. | +| [`_infer_arg_descriptions`](#langchain_core-tools-base-_infer_arg_descriptions) | Infer argument descriptions from function docstring and annotations. | +| [`_is_annotated_type`](#langchain_core-tools-base-_is_annotated_type) | Check if a type is an `Annotated` type. | +| [`_is_directly_injected_arg_type`](#langchain_core-tools-base-_is_directly_injected_arg_type) | Check if a type annotation indicates a directly injected argument. | +| [`_is_injected_arg_type`](#langchain_core-tools-base-_is_injected_arg_type) | Check if a type annotation indicates an injected argument. | +| [`_is_message_content_block`](#langchain_core-tools-base-_is_message_content_block) | Check if object is a valid message content block. | +| [`_is_message_content_type`](#langchain_core-tools-base-_is_message_content_type) | Check if object is valid message content format. | +| [`_is_pydantic_annotation`](#langchain_core-tools-base-_is_pydantic_annotation) | Check if a type annotation is a Pydantic model. | +| [`_is_tool_call`](#langchain_core-tools-base-_is_tool_call) | Check if the input is a tool call dictionary. | +| [`_parse_python_function_docstring`](#langchain_core-tools-base-_parse_python_function_docstring) | Parse function and argument descriptions from a docstring. | +| [`_prep_run_args`](#langchain_core-tools-base-_prep_run_args) | Prepare arguments for tool execution. | +| [`_replace_type_vars`](#langchain_core-tools-base-_replace_type_vars) | Replace `TypeVar`s in a type annotation with concrete types. | +| [`_stringify`](#langchain_core-tools-base-_stringify) | Convert content to string, preferring JSON format. | +| [`_validate_docstring_args_against_annotations`](#langchain_core-tools-base-_validate_docstring_args_against_annotations) | Validate that docstring arguments match function annotations. | +| [`create_schema_from_function`](#langchain_core-tools-base-create_schema_from_function) | Create a Pydantic schema from a function's signature. | +| [`get_all_basemodel_annotations`](#langchain_core-tools-base-get_all_basemodel_annotations) | Get all annotations from a Pydantic `BaseModel` and its parents. | + +### Data + +[`ArgsSchema`](#langchain_core-tools-base-ArgsSchema) + +[`FILTERED_ARGS`](#langchain_core-tools-base-FILTERED_ARGS) + +[`TOOL_MESSAGE_BLOCK_TYPES`](#langchain_core-tools-base-TOOL_MESSAGE_BLOCK_TYPES) + +[`_EMPTY_SET`](#langchain_core-tools-base-_EMPTY_SET) + +[`_logger`](#langchain_core-tools-base-_logger) + +### API + + + + + +```python +class langchain_core.tools.base.BaseTool( + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [RunnableSerializable[str | dict | ToolCall, Any]](/langchain-core/langchain_core/runnables/base#langchain_core-runnables-base-RunnableSerializable) + +Base class for all LangChain tools. + +This abstract class defines the interface that all LangChain tools must implement. + +Tools are components that can be called by agents to perform specific actions. + + + + + + +Get the tool's input arguments schema. + + + +Pydantic model class to validate and parse the tool's input arguments. + +Args schema should be either: + +- A subclass of `pydantic.BaseModel`. +- A subclass of `pydantic.v1.BaseModel` if accessing v1 namespace in pydantic 2 +- A JSON schema dict + + + +Callbacks to be called during tool execution. + + + +Used to tell the model how/when/why to use the tool. + +You can provide few-shot examples as a part of the description. + + + +Optional provider-specific extra fields for the tool. + +This is used to pass provider-specific configuration that doesn't fit into +standard tool fields. + + + +Handle the content of the `ToolException` thrown. + + + +Handle the content of the `ValidationError` thrown. + + + +Check if the tool accepts only a single input argument. + + + +Optional metadata associated with the tool. + +This metadata will be associated with each call to this tool, +and passed as arguments to the handlers defined in `callbacks`. + +You can use these to, e.g., identify a specific instance of a tool with its usecase. + + + + + + +The unique name of the tool that clearly communicates its purpose. + + + +The tool response format. + +If `'content'` then the output of the tool is interpreted as the contents of a +`ToolMessage`. If `'content_and_artifact'` then the output is expected to be a +two-tuple corresponding to the `(content, artifact)` of a `ToolMessage`. + + + +Whether to return the tool's output directly. + +Setting this to `True` means that after the tool is called, the `AgentExecutor` will +stop looping. + + + +Optional list of tags associated with the tool. + +These tags will be associated with each call to this tool, +and passed as arguments to the handlers defined in `callbacks`. + +You can use these to, e.g., identify a specific instance of a tool with its use +case. + + + +Get the schema for tool calls, excluding injected arguments. + + + +Whether to log the tool's progress. + + + + + +```python +langchain_core.tools.base.BaseTool.__init_subclass__( + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Validate the tool class definition during subclass creation. + +**Parameters:** + + +Additional keyword arguments passed to the parent class. + + +**Raises:** + +- `SchemaAnnotationError`: If `args_schema` has incorrect type annotation. + + + + + + + +```python +langchain_core.tools.base.BaseTool._arun( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +async + +Use the tool asynchronously. + +Add `run_manager: AsyncCallbackManagerForToolRun | None = None` to child +implementations to enable tracing. + +**Returns:** `Any` + +The result of the tool execution. + + + + + + + +```python +langchain_core.tools.base.BaseTool._filter_injected_args( + tool_input: dict +) -> dict +``` + + + + + + +Filter out injected tool arguments from the input dictionary. + +Injected arguments are those annotated with `InjectedToolArg` or its +subclasses, or arguments in `FILTERED_ARGS` like `run_manager` and callbacks. + +**Parameters:** + + +The tool input dictionary to filter. + + +**Returns:** `dict` + +A filtered dictionary with injected arguments removed. + + + + + + + +```python +langchain_core.tools.base.BaseTool._parse_input( + tool_input: str | dict, + tool_call_id: str | None +) -> str | dict[str, typing.Any] +``` + + + + + + +Parse and validate tool input using the args schema. + +**Parameters:** + + +The raw input to the tool. + + + +The ID of the tool call, if available. + + +**Returns:** `str | dict[str, Any]` + +The parsed and validated input. + +**Raises:** + +- `ValueError`: If `string` input is provided with JSON schema `args_schema`. +- `ValueError`: If `InjectedToolCallId` is required but `tool_call_id` is not +provided. +- `TypeError`: If `args_schema` is not a Pydantic `BaseModel` or dict. + + + + + + + +```python +langchain_core.tools.base.BaseTool._run( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +abstract + +Use the tool. + +Add `run_manager: CallbackManagerForToolRun | None = None` to child +implementations to enable tracing. + +**Returns:** `Any` + +The result of the tool execution. + + + + + + + +```python +langchain_core.tools.base.BaseTool._to_args_and_kwargs( + tool_input: str | dict, + tool_call_id: str | None +) -> tuple[tuple, dict] +``` + + + + + + +Convert tool input to positional and keyword arguments. + +**Parameters:** + + +The input to the tool. + + + +The ID of the tool call, if available. + + +**Returns:** `tuple[tuple, dict]` + +A tuple of `(positional_args, keyword_args)` for the tool. + +**Raises:** + +- `TypeError`: If the tool input type is invalid. + + + + + + + +```python +langchain_core.tools.base.BaseTool.ainvoke( + input: str | dict | langchain_core.messages.tool.ToolCall, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +async + + + + + + + +```python +langchain_core.tools.base.BaseTool.arun( + tool_input: str | dict, + verbose: bool | None = None, + start_color: str | None = 'green', + color: str | None = 'green', + callbacks: langchain_core.callbacks.Callbacks = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + run_name: str | None = None, + run_id: uuid.UUID | None = None, + config: langchain_core.runnables.RunnableConfig | None = None, + tool_call_id: str | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +async + +Run the tool asynchronously. + +**Parameters:** + + +The input to the tool. + + + +Whether to log the tool's progress. + + + +The color to use when starting the tool. + + + +The color to use when ending the tool. + + + +Callbacks to be called during tool execution. + + + +Optional list of tags associated with the tool. + + + +Optional metadata associated with the tool. + + + +The name of the run. + + + +The id of the run. + + + +The configuration for the tool. + + + +The id of the tool call. + + + +Keyword arguments to be passed to tool callbacks + + +**Returns:** `Any` + +The output of the tool. + +**Raises:** + +- `ToolException`: If an error occurs during tool execution. + + + + + + + +```python +langchain_core.tools.base.BaseTool.get_input_schema( + config: langchain_core.runnables.RunnableConfig | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +The tool's input schema. + +**Parameters:** + + +The configuration for the tool. + + +**Returns:** `type[BaseModel]` + +The input schema for the tool. + + + + + + + +```python +langchain_core.tools.base.BaseTool.invoke( + input: str | dict | langchain_core.messages.tool.ToolCall, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + + + + + + + +```python +langchain_core.tools.base.BaseTool.run( + tool_input: str | dict[str, typing.Any], + verbose: bool | None = None, + start_color: str | None = 'green', + color: str | None = 'green', + callbacks: langchain_core.callbacks.Callbacks = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + run_name: str | None = None, + run_id: uuid.UUID | None = None, + config: langchain_core.runnables.RunnableConfig | None = None, + tool_call_id: str | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Run the tool. + +**Parameters:** + + +The input to the tool. + + + +Whether to log the tool's progress. + + + +The color to use when starting the tool. + + + +The color to use when ending the tool. + + + +Callbacks to be called during tool execution. + + + +Optional list of tags associated with the tool. + + + +Optional metadata associated with the tool. + + + +The name of the run. + + + +The id of the run. + + + +The configuration for the tool. + + + +The id of the tool call. + + + +Keyword arguments to be passed to tool callbacks (event handler) + + +**Returns:** `Any` + +The output of the tool. + +**Raises:** + +- `ToolException`: If an error occurs during tool execution. + + + + + + + + + +```python +class langchain_core.tools.base.BaseToolkit() +``` + + + + + + +Abstract + +**Bases:** `BaseModel` + +Base class for toolkits containing related tools. + +A toolkit is a collection of related tools that can be used together to accomplish a +specific task or work with a particular system. + + + + + + +```python +langchain_core.tools.base.BaseToolkit.get_tools() -> list[langchain_core.tools.base.BaseTool] +``` + + + + + + +abstract + +Get all tools in the toolkit. + +**Returns:** `list[BaseTool]` + +List of tools contained in this toolkit. + + + + + + + + + +```python +class langchain_core.tools.base.InjectedToolArg() +``` + + + + + + +Annotation for tool arguments that are injected at runtime. + +Tool arguments annotated with this class are not included in the tool +schema sent to language models and are instead injected during execution. + + + + + + + + +```python +class langchain_core.tools.base.InjectedToolCallId() +``` + + + + + + +**Bases:** [InjectedToolArg](#langchain_core-tools-base-InjectedToolArg) + +Annotation for injecting the tool call ID. + +This annotation is used to mark a tool parameter that should receive the tool call +ID at runtime. + + + +```python +from typing import Annotated +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool, InjectedToolCallId + +@tool +def foo( + x: int, tool_call_id: Annotated[str, InjectedToolCallId] +) -> ToolMessage: + """Return x.""" + return ToolMessage( + str(x), + artifact=x, + name="foo", + tool_call_id=tool_call_id + ) +``` + + + + + + + + + + +```python +class langchain_core.tools.base.SchemaAnnotationError() +``` + + + + + + +**Bases:** `TypeError` + +Raised when `args_schema` is missing or has an incorrect type annotation. + + + + + + + + +```python +class langchain_core.tools.base.ToolException() +``` + + + + + + +Exception + +**Bases:** `Exception` + +Exception thrown when a tool execution error occurs. + +This exception allows tools to signal errors without stopping the agent. + +The error is handled according to the tool's `handle_tool_error` setting, and the +result is returned as an observation to the agent. + + + + + + + + +```python +class langchain_core.tools.base._DirectlyInjectedToolArg() +``` + + + + + + +Annotation for tool arguments that are injected at runtime. + +Injected via direct type annotation, rather than annotated metadata. + +For example, `ToolRuntime` is a directly injected argument. + +Note the direct annotation rather than the verbose alternative: +`Annotated[ToolRuntime, InjectedRuntime]` + + + +```python +from langchain_core.tools import tool, ToolRuntime + + +@tool +def foo(x: int, runtime: ToolRuntime) -> str: + # use runtime.state, runtime.context, runtime.store, etc. + ... +``` + + + + + + + + + + +```python +class langchain_core.tools.base._SchemaConfig() +``` + + + + + + +Configuration for Pydantic models generated from function signatures. + + + +Whether to allow arbitrary types in the model. + + + +Whether to allow extra fields in the model. + + + + + + + +```python +langchain_core.tools.base._format_output( + content: typing.Any, + artifact: typing.Any, + tool_call_id: str | None, + name: str, + status: str +) -> langchain_core.messages.tool.ToolOutputMixin | typing.Any +``` + + + + + + +Format tool output as a `ToolMessage` if appropriate. + +**Parameters:** + + +The main content of the tool output. + + + +Any artifact data from the tool. + + + +The ID of the tool call. + + + +The name of the tool. + + + +The execution status. + + +**Returns:** `ToolOutputMixin | Any` + +The formatted output, either as a `ToolMessage` or the original content. + + + + + + + + +```python +langchain_core.tools.base._function_annotations_are_pydantic_v1( + signature: inspect.Signature, + func: collections.abc.Callable +) -> bool +``` + + + + + + +Check if all Pydantic annotations in a function are from v1. + +**Parameters:** + + +The function signature to check. + + + +The function being checked. + + +**Returns:** `bool` + +True if all Pydantic annotations are from v1, `False` otherwise. + +**Raises:** + +- `NotImplementedError`: If the function contains mixed v1 and v2 annotations. + + + + + + + + +```python +langchain_core.tools.base._get_annotation_description( + arg_type: type +) -> str | None +``` + + + + + + +Extract description from an `Annotated` type. + +Checks for string annotations and `FieldInfo` objects with descriptions. + +**Parameters:** + + +The type to extract description from. + + +**Returns:** `str | None` + +The description string if found, `None` otherwise. + + + + + + + + +```python +langchain_core.tools.base._get_filtered_args( + inferred_model: type[pydantic.BaseModel], + func: collections.abc.Callable, + filter_args: collections.abc.Sequence[str], + include_injected: bool = True +) -> dict +``` + + + + + + +Get filtered arguments from a function's signature. + +**Parameters:** + + +The Pydantic model inferred from the function. + + + +The function to extract arguments from. + + + +Arguments to exclude from the result. + + + +Whether to include injected arguments. + + +**Returns:** `dict` + +Dictionary of filtered arguments with their schema definitions. + + + + + + + + +```python +langchain_core.tools.base._get_runnable_config_param( + func: collections.abc.Callable +) -> str | None +``` + + + + + + +Find the parameter name for `RunnableConfig` in a function. + +**Parameters:** + + +The function to check. + + +**Returns:** `str | None` + +The parameter name for `RunnableConfig`, or `None` if not found. + + + + + + + + +```python +langchain_core.tools.base._get_type_hints( + func: collections.abc.Callable +) -> dict[str, type] | None +``` + + + + + + +Get type hints from a function, handling partial functions. + +**Parameters:** + + +The function to get type hints from. + + +**Returns:** `dict[str, type] | None` + +`dict` of type hints, or `None` if extraction fails. + + + + + + + + +```python +langchain_core.tools.base._handle_tool_error( + e: langchain_core.tools.base.ToolException, + flag: typing.Literal[True] | str | collections.abc.Callable[[ToolException], str] | None +) -> str +``` + + + + + + +Handle tool execution errors based on the configured flag. + +**Parameters:** + + +The tool exception that occurred. + + + +How to handle the error (`bool`, `str`, or `Callable`). + + +**Returns:** `str` + +The error message to return. + +**Raises:** + +- `ValueError`: If the flag type is unexpected. + + + + + + + + +```python +langchain_core.tools.base._handle_validation_error( + e: pydantic.ValidationError | pydantic.v1.ValidationError, + flag: typing.Literal[True] | str | collections.abc.Callable[[ValidationError | ValidationErrorV1], str] +) -> str +``` + + + + + + +Handle validation errors based on the configured flag. + +**Parameters:** + + +The validation error that occurred. + + + +How to handle the error (`bool`, `str`, or `Callable`). + + +**Returns:** `str` + +The error message to return. + +**Raises:** + +- `ValueError`: If the flag type is unexpected. + + + + + + + + +```python +langchain_core.tools.base._infer_arg_descriptions( + fn: collections.abc.Callable, + parse_docstring: bool = False, + error_on_invalid_docstring: bool = False +) -> tuple[str, dict] +``` + + + + + + +Infer argument descriptions from function docstring and annotations. + +**Parameters:** + + +The function to infer descriptions from. + + + +Whether to parse the docstring for descriptions. + + + +Whether to raise error on invalid docstring. + + +**Returns:** `tuple[str, dict]` + +A tuple containing the function description and argument descriptions. + + + + + + + + +```python +langchain_core.tools.base._is_annotated_type( + typ: type[typing.Any] +) -> bool +``` + + + + + + +Check if a type is an `Annotated` type. + +**Parameters:** + + +The type to check. + + +**Returns:** `bool` + +`True` if the type is an `Annotated` type, `False` otherwise. + + + + + + + + +```python +langchain_core.tools.base._is_directly_injected_arg_type( + type_: typing.Any +) -> bool +``` + + + + + + +Check if a type annotation indicates a directly injected argument. + +This is currently only used for `ToolRuntime`. + +Checks if either the annotation itself is a subclass of `_DirectlyInjectedToolArg` +or the origin of the annotation is a subclass of `_DirectlyInjectedToolArg`. + +For example, `ToolRuntime` or `ToolRuntime[ContextT, StateT]` would both return +`True`. + + + + + + + + +```python +langchain_core.tools.base._is_injected_arg_type( + type_: type | typing.TypeVar, + injected_type: type[langchain_core.tools.base.InjectedToolArg] | None = None +) -> bool +``` + + + + + + +Check if a type annotation indicates an injected argument. + +**Parameters:** + + +The type annotation to check. + + + +The specific injected type to check for. + + +**Returns:** `bool` + +`True` if the type is an injected argument, `False` otherwise. + + + + + + + + +```python +langchain_core.tools.base._is_message_content_block( + obj: typing.Any +) -> bool +``` + + + + + + +Check if object is a valid message content block. + +Validates content blocks for OpenAI or Anthropic format. + +**Parameters:** + + +The object to check. + + +**Returns:** `bool` + +`True` if the object is a valid content block, `False` otherwise. + + + + + + + + +```python +langchain_core.tools.base._is_message_content_type( + obj: typing.Any +) -> bool +``` + + + + + + +Check if object is valid message content format. + +Validates content for OpenAI or Anthropic format tool messages. + +**Parameters:** + + +The object to check. + + +**Returns:** `bool` + +`True` if the object is valid message content, `False` otherwise. + + + + + + + + +```python +langchain_core.tools.base._is_pydantic_annotation( + annotation: typing.Any, + pydantic_version: str = 'v2' +) -> bool +``` + + + + + + +Check if a type annotation is a Pydantic model. + +**Parameters:** + + +The type annotation to check. + + + +The Pydantic version to check against (`'v1'` or `'v2'`). + + +**Returns:** `bool` + +`True` if the annotation is a Pydantic model, `False` otherwise. + + + + + + + + +```python +langchain_core.tools.base._is_tool_call( + x: typing.Any +) -> bool +``` + + + + + + +Check if the input is a tool call dictionary. + +**Parameters:** + + +The input to check. + + +**Returns:** `bool` + +`True` if the input is a tool call, `False` otherwise. + + + + + + + + +```python +langchain_core.tools.base._parse_python_function_docstring( + function: collections.abc.Callable, + annotations: dict, + error_on_invalid_docstring: bool = False +) -> tuple[str, dict] +``` + + + + + + +Parse function and argument descriptions from a docstring. + +Assumes the function docstring follows Google Python style guide. + +**Parameters:** + + +The function to parse the docstring from. + + + +Type annotations for the function parameters. + + + +Whether to raise an error on invalid docstring. + + +**Returns:** `tuple[str, dict]` + +A tuple containing the function description and argument descriptions. + + + + + + + + +```python +langchain_core.tools.base._prep_run_args( + value: str | dict | langchain_core.messages.tool.ToolCall, + config: langchain_core.runnables.RunnableConfig | None, + kwargs: typing.Any = {} +) -> tuple[str | dict, dict] +``` + + + + + + +Prepare arguments for tool execution. + +**Parameters:** + + +The input value (`str`, `dict`, or `ToolCall`). + + + +The runnable configuration. + + + +Additional keyword arguments. + + +**Returns:** `tuple[str | dict, dict]` + +A tuple of `(tool_input, run_kwargs)`. + + + + + + + + +```python +langchain_core.tools.base._replace_type_vars( + type_: type | typing.TypeVar, + generic_map: dict[typing.TypeVar, type] | None = None, + default_to_bound: bool = True +) -> type | typing.TypeVar +``` + + + + + + +Replace `TypeVar`s in a type annotation with concrete types. + +**Parameters:** + + +The type annotation to process. + + + +Mapping of `TypeVar`s to concrete types. + + + +Whether to use `TypeVar` bounds as defaults. + + +**Returns:** `type | TypeVar` + +The type with `TypeVar`s replaced. + + + + + + + + +```python +langchain_core.tools.base._stringify( + content: typing.Any +) -> str +``` + + + + + + +Convert content to string, preferring JSON format. + +**Parameters:** + + +The content to stringify. + + +**Returns:** `str` + +String representation of the content. + + + + + + + + +```python +langchain_core.tools.base._validate_docstring_args_against_annotations( + arg_descriptions: dict, + annotations: dict +) -> None +``` + + + + + + +Validate that docstring arguments match function annotations. + +**Parameters:** + + +Arguments described in the docstring. + + + +Type annotations from the function signature. + + +**Raises:** + +- `ValueError`: If a docstring argument is not found in function signature. + + + + + + + + +```python +langchain_core.tools.base.create_schema_from_function( + model_name: str, + func: collections.abc.Callable, + filter_args: collections.abc.Sequence[str] | None = None, + parse_docstring: bool = False, + error_on_invalid_docstring: bool = False, + include_injected: bool = True +) -> type[pydantic.BaseModel] +``` + + + + + + +Create a Pydantic schema from a function's signature. + +**Parameters:** + + +Name to assign to the generated Pydantic schema. + + + +Function to generate the schema from. + + + +Optional list of arguments to exclude from the schema. + +Defaults to `FILTERED_ARGS`. + + + +Whether to parse the function's docstring for descriptions +for each argument. + + + +If `parse_docstring` is provided, configure +whether to raise `ValueError` on invalid Google Style docstrings. + + + +Whether to include injected arguments in the schema. + +Defaults to `True`, since we want to include them in the schema when +*validating* tool inputs. + + +**Returns:** `type[BaseModel]` + +A Pydantic model with the same arguments as the function. + + + + + + + + +```python +langchain_core.tools.base.get_all_basemodel_annotations( + cls: langchain_core.utils.pydantic.TypeBaseModel | typing.Any, + default_to_bound: bool = True +) -> dict[str, type | typing.TypeVar] +``` + + + + + + +Get all annotations from a Pydantic `BaseModel` and its parents. + +**Parameters:** + + +The Pydantic `BaseModel` class. + + + +Whether to default to the bound of a `TypeVar` if it exists. + + +**Returns:** `dict[str, type | TypeVar]` + +`dict` of field names to their type annotations. + + + + + + + + +```python +langchain_core.tools.base.ArgsSchema = TypeBaseModel | dict[str, Any] +``` + + + + + + + + + +```python +langchain_core.tools.base.FILTERED_ARGS = ('run_manager', 'callbacks') +``` + + + + + + + + + +```python +langchain_core.tools.base.TOOL_MESSAGE_BLOCK_TYPES = ('text', 'image_url', 'image', 'json', 'search_result', 'custom_tool_call_output... +``` + + + + + + + + + +```python +langchain_core.tools.base._EMPTY_SET: frozenset[str] = frozenset() +``` + + + + + + + + + +```python +langchain_core.tools.base._logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/convert.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/convert.mdx new file mode 100644 index 0000000..a12d0c0 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/convert.mdx @@ -0,0 +1,348 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tools/convert +title: langchain_core.tools.convert +--- + +Convert functions and runnables to tools. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_description_from_runnable`](#langchain_core-tools-convert-_get_description_from_runnable) | Generate a placeholder description of a `Runnable`. | +| [`_get_schema_from_runnable_and_arg_types`](#langchain_core-tools-convert-_get_schema_from_runnable_and_arg_types) | Infer `args_schema` for tool. | +| [`convert_runnable_to_tool`](#langchain_core-tools-convert-convert_runnable_to_tool) | Convert a `Runnable` into a `BaseTool`. | +| [`tool`](#langchain_core-tools-convert-tool) | Convert Python functions and `Runnables` to LangChain tools. | + +### API + + + + + +```python +langchain_core.tools.convert._get_description_from_runnable( + runnable: langchain_core.runnables.Runnable +) -> str +``` + + + + + + +Generate a placeholder description of a `Runnable`. + + + + + + + + +```python +langchain_core.tools.convert._get_schema_from_runnable_and_arg_types( + runnable: langchain_core.runnables.Runnable, + name: str, + arg_types: dict[str, type] | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Infer `args_schema` for tool. + + + + + + + + +```python +langchain_core.tools.convert.convert_runnable_to_tool( + runnable: langchain_core.runnables.Runnable, + args_schema: type[pydantic.BaseModel] | None = None, + name: str | None = None, + description: str | None = None, + arg_types: dict[str, type] | None = None +) -> langchain_core.tools.base.BaseTool +``` + + + + + + +Convert a `Runnable` into a `BaseTool`. + +**Parameters:** + + +The `Runnable` to convert. + + + +The schema for the tool's input arguments. + + + +The name of the tool. + + + +The description of the tool. + + + +The types of the arguments. + + +**Returns:** `BaseTool` + +The tool. + + + + + + + + +```python +langchain_core.tools.convert.tool( + name_or_callable: str | collections.abc.Callable | None = None, + runnable: langchain_core.runnables.Runnable | None = None, + args: typing.Any = (), + description: str | None = None, + return_direct: bool = False, + args_schema: langchain_core.tools.base.ArgsSchema | None = None, + infer_schema: bool = True, + response_format: typing.Literal['content', 'content_and_artifact'] = 'content', + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, + extras: dict[str, typing.Any] | None = None +) -> langchain_core.tools.base.BaseTool | collections.abc.Callable[[Callable | Runnable], langchain_core.tools.base.BaseTool] +``` + + + + + + +Convert Python functions and `Runnables` to LangChain tools. + +Can be used as a decorator with or without arguments to create tools from functions. + +Functions can have any signature - the tool will automatically infer input schemas +unless disabled. + +!!! note "Requirements" + + - Functions should have type hints for proper schema inference. + - Functions may accept multiple arguments and return types are flexible; + outputs will be serialized if needed. + - When using with `Runnable`, a string name must be provided. + +**Parameters:** + + +Optional name of the tool or the `Callable` to be +converted to a tool. + +Overrides the function's name. + +Must be provided as a positional argument. + + + +Optional `Runnable` to convert to a tool. + +Must be provided as a positional argument. + + + +Optional description for the tool. + +Precedence for the tool description value is as follows: + +- This `description` argument (used even if docstring and/or `args_schema` + are provided) +- Tool function docstring (used even if `args_schema` is provided) +- `args_schema` description (used only if `description` and docstring are + not provided) + + + +Extra positional arguments. + +Must be empty. + + + +Whether to return directly from the tool rather than continuing +the agent loop. + + + +Optional argument schema for user to specify. + + + +Whether to infer the schema of the arguments from the function's +signature. + +This also makes the resultant tool accept a dictionary input to its `run()` +function. + + + +The tool response format. + +If `'content'`, then the output of the tool is interpreted as the contents +of a `ToolMessage`. + +If `'content_and_artifact'`, then the output is expected to be a two-tuple +corresponding to the `(content, artifact)` of a `ToolMessage`. + + + +If `infer_schema` and `parse_docstring`, will attempt to +parse parameter descriptions from Google Style function docstrings. + + + +If `parse_docstring` is provided, configure +whether to raise `ValueError` on invalid Google Style docstrings. + + + +Optional provider-specific extra fields for the tool. + +Used to pass configuration that doesn't fit into standard tool fields. +Chat models should process known extras when constructing model payloads. + +!!! example + + For example, Anthropic-specific fields like `cache_control`, + `defer_loading`, or `input_examples`. + + +**Returns:** `BaseTool | Callable[[Callable | Runnable], BaseTool]` + +The tool. + +**Raises:** + +- `ValueError`: If too many positional arguments are provided (e.g. violating the +`*args` constraint). +- `ValueError`: If a `Runnable` is provided without a string name. When using `tool` +with a `Runnable`, a `str` name must be provided as the `name_or_callable`. +- `ValueError`: If the first argument is not a string or callable with +a `__name__` attribute. +- `ValueError`: If the function does not have a docstring and description +is not provided and `infer_schema` is `False`. +- `ValueError`: If `parse_docstring` is `True` and the function has an invalid +Google-style docstring and `error_on_invalid_docstring` is True. +- `ValueError`: If a `Runnable` is provided that does not have an object schema. + +**Examples:** + + + +```python +@tool +def search_api(query: str) -> str: + # Searches the API for the query. + return + + +@tool("search", return_direct=True) +def search_api(query: str) -> str: + # Searches the API for the query. + return + + +@tool(response_format="content_and_artifact") +def search_api(query: str) -> tuple[str, dict]: + return "partial json of results", {"full": "object of results"} +``` + +Parse Google-style docstrings: + +```python +@tool(parse_docstring=True) +def foo(bar: str, baz: int) -> str: + """The foo. + + Args: + bar: The bar. + baz: The baz. + """ + return bar + +foo.args_schema.model_json_schema() +``` + +```python +{ + "title": "foo", + "description": "The foo.", + "type": "object", + "properties": { + "bar": { + "title": "Bar", + "description": "The bar.", + "type": "string", + }, + "baz": { + "title": "Baz", + "description": "The baz.", + "type": "integer", + }, + }, + "required": ["bar", "baz"], +} +``` + +Note that parsing by default will raise `ValueError` if the docstring is +considered invalid. A docstring is considered invalid if it contains arguments +not in the function signature, or is unable to be parsed into a summary and +`'Args:'` blocks. Examples below: + +```python +# No args section +def invalid_docstring_1(bar: str, baz: int) -> str: + """The foo.""" + return bar + +# Improper whitespace between summary and args section +def invalid_docstring_2(bar: str, baz: int) -> str: + """The foo. + Args: + bar: The bar. + baz: The baz. + """ + return bar + +# Documented args absent from function signature +def invalid_docstring_3(bar: str, baz: int) -> str: + """The foo. + + Args: + banana: The bar. + monkey: The baz. + """ + return bar +``` + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/render.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/render.mdx new file mode 100644 index 0000000..0d95b93 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/render.mdx @@ -0,0 +1,116 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tools/render +title: langchain_core.tools.render +--- + +Utilities to render tools. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`render_text_description`](#langchain_core-tools-render-render_text_description) | Render the tool name and description in plain text. | +| [`render_text_description_and_args`](#langchain_core-tools-render-render_text_description_and_args) | Render the tool name, description, and args in plain text. | + +### Data + +[`ToolsRenderer`](#langchain_core-tools-render-ToolsRenderer) + +### API + + + + + +```python +langchain_core.tools.render.render_text_description( + tools: list[langchain_core.tools.base.BaseTool] +) -> str +``` + + + + + + +Render the tool name and description in plain text. + +Output will be in the format of: + + + +```python +search: This tool is used for search +calculator: This tool is used for math +``` + + + +**Parameters:** + + +The tools to render. + + +**Returns:** `str` + +The rendered text. + + + + + + + + +```python +langchain_core.tools.render.render_text_description_and_args( + tools: list[langchain_core.tools.base.BaseTool] +) -> str +``` + + + + + + +Render the tool name, description, and args in plain text. + +Output will be in the format of: + + + +```python +search: This tool is used for search, args: {"query": {"type": "string"}} +calculator: This tool is used for math, args: {"expression": {"type": "string"}} +``` + + + +**Parameters:** + + +The tools to render. + + +**Returns:** `str` + +The rendered text. + + + + + + + + +```python +langchain_core.tools.render.ToolsRenderer = Callable[[list[BaseTool]], str] +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/retriever.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/retriever.mdx new file mode 100644 index 0000000..97330d5 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/retriever.mdx @@ -0,0 +1,110 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tools/retriever +title: langchain_core.tools.retriever +--- + +Retriever tool. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RetrieverInput`](#langchain_core-tools-retriever-RetrieverInput) | Input to the retriever. | + +### Functions + +| Name | Description | +|------|-------------| +| [`create_retriever_tool`](#langchain_core-tools-retriever-create_retriever_tool) | Create a tool to do retrieval of documents. | + +### API + + + + + +```python +class langchain_core.tools.retriever.RetrieverInput() +``` + + + + + + +**Bases:** `BaseModel` + +Input to the retriever. + + + + + + + + + + +```python +langchain_core.tools.retriever.create_retriever_tool( + retriever: langchain_core.retrievers.BaseRetriever, + name: str, + description: str, + document_prompt: langchain_core.prompts.BasePromptTemplate | None = None, + document_separator: str = '\n\n', + response_format: typing.Literal['content', 'content_and_artifact'] = 'content' +) -> langchain_core.tools.structured.StructuredTool +``` + + + + + + +Create a tool to do retrieval of documents. + +**Parameters:** + + +The retriever to use for the retrieval + + + +The name for the tool. + +This will be passed to the language model, so should be unique and somewhat +descriptive. + + + +The description for the tool. + +This will be passed to the language model, so should be descriptive. + + + +The prompt to use for the document. + + + +The separator to use between documents. + + + +The tool response format. + +If `'content'` then the output of the tool is interpreted as the contents of +a `ToolMessage`. If `'content_and_artifact'` then the output is expected to +be a two-tuple corresponding to the `(content, artifact)` of a `ToolMessage` +(artifact being a list of documents in this case). + + +**Returns:** `StructuredTool` + +Tool class to pass to an agent. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/simple.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/simple.mdx new file mode 100644 index 0000000..9021e46 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/simple.mdx @@ -0,0 +1,270 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tools/simple +title: langchain_core.tools.simple +--- + +Tool that takes in function or coroutine directly. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Tool`](#langchain_core-tools-simple-Tool) | Tool that takes in function or coroutine directly. | + +### API + + + + + +```python +class langchain_core.tools.simple.Tool( + name: str, + func: collections.abc.Callable | None, + description: str, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseTool](/langchain-core/langchain_core/tools/base#langchain_core-tools-base-BaseTool) + +Tool that takes in function or coroutine directly. + + + +The tool's input arguments. + + + +The asynchronous version of the function. + + + + + + +The function to run when the tool is called. + + + + + +```python +langchain_core.tools.simple.Tool._arun( + args: typing.Any = (), + config: langchain_core.runnables.RunnableConfig, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForToolRun | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +async + +Use the tool asynchronously. + +**Parameters:** + + +Positional arguments to pass to the tool + + + +Configuration for the run + + + +Optional callback manager to use for the run + + + +Keyword arguments to pass to the tool + + +**Returns:** `Any` + +The result of the tool execution + + + + + + + +```python +langchain_core.tools.simple.Tool._run( + args: typing.Any = (), + config: langchain_core.runnables.RunnableConfig, + run_manager: langchain_core.callbacks.CallbackManagerForToolRun | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Use the tool. + +**Parameters:** + + +Positional arguments to pass to the tool + + + +Configuration for the run + + + +Optional callback manager to use for the run + + + +Keyword arguments to pass to the tool + + +**Returns:** `Any` + +The result of the tool execution + + + + + + + +```python +langchain_core.tools.simple.Tool._to_args_and_kwargs( + tool_input: str | dict, + tool_call_id: str | None +) -> tuple[tuple, dict] +``` + + + + + + +Convert tool input to Pydantic model. + +**Parameters:** + + +The input to the tool. + + + +The ID of the tool call. + + +**Returns:** `tuple[tuple, dict]` + +The Pydantic model args and kwargs. + +**Raises:** + +- `ToolException`: If the tool input is invalid. + + + + + + + +```python +langchain_core.tools.simple.Tool.ainvoke( + input: str | dict | langchain_core.messages.ToolCall, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +async + + + + + + + +```python +langchain_core.tools.simple.Tool.from_function( + func: collections.abc.Callable | None, + name: str, + description: str, + return_direct: bool = False, + args_schema: langchain_core.tools.base.ArgsSchema | None = None, + coroutine: collections.abc.Callable[..., collections.abc.Awaitable[typing.Any]] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tools.simple.Tool +``` + + + + + + +classmethod + +Initialize tool from a function. + +**Parameters:** + + +The function to create the tool from. + + + +The name of the tool. + + + +The description of the tool. + + + +Whether to return the output directly. + + + +The schema of the tool's input arguments. + + + +The asynchronous version of the function. + + + +Additional arguments to pass to the tool. + + +**Returns:** `Tool` + +The tool. + +**Raises:** + +- `ValueError`: If the function is not provided. + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/structured.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/structured.mdx new file mode 100644 index 0000000..d218efb --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tools/structured.mdx @@ -0,0 +1,304 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tools/structured +title: langchain_core.tools.structured +--- + +Structured tool. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`StructuredTool`](#langchain_core-tools-structured-StructuredTool) | Tool that can operate on any number of inputs. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_filter_schema_args`](#langchain_core-tools-structured-_filter_schema_args) | - | + +### API + + + + + +```python +class langchain_core.tools.structured.StructuredTool() +``` + + + + + + +**Bases:** [BaseTool](/langchain-core/langchain_core/tools/base#langchain_core-tools-base-BaseTool) + +Tool that can operate on any number of inputs. + + + + + + +The input arguments' schema. + + + +The asynchronous version of the function. + + + + + + +The function to run when the tool is called. + + + + + +```python +langchain_core.tools.structured.StructuredTool._arun( + args: typing.Any = (), + config: langchain_core.runnables.RunnableConfig, + run_manager: langchain_core.callbacks.AsyncCallbackManagerForToolRun | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +async + +Use the tool asynchronously. + +**Parameters:** + + +Positional arguments to pass to the tool + + + +Configuration for the run + + + +Optional callback manager to use for the run + + + +Keyword arguments to pass to the tool + + +**Returns:** `Any` + +The result of the tool execution + + + + + + + +```python +langchain_core.tools.structured.StructuredTool._run( + args: typing.Any = (), + config: langchain_core.runnables.RunnableConfig, + run_manager: langchain_core.callbacks.CallbackManagerForToolRun | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +Use the tool. + +**Parameters:** + + +Positional arguments to pass to the tool + + + +Configuration for the run + + + +Optional callback manager to use for the run + + + +Keyword arguments to pass to the tool + + +**Returns:** `Any` + +The result of the tool execution + + + + + + + +```python +langchain_core.tools.structured.StructuredTool.ainvoke( + input: str | dict | langchain_core.messages.ToolCall, + config: langchain_core.runnables.RunnableConfig | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +async + + + + + + + +```python +langchain_core.tools.structured.StructuredTool.from_function( + func: collections.abc.Callable | None = None, + coroutine: collections.abc.Callable[..., collections.abc.Awaitable[typing.Any]] | None = None, + name: str | None = None, + description: str | None = None, + return_direct: bool = False, + args_schema: langchain_core.tools.base.ArgsSchema | None = None, + infer_schema: bool = True, + response_format: typing.Literal['content', 'content_and_artifact'] = 'content', + parse_docstring: bool = False, + error_on_invalid_docstring: bool = False, + kwargs: typing.Any = {} +) -> langchain_core.tools.structured.StructuredTool +``` + + + + + + +classmethod + +Create tool from a given function. + +A classmethod that helps to create a tool from a function. + +**Parameters:** + + +The function from which to create a tool. + + + +The async function from which to create a tool. + + + +The name of the tool. + +Defaults to the function name. + + + +The description of the tool. + +Defaults to the function docstring. + + + +Whether to return the result directly or as a callback. + + + +The schema of the tool's input arguments. + + + +Whether to infer the schema from the function's signature. + + + +The tool response format. + +If `'content'` then the output of the tool is interpreted as the +contents of a `ToolMessage`. If `'content_and_artifact'` then the output +is expected to be a two-tuple corresponding to the `(content, artifact)` +of a `ToolMessage`. + + + +If `infer_schema` and `parse_docstring`, will attempt +to parse parameter descriptions from Google Style function docstrings. + + + +if `parse_docstring` is provided, configure +whether to raise `ValueError` on invalid Google Style docstrings. + + + +Additional arguments to pass to the tool + + +**Returns:** `StructuredTool` + +The tool. + +**Raises:** + +- `ValueError`: If the function is not provided. +- `ValueError`: If the function does not have a docstring and description +is not provided. +- `TypeError`: If the `args_schema` is not a `BaseModel` or dict. + +**Examples:** + + + +```python +def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b +tool = StructuredTool.from_function(add) +tool.run(1, 2) # 3 +``` + + + + + + + + + + + +```python +langchain_core.tools.structured._filter_schema_args( + func: collections.abc.Callable +) -> list[str] +``` + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers.mdx new file mode 100644 index 0000000..5999c27 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers.mdx @@ -0,0 +1,99 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers +title: langchain_core.tracers +--- + +Tracers are classes for tracing runs. + +## Submodules + +- **[`langchain_core.tracers._compat`](/langchain-core/langchain_core/tracers/_compat)** +- **[`langchain_core.tracers._streaming`](/langchain-core/langchain_core/tracers/_streaming)** +- **[`langchain_core.tracers.base`](/langchain-core/langchain_core/tracers/base)** +- **[`langchain_core.tracers.context`](/langchain-core/langchain_core/tracers/context)** +- **[`langchain_core.tracers.core`](/langchain-core/langchain_core/tracers/core)** +- **[`langchain_core.tracers.evaluation`](/langchain-core/langchain_core/tracers/evaluation)** +- **[`langchain_core.tracers.event_stream`](/langchain-core/langchain_core/tracers/event_stream)** +- **[`langchain_core.tracers.langchain`](/langchain-core/langchain_core/tracers/langchain)** +- **[`langchain_core.tracers.log_stream`](/langchain-core/langchain_core/tracers/log_stream)** +- **[`langchain_core.tracers.memory_stream`](/langchain-core/langchain_core/tracers/memory_stream)** +- **[`langchain_core.tracers.root_listeners`](/langchain-core/langchain_core/tracers/root_listeners)** +- **[`langchain_core.tracers.run_collector`](/langchain-core/langchain_core/tracers/run_collector)** +- **[`langchain_core.tracers.schemas`](/langchain-core/langchain_core/tracers/schemas)** +- **[`langchain_core.tracers.stdout`](/langchain-core/langchain_core/tracers/stdout)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-tracers-__dir__) | - | +| [`__getattr__`](#langchain_core-tracers-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-tracers-__all__) + +[`_dynamic_imports`](#langchain_core-tracers-_dynamic_imports) + +### API + + + + + +```python +langchain_core.tracers.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.tracers.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.tracers.__all__ = ('BaseTracer', 'ConsoleCallbackHandler', 'EvaluatorCallbackHandler', 'LangChainT... +``` + + + + + + + + + +```python +langchain_core.tracers._dynamic_imports = {'BaseTracer': 'base', 'EvaluatorCallbackHandler': 'evaluation', 'LangChainTrace... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/_compat.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/_compat.mdx new file mode 100644 index 0000000..f4fd3eb --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/_compat.mdx @@ -0,0 +1,229 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/_compat +title: langchain_core.tracers._compat +--- + +Compatibility helpers for Pydantic v1/v2 with langsmith `Run` objects. + +!!! note + + The generic helpers (`pydantic_to_dict`, `pydantic_copy`) detect Pydanti version + based on the langsmith `Run` model. They're intended for langsmith objects (`Run`, + `Example`) which migrate together. + +For general Pydantic v1/v2 handling, see `langchain_core.utils.pydantic`. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`pydantic_copy`](#langchain_core-tracers-_compat-pydantic_copy) | Copy any Pydantic model, compatible with both v1 and v2. | +| [`pydantic_to_dict`](#langchain_core-tracers-_compat-pydantic_to_dict) | Convert any Pydantic model to dict, compatible with both v1 and v2. | +| [`run_construct`](#langchain_core-tracers-_compat-run_construct) | Construct run without validation, compatible with both Pydantic v1 and v2. | +| [`run_copy`](#langchain_core-tracers-_compat-run_copy) | Copy run, compatible with both Pydantic v1 and v2. | +| [`run_to_dict`](#langchain_core-tracers-_compat-run_to_dict) | Convert run to dict, compatible with both Pydantic v1 and v2. | + +### Data + +[`T`](#langchain_core-tracers-_compat-T) + +[`_RUN_IS_PYDANTIC_V2`](#langchain_core-tracers-_compat-_RUN_IS_PYDANTIC_V2) + +### API + + + + + +```python +langchain_core.tracers._compat.pydantic_copy( + obj: langchain_core.tracers._compat.T, + kwargs: typing.Any = {} +) -> langchain_core.tracers._compat.T +``` + + + + + + +Copy any Pydantic model, compatible with both v1 and v2. + +**Parameters:** + + +The Pydantic model to copy. + + + +Additional arguments passed to `model_copy`/`copy`. + + +**Returns:** `T` + +A copy of the model. + + + + + + + + +```python +langchain_core.tracers._compat.pydantic_to_dict( + obj: typing.Any, + kwargs: typing.Any = {} +) -> dict[str, typing.Any] +``` + + + + + + +Convert any Pydantic model to dict, compatible with both v1 and v2. + +**Parameters:** + + +The Pydantic model to convert. + + + +Additional arguments passed to `model_dump`/`dict`. + + +**Returns:** `dict[str, Any]` + +Dictionary representation of the model. + + + + + + + + +```python +langchain_core.tracers._compat.run_construct( + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Construct run without validation, compatible with both Pydantic v1 and v2. + +**Parameters:** + + +Fields to set on the run. + + +**Returns:** `Run` + +A new `Run` instance constructed without validation. + + + + + + + + +```python +langchain_core.tracers._compat.run_copy( + run: langchain_core.tracers.schemas.Run, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Copy run, compatible with both Pydantic v1 and v2. + +**Parameters:** + + +The run to copy. + + + +Additional arguments passed to `model_copy`/`copy`. + + +**Returns:** `Run` + +A copy of the run. + + + + + + + + +```python +langchain_core.tracers._compat.run_to_dict( + run: langchain_core.tracers.schemas.Run, + kwargs: typing.Any = {} +) -> dict[str, typing.Any] +``` + + + + + + +Convert run to dict, compatible with both Pydantic v1 and v2. + +**Parameters:** + + +The run to convert. + + + +Additional arguments passed to `model_dump`/`dict`. + + +**Returns:** `dict[str, Any]` + +Dictionary representation of the run. + + + + + + + + +```python +langchain_core.tracers._compat.T = TypeVar('T') +``` + + + + + + + + + +```python +langchain_core.tracers._compat._RUN_IS_PYDANTIC_V2 = hasattr(Run, 'model_dump') +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/_streaming.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/_streaming.mdx new file mode 100644 index 0000000..3e3d3e8 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/_streaming.mdx @@ -0,0 +1,115 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/_streaming +title: langchain_core.tracers._streaming +--- + +Internal tracers used for `stream_log` and `astream` events implementations. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`_StreamingCallbackHandler`](#langchain_core-tracers-_streaming-_StreamingCallbackHandler) | Types for streaming callback handlers. | + +### Data + +[`T`](#langchain_core-tracers-_streaming-T) + +[`__all__`](#langchain_core-tracers-_streaming-__all__) + +### API + + + + + +```python +class langchain_core.tracers._streaming._StreamingCallbackHandler() +``` + + + + + + +Protocol + +**Bases:** `Protocol[T]` + +Types for streaming callback handlers. + +This is a common mixin that the callback handlers for both astream events and +astream log inherit from. + +The `tap_output_aiter` method is invoked in some contexts to produce callbacks for +intermediate results. + + + + + + +```python +langchain_core.tracers._streaming._StreamingCallbackHandler.tap_output_aiter( + run_id: uuid.UUID, + output: collections.abc.AsyncIterator[langchain_core.tracers._streaming.T] +) -> collections.abc.AsyncIterator[langchain_core.tracers._streaming.T] +``` + + + + + + +Used for internal astream_log and astream events implementations. + + + + + + + +```python +langchain_core.tracers._streaming._StreamingCallbackHandler.tap_output_iter( + run_id: uuid.UUID, + output: collections.abc.Iterator[langchain_core.tracers._streaming.T] +) -> collections.abc.Iterator[langchain_core.tracers._streaming.T] +``` + + + + + + +Used for internal astream_log and astream events implementations. + + + + + + + + + +```python +langchain_core.tracers._streaming.T = typing.TypeVar('T') +``` + + + + + + + + + +```python +langchain_core.tracers._streaming.__all__ = ['_StreamingCallbackHandler'] +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/base.mdx new file mode 100644 index 0000000..568c5e0 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/base.mdx @@ -0,0 +1,1678 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/base +title: langchain_core.tracers.base +--- + +Base interfaces for tracing runs. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AsyncBaseTracer`](#langchain_core-tracers-base-AsyncBaseTracer) | Async base interface for tracers. | +| [`BaseTracer`](#langchain_core-tracers-base-BaseTracer) | Base interface for tracers. | + +### Data + +[`logger`](#langchain_core-tracers-base-logger) + +### API + + + + + +```python +class langchain_core.tracers.base.AsyncBaseTracer() +``` + + + + + + +Abstract + +**Bases:** [_TracerCore](/langchain-core/langchain_core/tracers/core#langchain_core-tracers-core-_TracerCore), [AsyncCallbackHandler](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-AsyncCallbackHandler) + +Async base interface for tracers. + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._end_trace( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +End a trace for a run. + +Ending a trace will run concurrently with each `_on_[run_type]_end` method. +No `_on_[run_type]_end` callback should depend on operations in `_end_trace`. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_chain_end( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the Chain Run. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_chain_error( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the Chain Run upon error. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_chain_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the Chain Run upon start. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_chat_model_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the Chat Model Run upon start. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_llm_end( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the LLM Run. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_llm_error( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the LLM Run upon error. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_llm_new_token( + run: langchain_core.tracers.schemas.Run, + token: str, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None +) -> None +``` + + + + + + +async + +Process new LLM token. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_llm_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the LLM Run upon start. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_retriever_end( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the Retriever Run. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_retriever_error( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the Retriever Run upon error. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_retriever_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the Retriever Run upon start. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_run_create( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process a run upon creation. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_run_update( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process a run upon update. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_tool_end( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the Tool Run. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_tool_error( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the Tool Run upon error. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._on_tool_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Process the Tool Run upon start. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._persist_run( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async abstract + +Persist a run. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer._start_trace( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + +Start a trace for a run. + +Starting a trace will run concurrently with each `_on_[run_type]_start` method. +No `_on_[run_type]_start` callback should depend on operations in +`_start_trace`. + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_chain_end( + outputs: dict[str, typing.Any], + run_id: uuid.UUID, + inputs: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_chain_error( + error: BaseException, + inputs: dict[str, typing.Any] | None = None, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_chain_start( + serialized: dict[str, typing.Any], + inputs: dict[str, typing.Any], + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + run_type: str | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_chat_model_start( + serialized: dict[str, typing.Any], + messages: list[list[langchain_core.messages.BaseMessage]], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> typing.Any +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_llm_end( + response: langchain_core.outputs.LLMResult, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_llm_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_llm_new_token( + token: str, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None = None, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_llm_start( + serialized: dict[str, typing.Any], + prompts: list[str], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_retriever_end( + documents: collections.abc.Sequence[langchain_core.documents.Document], + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_retriever_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_retriever_start( + serialized: dict[str, typing.Any], + query: str, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_retry( + retry_state: tenacity.RetryCallState, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_tool_end( + output: typing.Any, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_tool_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.base.AsyncBaseTracer.on_tool_start( + serialized: dict[str, typing.Any], + input_str: str, + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + inputs: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + + + +```python +class langchain_core.tracers.base.BaseTracer() +``` + + + + + + +Abstract + +**Bases:** [_TracerCore](/langchain-core/langchain_core/tracers/core#langchain_core-tracers-core-_TracerCore), [BaseCallbackHandler](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-BaseCallbackHandler) + +Base interface for tracers. + + + + + + +```python +langchain_core.tracers.base.BaseTracer.__copy__() -> langchain_core.tracers.base.BaseTracer +``` + + + + + + +Return self. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.__deepcopy__( + memo: dict +) -> langchain_core.tracers.base.BaseTracer +``` + + + + + + +Return self. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer._end_trace( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +End a trace for a run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer._persist_run( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +abstract + +Persist a run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer._start_trace( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Start a trace for a run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_chain_end( + outputs: dict[str, typing.Any], + run_id: uuid.UUID, + inputs: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +End a trace for a chain run. + +**Parameters:** + + +The outputs for the chain. + + + +The run ID. + + + +The inputs for the chain. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_chain_error( + error: BaseException, + inputs: dict[str, typing.Any] | None = None, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Handle an error for a chain run. + +**Parameters:** + + +The error. + + + +The inputs for the chain. + + + +The run ID. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_chain_start( + serialized: dict[str, typing.Any], + inputs: dict[str, typing.Any], + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + run_type: str | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Start a trace for a chain run. + +**Parameters:** + + +The serialized chain. + + + +The inputs for the chain. + + + +The run ID. + + + +The tags for the run. + + + +The parent run ID. + + + +The metadata for the run. + + + +The type of the run. + + + +The name of the run. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_chat_model_start( + serialized: dict[str, typing.Any], + messages: list[list[langchain_core.messages.BaseMessage]], + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Start a trace for an LLM run. + +**Parameters:** + + +The serialized model. + + + +The messages to start the chat with. + + + +The run ID. + + + +The tags for the run. + + + +The parent run ID. + + + +The metadata for the run. + + + +The name of the run. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_llm_end( + response: langchain_core.outputs.LLMResult, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +End a trace for an LLM run. + +**Parameters:** + + +The response. + + + +The run ID. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_llm_error( + error: BaseException, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Handle an error for an LLM run. + +**Parameters:** + + +The error. + + + +The run ID. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_llm_new_token( + token: str, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None = None, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Run on new LLM token. + +Only available when streaming is enabled. + +**Parameters:** + + +The token. + + + +The chunk. + + + +The run ID. + + + +The parent run ID. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_llm_start( + serialized: dict[str, typing.Any], + prompts: list[str], + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Start a trace for an LLM run. + +**Parameters:** + + +The serialized model. + + + +The prompts to start the LLM with. + + + +The run ID. + + + +The tags for the run. + + + +The parent run ID. + + + +The metadata for the run. + + + +The name of the run. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_retriever_end( + documents: collections.abc.Sequence[langchain_core.documents.Document], + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Run when the `Retriever` ends running. + +**Parameters:** + + +The documents. + + + +The run ID. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_retriever_error( + error: BaseException, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Run when `Retriever` errors. + +**Parameters:** + + +The error. + + + +The run ID. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_retriever_start( + serialized: dict[str, typing.Any], + query: str, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Run when the `Retriever` starts running. + +**Parameters:** + + +The serialized retriever. + + + +The query. + + + +The run ID. + + + +The parent run ID. + + + +The tags for the run. + + + +The metadata for the run. + + + +The name of the run. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_retry( + retry_state: tenacity.RetryCallState, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Run on retry. + +**Parameters:** + + +The retry state. + + + +The run ID. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_tool_end( + output: typing.Any, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +End a trace for a tool run. + +**Parameters:** + + +The output for the tool. + + + +The run ID. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_tool_error( + error: BaseException, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Handle an error for a tool run. + +**Parameters:** + + +The error. + + + +The run ID. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.base.BaseTracer.on_tool_start( + serialized: dict[str, typing.Any], + input_str: str, + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + inputs: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Start a trace for a tool run. + +**Parameters:** + + +The serialized tool. + + + +The input string. + + + +The run ID. + + + +The tags for the run. + + + +The parent run ID. + + + +The metadata for the run. + + + +The name of the run. + + + +The inputs for the tool. + + + +Additional arguments. + + +**Returns:** `Run` + +The run. + + + + + + + + + +```python +langchain_core.tracers.base.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/context.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/context.mdx new file mode 100644 index 0000000..76282dd --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/context.mdx @@ -0,0 +1,239 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/context +title: langchain_core.tracers.context +--- + +Context management for tracers. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_trace_callbacks`](#langchain_core-tracers-context-_get_trace_callbacks) | - | +| [`_get_tracer_project`](#langchain_core-tracers-context-_get_tracer_project) | - | +| [`_tracing_v2_is_enabled`](#langchain_core-tracers-context-_tracing_v2_is_enabled) | - | +| [`collect_runs`](#langchain_core-tracers-context-collect_runs) | Collect all run traces in context. | +| [`register_configure_hook`](#langchain_core-tracers-context-register_configure_hook) | Register a configure hook. | +| [`tracing_v2_enabled`](#langchain_core-tracers-context-tracing_v2_enabled) | Instruct LangChain to log all runs in context to LangSmith. | + +### Data + +[`_configure_hooks`](#langchain_core-tracers-context-_configure_hooks) + +[`run_collector_var`](#langchain_core-tracers-context-run_collector_var) + +[`tracing_callback_var`](#langchain_core-tracers-context-tracing_callback_var) + +[`tracing_v2_callback_var`](#langchain_core-tracers-context-tracing_v2_callback_var) + +### API + + + + + +```python +langchain_core.tracers.context._get_trace_callbacks( + project_name: str | None = None, + example_id: str | uuid.UUID | None = None, + callback_manager: langchain_core.callbacks.manager.CallbackManager | langchain_core.callbacks.manager.AsyncCallbackManager | None = None +) -> langchain_core.callbacks.base.Callbacks +``` + + + + + + + + + + + + + +```python +langchain_core.tracers.context._get_tracer_project() -> str +``` + + + + + + + + + + + + + +```python +langchain_core.tracers.context._tracing_v2_is_enabled() -> bool | typing.Literal['local'] +``` + + + + + + + + + + + + + +```python +langchain_core.tracers.context.collect_runs() -> collections.abc.Generator[langchain_core.tracers.run_collector.RunCollectorCallbackHandler, None, None] +``` + + + + + + +Collect all run traces in context. + + + + + + + + +```python +langchain_core.tracers.context.register_configure_hook( + context_var: contextvars.ContextVar[typing.Any | None], + inheritable: bool, + handle_class: type[langchain_core.callbacks.base.BaseCallbackHandler] | None = None, + env_var: str | None = None +) -> None +``` + + + + + + +Register a configure hook. + +**Parameters:** + + +The context variable. + + + +Whether the context variable is inheritable. + + + +The callback handler class. + + + +The environment variable. + + +**Raises:** + +- `ValueError`: If `env_var` is set, `handle_class` must also be set to a non-`None` +value. + + + + + + + + +```python +langchain_core.tracers.context.tracing_v2_enabled( + project_name: str | None = None, + example_id: str | uuid.UUID | None = None, + tags: list[str] | None = None, + client: langsmith.Client | None = None +) -> collections.abc.Generator[langchain_core.tracers.langchain.LangChainTracer, None, None] +``` + + + + + + +Instruct LangChain to log all runs in context to LangSmith. + +**Parameters:** + + +The name of the project. + +Defaults to `'default'`. + + + +The ID of the example. + + + +The tags to add to the run. + + + +The client of the langsmith. + + + + + + + + + +```python +langchain_core.tracers.context._configure_hooks: list[tuple[ContextVar[BaseCallbackHandler | None], bool, type[BaseCallbackHandler] | None, str | None]] = [] +``` + + + + + + + + + +```python +langchain_core.tracers.context.run_collector_var: ContextVar[RunCollectorCallbackHandler | None] = ContextVar('run_collector', default=None) +``` + + + + + + + + + +```python +langchain_core.tracers.context.tracing_callback_var: Any = None +``` + + + + + + + + + +```python +langchain_core.tracers.context.tracing_v2_callback_var: ContextVar[LangChainTracer | None] = ContextVar('tracing_callback_v2', default=None) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/core.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/core.mdx new file mode 100644 index 0000000..2a1102d --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/core.mdx @@ -0,0 +1,1017 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/core +title: langchain_core.tracers.core +--- + +Utilities for the root listener. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`_TracerCore`](#langchain_core-tracers-core-_TracerCore) | Abstract base class for tracers. | + +### Data + +[`SCHEMA_FORMAT_TYPE`](#langchain_core-tracers-core-SCHEMA_FORMAT_TYPE) + +[`logger`](#langchain_core-tracers-core-logger) + +### API + + + + + +```python +class langchain_core.tracers.core._TracerCore( + _schema_format: typing.Literal['original', 'streaming_events', 'original+chat'] = 'original', + kwargs: typing.Any = {} +) +``` + + + + + + +Abstract + +Abstract base class for tracers. + +This class provides common methods, and reusable methods for tracers. + + + + + + +Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed. + + + +Map of run ID to run. Cleared on run end. + + + + + +```python +langchain_core.tracers.core._TracerCore.__copy__() -> langchain_core.tracers.core._TracerCore +``` + + + + + + +Return self copied. + + + + + + + +```python +langchain_core.tracers.core._TracerCore.__deepcopy__( + memo: dict +) -> langchain_core.tracers.core._TracerCore +``` + + + + + + +Return self deepcopied. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._add_child_run( + parent_run: langchain_core.tracers.schemas.Run, + child_run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +staticmethod + +Add child run to a chain run or tool run. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._complete_chain_run( + outputs: dict[str, typing.Any], + run_id: uuid.UUID, + inputs: dict[str, typing.Any] | None = None +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Update a chain run with outputs and end time. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._complete_llm_run( + response: langchain_core.outputs.LLMResult, + run_id: uuid.UUID +) -> langchain_core.tracers.schemas.Run +``` + + + + + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._complete_retrieval_run( + documents: collections.abc.Sequence[langchain_core.documents.Document], + run_id: uuid.UUID +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Update a retrieval run with outputs and end time. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._complete_tool_run( + output: dict[str, typing.Any], + run_id: uuid.UUID +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Update a tool run with outputs and end time. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._create_chain_run( + serialized: dict[str, typing.Any], + inputs: dict[str, typing.Any], + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + run_type: str | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Create a chain Run. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._create_chat_model_run( + serialized: dict[str, typing.Any], + messages: list[list[langchain_core.messages.BaseMessage]], + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Create a chat model run. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._create_llm_run( + serialized: dict[str, typing.Any], + prompts: list[str], + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Create a llm run. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._create_retrieval_run( + serialized: dict[str, typing.Any], + query: str, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Create a retrieval run. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._create_tool_run( + serialized: dict[str, typing.Any], + input_str: str, + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + inputs: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Create a tool run. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._end_trace( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +End a trace for a run. + +**Parameters:** + + +The run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._errored_chain_run( + error: BaseException, + inputs: dict[str, typing.Any] | None, + run_id: uuid.UUID +) -> langchain_core.tracers.schemas.Run +``` + + + + + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._errored_llm_run( + error: BaseException, + run_id: uuid.UUID, + response: langchain_core.outputs.LLMResult | None = None +) -> langchain_core.tracers.schemas.Run +``` + + + + + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._errored_retrieval_run( + error: BaseException, + run_id: uuid.UUID +) -> langchain_core.tracers.schemas.Run +``` + + + + + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._errored_tool_run( + error: BaseException, + run_id: uuid.UUID +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Update a tool run with error and end time. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._get_chain_inputs( + inputs: typing.Any +) -> typing.Any +``` + + + + + + +Get the inputs for a chain run. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._get_chain_outputs( + outputs: typing.Any +) -> typing.Any +``` + + + + + + +Get the outputs for a chain run. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._get_run( + run_id: uuid.UUID, + run_type: str | set[str] | None = None +) -> langchain_core.tracers.schemas.Run +``` + + + + + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._get_stacktrace( + error: BaseException +) -> str +``` + + + + + + +staticmethod + +Get the stacktrace of the parent error. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._llm_run_with_retry_event( + retry_state: tenacity.RetryCallState, + run_id: uuid.UUID +) -> langchain_core.tracers.schemas.Run +``` + + + + + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._llm_run_with_token_event( + token: str, + run_id: uuid.UUID, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None = None, + parent_run_id: uuid.UUID | None = None +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Append token event to LLM run and return the run. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_chain_end( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the Chain Run. + +**Parameters:** + + +The chain run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_chain_error( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the Chain Run upon error. + +**Parameters:** + + +The chain run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_chain_start( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the Chain Run upon start. + +**Parameters:** + + +The chain run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_chat_model_start( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the Chat Model Run upon start. + +**Parameters:** + + +The chat model run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_llm_end( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the LLM Run. + +**Parameters:** + + +The LLM run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_llm_error( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the LLM Run upon error. + +**Parameters:** + + +The LLM run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_llm_new_token( + run: langchain_core.tracers.schemas.Run, + token: str, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process new LLM token. + +**Parameters:** + + +The LLM run. + + + +The new token. + + + +Optional chunk. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_llm_start( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the LLM Run upon start. + +**Parameters:** + + +The LLM run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_retriever_end( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the Retriever Run. + +**Parameters:** + + +The retriever run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_retriever_error( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the Retriever Run upon error. + +**Parameters:** + + +The retriever run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_retriever_start( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the Retriever Run upon start. + +**Parameters:** + + +The retriever run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_run_create( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process a run upon creation. + +**Parameters:** + + +The created run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_run_update( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process a run upon update. + +**Parameters:** + + +The updated run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_tool_end( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the Tool Run. + +**Parameters:** + + +The tool run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_tool_error( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the Tool Run upon error. + +**Parameters:** + + +The tool run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._on_tool_start( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +Process the Tool Run upon start. + +**Parameters:** + + +The tool run. + + + + + + + + +```python +langchain_core.tracers.core._TracerCore._persist_run( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + +abstract + +Persist a run. + + + + + + + +```python +langchain_core.tracers.core._TracerCore._start_trace( + run: langchain_core.tracers.schemas.Run +) -> collections.abc.Coroutine[typing.Any, typing.Any, None] | None +``` + + + + + + + + + + + + + + +```python +langchain_core.tracers.core.SCHEMA_FORMAT_TYPE = Literal['original', 'streaming_events'] +``` + + + + + + + + + +```python +langchain_core.tracers.core.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/evaluation.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/evaluation.mdx new file mode 100644 index 0000000..6676d63 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/evaluation.mdx @@ -0,0 +1,233 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/evaluation +title: langchain_core.tracers.evaluation +--- + +A tracer that runs evaluators over completed runs. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`EvaluatorCallbackHandler`](#langchain_core-tracers-evaluation-EvaluatorCallbackHandler) | Tracer that runs a run evaluator whenever a run is persisted. | + +### Functions + +| Name | Description | +|------|-------------| +| [`wait_for_all_evaluators`](#langchain_core-tracers-evaluation-wait_for_all_evaluators) | Wait for all tracers to finish. | + +### Data + +[`_TRACERS`](#langchain_core-tracers-evaluation-_TRACERS) + +[`logger`](#langchain_core-tracers-evaluation-logger) + +### API + + + + + +```python +class langchain_core.tracers.evaluation.EvaluatorCallbackHandler( + evaluators: collections.abc.Sequence[langsmith.RunEvaluator], + client: langsmith.Client | None = None, + example_id: uuid.UUID | str | None = None, + skip_unfinished: bool = True, + project_name: str | None = 'evaluators', + max_concurrency: int | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseTracer](/langchain-core/langchain_core/tracers/base#langchain_core-tracers-base-BaseTracer) + +Tracer that runs a run evaluator whenever a run is persisted. + + + +The LangSmith client instance used for evaluating the runs. + + + +The example ID associated with the runs. + + + +The thread pool executor used for running the evaluators. + + + +The set of futures representing the running evaluators. + + + + + + + + + + + + + + +```python +langchain_core.tracers.evaluation.EvaluatorCallbackHandler._evaluate_in_project( + run: langchain_core.tracers.schemas.Run, + evaluator: langsmith.RunEvaluator +) -> None +``` + + + + + + +Evaluate the run in the project. + +**Parameters:** + + +The run to be evaluated. + + + +The evaluator to use for evaluating the run. + + + + + + + + +```python +langchain_core.tracers.evaluation.EvaluatorCallbackHandler._log_evaluation_feedback( + evaluator_response: langsmith.evaluation.evaluator.EvaluationResult | langsmith.evaluation.evaluator.EvaluationResults, + run: langchain_core.tracers.schemas.Run, + source_run_id: uuid.UUID | None = None +) -> list[langsmith.evaluation.evaluator.EvaluationResult] +``` + + + + + + + + + + + + +```python +langchain_core.tracers.evaluation.EvaluatorCallbackHandler._persist_run( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Run the evaluator on the run. + +**Parameters:** + + +The run to be evaluated. + + + + + + + + +```python +langchain_core.tracers.evaluation.EvaluatorCallbackHandler._select_eval_results( + results: langsmith.evaluation.evaluator.EvaluationResult | langsmith.evaluation.evaluator.EvaluationResults +) -> list[langsmith.evaluation.evaluator.EvaluationResult] +``` + + + + + + +staticmethod + + + + + + + +```python +langchain_core.tracers.evaluation.EvaluatorCallbackHandler.wait_for_futures() -> None +``` + + + + + + +Wait for all futures to complete. + + + + + + + + + +```python +langchain_core.tracers.evaluation.wait_for_all_evaluators() -> None +``` + + + + + + +Wait for all tracers to finish. + + + + + + + + +```python +langchain_core.tracers.evaluation._TRACERS: WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet() +``` + + + + + + + + + +```python +langchain_core.tracers.evaluation.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/event_stream.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/event_stream.mdx new file mode 100644 index 0000000..1b2c9a0 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/event_stream.mdx @@ -0,0 +1,777 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/event_stream +title: langchain_core.tracers.event_stream +--- + +Internal tracer to power the event stream API. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RunInfo`](#langchain_core-tracers-event_stream-RunInfo) | Information about a run. | +| [`_AstreamEventsCallbackHandler`](#langchain_core-tracers-event_stream-_AstreamEventsCallbackHandler) | An implementation of an async callback handler for astream events. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_assign_name`](#langchain_core-tracers-event_stream-_assign_name) | Assign a name to a run. | +| [`_astream_events_implementation_v1`](#langchain_core-tracers-event_stream-_astream_events_implementation_v1) | - | +| [`_astream_events_implementation_v2`](#langchain_core-tracers-event_stream-_astream_events_implementation_v2) | Implementation of the astream events API for v2 runnables. | + +### Data + +[`T`](#langchain_core-tracers-event_stream-T) + +[`logger`](#langchain_core-tracers-event_stream-logger) + +### API + + + + + +```python +class langchain_core.tracers.event_stream.RunInfo +``` + + + + + + +**Bases:** `typing.TypedDict` + +Information about a run. + +This is used to keep track of the metadata associated with a run. + + +The inputs to the run. + + + +The metadata associated with the run. + + + +The name of the run. + + + +The ID of the parent run. + + + +The type of the run. + + + +The tags associated with the run. + + + +The tool call ID associated with the run. + + + + + + + + +```python +class langchain_core.tracers.event_stream._AstreamEventsCallbackHandler( + args: typing.Any = (), + include_names: collections.abc.Sequence[str] | None = None, + include_types: collections.abc.Sequence[str] | None = None, + include_tags: collections.abc.Sequence[str] | None = None, + exclude_names: collections.abc.Sequence[str] | None = None, + exclude_types: collections.abc.Sequence[str] | None = None, + exclude_tags: collections.abc.Sequence[str] | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [AsyncCallbackHandler](/langchain-core/langchain_core/callbacks/base#langchain_core-callbacks-base-AsyncCallbackHandler), [_StreamingCallbackHandler](/langchain-core/langchain_core/tracers/_streaming#langchain_core-tracers-_streaming-_StreamingCallbackHandler) + +An implementation of an async callback handler for astream events. + + + + + + + + + + + + + + + + + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.__aiter__() -> collections.abc.AsyncIterator[typing.Any] +``` + + + + + + +Iterate over the receive stream. + +**Returns:** `AsyncIterator[Any]` + +An async iterator over the receive stream. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.__copy__() -> langchain_core.tracers.event_stream._AstreamEventsCallbackHandler +``` + + + + + + +Return self. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.__deepcopy__( + memo: dict +) -> langchain_core.tracers.event_stream._AstreamEventsCallbackHandler +``` + + + + + + +Return self. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler._get_parent_ids( + run_id: uuid.UUID +) -> list[str] +``` + + + + + + +Get the parent IDs of a run (non-recursively) cast to strings. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler._get_tool_run_info_with_inputs( + run_id: uuid.UUID +) -> tuple[langchain_core.tracers.event_stream.RunInfo, typing.Any] +``` + + + + + + +Get run info for a tool and extract inputs, with validation. + +**Parameters:** + + +The run ID of the tool. + + +**Returns:** `tuple[RunInfo, Any]` + +A tuple of `(run_info, inputs)`. + +**Raises:** + +- `AssertionError`: If the run ID is a tool call and does not have inputs. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler._send( + event: langchain_core.runnables.schema.StreamEvent, + event_type: str +) -> None +``` + + + + + + +Send an event to the stream. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler._write_run_start_info( + run_id: uuid.UUID, + tags: list[str] | None, + metadata: dict[str, typing.Any] | None, + parent_run_id: uuid.UUID | None, + name_: str, + run_type: str, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +Update the run info. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_chain_end( + outputs: dict[str, typing.Any], + run_id: uuid.UUID, + inputs: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +End a trace for a chain run. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_chain_start( + serialized: dict[str, typing.Any], + inputs: dict[str, typing.Any], + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + run_type: str | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Start a trace for a chain run. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_chat_model_start( + serialized: dict[str, typing.Any], + messages: list[list[langchain_core.messages.BaseMessage]], + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Start a trace for a chat model run. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_custom_event( + name: str, + data: typing.Any, + run_id: uuid.UUID, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Generate a custom astream event. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_llm_end( + response: langchain_core.outputs.LLMResult, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +End a trace for a model run. + +For both chat models and non-chat models (legacy text-completion LLMs). + +**Raises:** + +- `ValueError`: If the run type is not `'llm'` or `'chat_model'`. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_llm_new_token( + token: str, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None = None, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run on new output token. + +Only available when streaming is enabled. + +For both chat models and non-chat models (legacy text-completion LLMs). + +**Raises:** + +- `ValueError`: If the run type is not `llm` or `chat_model`. +- `AssertionError`: If the run ID is not found in the run map. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_llm_start( + serialized: dict[str, typing.Any], + prompts: list[str], + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Start a trace for a (non-chat model) LLM run. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_retriever_end( + documents: collections.abc.Sequence[langchain_core.documents.Document], + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when `Retriever` ends running. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_retriever_start( + serialized: dict[str, typing.Any], + query: str, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when `Retriever` starts running. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_tool_end( + output: typing.Any, + run_id: uuid.UUID, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +End a trace for a tool run. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_tool_error( + error: BaseException, + run_id: uuid.UUID, + parent_run_id: uuid.UUID | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Run when tool errors. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.on_tool_start( + serialized: dict[str, typing.Any], + input_str: str, + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + inputs: dict[str, typing.Any] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + +Start a trace for a tool run. + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.tap_output_aiter( + run_id: uuid.UUID, + output: collections.abc.AsyncIterator[langchain_core.tracers.event_stream.T] +) -> collections.abc.AsyncIterator[langchain_core.tracers.event_stream.T] +``` + + + + + + +async + +Tap the output aiter. + +This method is used to tap the output of a `Runnable` that produces an async +iterator. It is used to generate stream events for the output of the `Runnable`. + +**Parameters:** + + +The ID of the run. + + + +The output of the `Runnable`. + + + + + + + + +```python +langchain_core.tracers.event_stream._AstreamEventsCallbackHandler.tap_output_iter( + run_id: uuid.UUID, + output: collections.abc.Iterator[langchain_core.tracers.event_stream.T] +) -> collections.abc.Iterator[langchain_core.tracers.event_stream.T] +``` + + + + + + +Tap the output iter. + +**Parameters:** + + +The ID of the run. + + + +The output of the `Runnable`. + + + + + + + + + + +```python +langchain_core.tracers.event_stream._assign_name( + name: str | None, + serialized: dict[str, typing.Any] | None +) -> str +``` + + + + + + +Assign a name to a run. + + + + + + + + +```python +langchain_core.tracers.event_stream._astream_events_implementation_v1( + runnable: langchain_core.runnables.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output], + value: typing.Any, + config: langchain_core.runnables.RunnableConfig | None = None, + include_names: collections.abc.Sequence[str] | None = None, + include_types: collections.abc.Sequence[str] | None = None, + include_tags: collections.abc.Sequence[str] | None = None, + exclude_names: collections.abc.Sequence[str] | None = None, + exclude_types: collections.abc.Sequence[str] | None = None, + exclude_tags: collections.abc.Sequence[str] | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.schema.StandardStreamEvent] +``` + + + + + + +async + + + + + + + + +```python +langchain_core.tracers.event_stream._astream_events_implementation_v2( + runnable: langchain_core.runnables.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output], + value: typing.Any, + config: langchain_core.runnables.RunnableConfig | None = None, + include_names: collections.abc.Sequence[str] | None = None, + include_types: collections.abc.Sequence[str] | None = None, + include_tags: collections.abc.Sequence[str] | None = None, + exclude_names: collections.abc.Sequence[str] | None = None, + exclude_types: collections.abc.Sequence[str] | None = None, + exclude_tags: collections.abc.Sequence[str] | None = None, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.runnables.schema.StandardStreamEvent] +``` + + + + + + +async + +Implementation of the astream events API for v2 runnables. + + + + + + + + +```python +langchain_core.tracers.event_stream.T = TypeVar('T') +``` + + + + + + + + + +```python +langchain_core.tracers.event_stream.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/langchain.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/langchain.mdx new file mode 100644 index 0000000..c2e34d2 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/langchain.mdx @@ -0,0 +1,704 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/langchain +title: langchain_core.tracers.langchain +--- + +A tracer implementation that records to LangChain endpoint. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LangChainTracer`](#langchain_core-tracers-langchain-LangChainTracer) | Implementation of the `SharedTracer` that `POSTS` to the LangChain endpoint. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_executor`](#langchain_core-tracers-langchain-_get_executor) | Get the executor. | +| [`_get_usage_metadata_from_generations`](#langchain_core-tracers-langchain-_get_usage_metadata_from_generations) | Extract and aggregate `usage_metadata` from generations. | +| [`get_client`](#langchain_core-tracers-langchain-get_client) | Get the client. | +| [`log_error_once`](#langchain_core-tracers-langchain-log_error_once) | Log an error once. | +| [`wait_for_all_tracers`](#langchain_core-tracers-langchain-wait_for_all_tracers) | Wait for all tracers to finish. | + +### Data + +[`_EXECUTOR`](#langchain_core-tracers-langchain-_EXECUTOR) + +[`_LOGGED`](#langchain_core-tracers-langchain-_LOGGED) + +[`logger`](#langchain_core-tracers-langchain-logger) + +### API + + + + + +```python +class langchain_core.tracers.langchain.LangChainTracer( + example_id: uuid.UUID | str | None = None, + project_name: str | None = None, + client: langsmith.Client | None = None, + tags: list[str] | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseTracer](/langchain-core/langchain_core/tracers/base#langchain_core-tracers-base-BaseTracer) + +Implementation of the `SharedTracer` that `POSTS` to the LangChain endpoint. + + + + + + + + + + + + + + + + + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._get_tags( + run: langchain_core.tracers.schemas.Run +) -> list[str] +``` + + + + + + +Get combined tags for a run. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._llm_run_with_token_event( + token: str, + run_id: uuid.UUID, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None = None, + parent_run_id: uuid.UUID | None = None +) -> langchain_core.tracers.schemas.Run +``` + + + + + + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_chain_end( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Process the Chain Run. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_chain_error( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Process the Chain Run upon error. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_chain_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Process the Chain Run upon start. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_chat_model_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Persist an LLM run. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_llm_end( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Process the LLM Run. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_llm_error( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Process the LLM Run upon error. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_llm_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Persist an LLM run. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_retriever_end( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Process the Retriever Run. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_retriever_error( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Process the Retriever Run upon error. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_retriever_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Process the Retriever Run upon start. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_tool_end( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Process the Tool Run. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_tool_error( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Process the Tool Run upon error. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._on_tool_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Process the Tool Run upon start. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._persist_run( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._persist_run_single( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Persist a run. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._start_trace( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer._update_run_single( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +staticmethod + +Update a run. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer.get_run_url() -> str +``` + + + + + + +Get the LangSmith root run URL. + +**Returns:** `str` + +The LangSmith root run URL. + +**Raises:** + +- `ValueError`: If no traced run is found. +- `ValueError`: If the run URL cannot be found. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer.on_chat_model_start( + serialized: dict[str, typing.Any], + messages: list[list[langchain_core.messages.BaseMessage]], + run_id: uuid.UUID, + tags: list[str] | None = None, + parent_run_id: uuid.UUID | None = None, + metadata: dict[str, typing.Any] | None = None, + name: str | None = None, + kwargs: typing.Any = {} +) -> langchain_core.tracers.schemas.Run +``` + + + + + + +Start a trace for an LLM run. + +**Parameters:** + + +The serialized model. + + + +The messages. + + + +The run ID. + + + +The tags. + + + +The parent run ID. + + + +The metadata. + + + +The name. + + + +Additional keyword arguments. + + +**Returns:** `Run` + +The run. + + + + + + + +```python +langchain_core.tracers.langchain.LangChainTracer.wait_for_futures() -> None +``` + + + + + + +Wait for the given futures to complete. + + + + + + + + + +```python +langchain_core.tracers.langchain._get_executor() -> concurrent.futures.ThreadPoolExecutor +``` + + + + + + +Get the executor. + + + + + + + + +```python +langchain_core.tracers.langchain._get_usage_metadata_from_generations( + generations: list[list[dict[str, typing.Any]]] +) -> langchain_core.messages.ai.UsageMetadata | None +``` + + + + + + +Extract and aggregate `usage_metadata` from generations. + +Iterates through generations to find and aggregate all `usage_metadata` found in +messages. This is typically present in chat model outputs. + +**Parameters:** + + +List of generation batches, where each batch is a list of +generation dicts that may contain a `'message'` key with `'usage_metadata'`. + + +**Returns:** `UsageMetadata | None` + +The aggregated `usage_metadata` dict if found, otherwise `None`. + + + + + + + + +```python +langchain_core.tracers.langchain.get_client() -> langsmith.Client +``` + + + + + + +Get the client. + +**Returns:** `Client` + +The LangSmith client. + + + + + + + + +```python +langchain_core.tracers.langchain.log_error_once( + method: str, + exception: Exception +) -> None +``` + + + + + + +Log an error once. + +**Parameters:** + + +The method that raised the exception. + + + +The exception that was raised. + + + + + + + + + +```python +langchain_core.tracers.langchain.wait_for_all_tracers() -> None +``` + + + + + + +Wait for all tracers to finish. + + + + + + + + +```python +langchain_core.tracers.langchain._EXECUTOR: ThreadPoolExecutor | None = None +``` + + + + + + + + + +```python +langchain_core.tracers.langchain._LOGGED = set() +``` + + + + + + + + + +```python +langchain_core.tracers.langchain.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/log_stream.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/log_stream.mdx new file mode 100644 index 0000000..82d598d --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/log_stream.mdx @@ -0,0 +1,759 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/log_stream +title: langchain_core.tracers.log_stream +--- + +Tracer that streams run logs to a stream. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LogEntry`](#langchain_core-tracers-log_stream-LogEntry) | A single entry in the run log. | +| [`LogStreamCallbackHandler`](#langchain_core-tracers-log_stream-LogStreamCallbackHandler) | Tracer that streams run logs to a stream. | +| [`RunLog`](#langchain_core-tracers-log_stream-RunLog) | Run log. | +| [`RunLogPatch`](#langchain_core-tracers-log_stream-RunLogPatch) | Patch to the run log. | +| [`RunState`](#langchain_core-tracers-log_stream-RunState) | State of the run. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_astream_log_implementation`](#langchain_core-tracers-log_stream-_astream_log_implementation) | Implementation of astream_log for a given runnable. | +| [`_get_standardized_inputs`](#langchain_core-tracers-log_stream-_get_standardized_inputs) | Extract standardized inputs from a `Run`. | +| [`_get_standardized_outputs`](#langchain_core-tracers-log_stream-_get_standardized_outputs) | Extract standardized output from a run. | + +### Data + +[`T`](#langchain_core-tracers-log_stream-T) + +### API + + + + + +```python +class langchain_core.tracers.log_stream.LogEntry +``` + + + + + + +**Bases:** `typing.TypedDict` + +A single entry in the run log. + + +ISO-8601 timestamp of when the run ended. + +Only available after the run has finished. + + + +Final output of this run. + +Only available after the run has finished successfully. + + + +ID of the sub-run. + + + +Inputs to this run. Not available currently via `astream_log`. + + + +Key-value pairs of metadata for the run. + + + +Name of the object being run. + + + +ISO-8601 timestamp of when the run started. + + + +List of output chunks streamed by this run, if available. + + + +List of LLM tokens streamed by this run, if applicable. + + + +List of tags for the run. + + + +Type of the object being run, eg. prompt, chain, llm, etc. + + + + + + + + +```python +class langchain_core.tracers.log_stream.LogStreamCallbackHandler( + auto_close: bool = True, + include_names: collections.abc.Sequence[str] | None = None, + include_types: collections.abc.Sequence[str] | None = None, + include_tags: collections.abc.Sequence[str] | None = None, + exclude_names: collections.abc.Sequence[str] | None = None, + exclude_types: collections.abc.Sequence[str] | None = None, + exclude_tags: collections.abc.Sequence[str] | None = None, + _schema_format: typing.Literal['original', 'streaming_events'] = 'streaming_events' +) +``` + + + + + + +**Bases:** [BaseTracer](/langchain-core/langchain_core/tracers/base#langchain_core-tracers-base-BaseTracer), [_StreamingCallbackHandler](/langchain-core/langchain_core/tracers/_streaming#langchain_core-tracers-_streaming-_StreamingCallbackHandler) + +Tracer that streams run logs to a stream. + + + + + + + + + + + + + + + + + + + + + + + +```python +langchain_core.tracers.log_stream.LogStreamCallbackHandler.__aiter__() -> collections.abc.AsyncIterator[langchain_core.tracers.log_stream.RunLogPatch] +``` + + + + + + +Iterate over the stream of run logs. + +**Returns:** `AsyncIterator[RunLogPatch]` + +An async iterator over the run log patches. + + + + + + + +```python +langchain_core.tracers.log_stream.LogStreamCallbackHandler._on_llm_new_token( + run: langchain_core.tracers.schemas.Run, + token: str, + chunk: langchain_core.outputs.GenerationChunk | langchain_core.outputs.ChatGenerationChunk | None +) -> None +``` + + + + + + +Process new LLM token. + + + + + + + +```python +langchain_core.tracers.log_stream.LogStreamCallbackHandler._on_run_create( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Start a run. + + + + + + + +```python +langchain_core.tracers.log_stream.LogStreamCallbackHandler._on_run_update( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Finish a `Run`. + + + + + + + +```python +langchain_core.tracers.log_stream.LogStreamCallbackHandler._persist_run( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.log_stream.LogStreamCallbackHandler.include_run( + run: langchain_core.tracers.schemas.Run +) -> bool +``` + + + + + + +Check if a `Run` should be included in the log. + +**Parameters:** + + +The `Run` to check. + + +**Returns:** `bool` + +`True` if the `Run` should be included, `False` otherwise. + + + + + + + +```python +langchain_core.tracers.log_stream.LogStreamCallbackHandler.send( + ops: dict[str, typing.Any] = () +) -> bool +``` + + + + + + +Send a patch to the stream, return `False` if the stream is closed. + +**Parameters:** + + +The operations to send to the stream. + + +**Returns:** `bool` + +`True` if the patch was sent successfully, `False` if the stream is closed. + + + + + + + +```python +langchain_core.tracers.log_stream.LogStreamCallbackHandler.tap_output_aiter( + run_id: uuid.UUID, + output: collections.abc.AsyncIterator[langchain_core.tracers.log_stream.T] +) -> collections.abc.AsyncIterator[langchain_core.tracers.log_stream.T] +``` + + + + + + +async + +Tap an output async iterator to stream its values to the log. + +**Parameters:** + + +The ID of the run. + + + +The output async iterator. + + + + + + + + +```python +langchain_core.tracers.log_stream.LogStreamCallbackHandler.tap_output_iter( + run_id: uuid.UUID, + output: collections.abc.Iterator[langchain_core.tracers.log_stream.T] +) -> collections.abc.Iterator[langchain_core.tracers.log_stream.T] +``` + + + + + + +Tap an output iterator to stream its values to the log. + +**Parameters:** + + +The ID of the run. + + + +The output iterator. + + + + + + + + + + +```python +class langchain_core.tracers.log_stream.RunLog( + ops: dict[str, typing.Any] = (), + state: langchain_core.tracers.log_stream.RunState +) +``` + + + + + + +**Bases:** [RunLogPatch](#langchain_core-tracers-log_stream-RunLogPatch) + +Run log. + + + + + + +```python +langchain_core.tracers.log_stream.RunLog.__add__( + other: langchain_core.tracers.log_stream.RunLogPatch | typing.Any +) -> langchain_core.tracers.log_stream.RunLog +``` + + + + + + +Combine two `RunLog` objects. + +**Parameters:** + + +The other `RunLog` or `RunLogPatch` to combine with. + + +**Returns:** `RunLog` + +A new `RunLog` representing the combination of the two. + +**Raises:** + +- `TypeError`: If the other object is not a `RunLog` or `RunLogPatch`. + + + + + + + +```python +langchain_core.tracers.log_stream.RunLog.__eq__( + other: object +) -> bool +``` + + + + + + +Check if two `RunLog`s are equal. + +**Parameters:** + + +The other `RunLog` to compare to. + + +**Returns:** `bool` + +`True` if the `RunLog`s are equal, `False` otherwise. + + + + + + + +```python +langchain_core.tracers.log_stream.RunLog.__repr__() -> str +``` + + + + + + + + + + + + + + +```python +class langchain_core.tracers.log_stream.RunLogPatch( + ops: dict[str, typing.Any] = () +) +``` + + + + + + +Patch to the run log. + + + +List of `JSONPatch` operations, which describe how to create the run state +from an empty dict. + +This is the minimal representation of the log, designed to be serialized as JSON and +sent over the wire to reconstruct the log on the other side. Reconstruction of the +state can be done with any JSONPatch-compliant library, see https://jsonpatch.com +for more information. + + + + + +```python +langchain_core.tracers.log_stream.RunLogPatch.__add__( + other: langchain_core.tracers.log_stream.RunLogPatch | typing.Any +) -> langchain_core.tracers.log_stream.RunLog +``` + + + + + + +Combine two `RunLogPatch` instances. + +**Parameters:** + + +The other `RunLogPatch` to combine with. + + +**Returns:** `RunLog` + +A new `RunLog` representing the combination of the two. + +**Raises:** + +- `TypeError`: If the other object is not a `RunLogPatch`. + + + + + + + +```python +langchain_core.tracers.log_stream.RunLogPatch.__eq__( + other: object +) -> bool +``` + + + + + + + + + + + + +```python +langchain_core.tracers.log_stream.RunLogPatch.__repr__() -> str +``` + + + + + + + + + + + + + + +```python +class langchain_core.tracers.log_stream.RunState +``` + + + + + + +**Bases:** `typing.TypedDict` + +State of the run. + + +Final output of the run, usually the result of aggregating (`+`) streamed_output. + +Updated throughout the run when supported by the `Runnable`. + + + +ID of the run. + + + +Map of run names to sub-runs. + +If filters were supplied, this list will contain only the runs that matched the +filters. + + + +Name of the object being run. + + + +List of output chunks streamed by `Runnable.stream()` + + + +Type of the object being run, e.g. prompt, chain, llm, etc. + + + + + + + + +```python +langchain_core.tracers.log_stream._astream_log_implementation( + runnable: langchain_core.runnables.Runnable[langchain_core.runnables.utils.Input, langchain_core.runnables.utils.Output], + value: typing.Any, + config: langchain_core.runnables.RunnableConfig | None = None, + stream: langchain_core.tracers.log_stream.LogStreamCallbackHandler, + diff: bool = True, + with_streamed_output_list: bool = True, + kwargs: typing.Any = {} +) -> collections.abc.AsyncIterator[langchain_core.tracers.log_stream.RunLogPatch] | collections.abc.AsyncIterator[langchain_core.tracers.log_stream.RunLog] +``` + + + + + + +async + +Implementation of astream_log for a given runnable. + +The implementation has been factored out (at least temporarily) as both +`astream_log` and `astream_events` rely on it. + +**Parameters:** + + +The runnable to run in streaming mode. + + + +The input to the runnable. + + + +The config to pass to the runnable. + + + +The stream to send the run logs to. + + + +Whether to yield run log patches (`True`) or full run logs (`False`). + + + +Whether to include a list of all streamed outputs in +each patch. If `False`, only the final output will be included in the +patches. + + + +Additional keyword arguments to pass to the `Runnable`. + + +**Raises:** + +- `ValueError`: If the callbacks in the config are of an unexpected type. + + + + + + + + +```python +langchain_core.tracers.log_stream._get_standardized_inputs( + run: langchain_core.tracers.schemas.Run, + schema_format: typing.Literal['original', 'streaming_events'] +) -> typing.Any +``` + + + + + + +Extract standardized inputs from a `Run`. + +Standardizes the inputs based on the type of the runnable used. + +**Parameters:** + + +`Run` object + + + +The schema format to use. + + +**Returns:** `Any` + +Valid inputs are only dict. By conventions, inputs always represented invocation +using named arguments. `None` means that the input is not yet known! + + + + + + + + +```python +langchain_core.tracers.log_stream._get_standardized_outputs( + run: langchain_core.tracers.schemas.Run, + schema_format: typing.Literal['original', 'streaming_events', 'original+chat'] +) -> typing.Any | None +``` + + + + + + +Extract standardized output from a run. + +Standardizes the outputs based on the type of the runnable used. + +**Parameters:** + + +the run object. + + + +The schema format to use. + + +**Returns:** `Any | None` + +An output if returned, otherwise `None`. + + + + + + + + +```python +langchain_core.tracers.log_stream.T = TypeVar('T') +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/memory_stream.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/memory_stream.mdx new file mode 100644 index 0000000..39a1ab9 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/memory_stream.mdx @@ -0,0 +1,282 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/memory_stream +title: langchain_core.tracers.memory_stream +--- + +Module implements a memory stream for communication between two co-routines. + +This module provides a way to communicate between two co-routines using a memory +channel. The writer and reader can be in the same event loop or in different event +loops. When they're in different event loops, they will also be in different threads. + +Useful in situations when there's a mix of synchronous and asynchronous used in the +code. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`_MemoryStream`](#langchain_core-tracers-memory_stream-_MemoryStream) | Stream data from a writer to a reader even if they are in different threads. | +| [`_ReceiveStream`](#langchain_core-tracers-memory_stream-_ReceiveStream) | - | +| [`_SendStream`](#langchain_core-tracers-memory_stream-_SendStream) | - | + +### Data + +[`T`](#langchain_core-tracers-memory_stream-T) + +### API + + + + + +```python +class langchain_core.tracers.memory_stream._MemoryStream( + loop: asyncio.AbstractEventLoop +) +``` + + + + + + +**Bases:** `Generic[T]` + +Stream data from a writer to a reader even if they are in different threads. + +Uses asyncio queues to communicate between two co-routines. This implementation +should work even if the writer and reader co-routines belong to two different event +loops (e.g. one running from an event loop in the main thread and the other running +in an event loop in a background thread). + +This implementation is meant to be used with a single writer and a single reader. + +This is an internal implementation to LangChain. Do not use it directly. + + + + + + + + + + + +```python +langchain_core.tracers.memory_stream._MemoryStream.get_receive_stream() -> langchain_core.tracers.memory_stream._ReceiveStream[langchain_core.tracers.memory_stream.T] +``` + + + + + + +Get a reader for the channel. + +**Returns:** `_ReceiveStream[T]` + +The reader for the channel. + + + + + + + +```python +langchain_core.tracers.memory_stream._MemoryStream.get_send_stream() -> langchain_core.tracers.memory_stream._SendStream[langchain_core.tracers.memory_stream.T] +``` + + + + + + +Get a writer for the channel. + +**Returns:** `_SendStream[T]` + +The writer for the channel. + + + + + + + + + +```python +class langchain_core.tracers.memory_stream._ReceiveStream( + queue: asyncio.Queue, + done: object +) +``` + + + + + + +**Bases:** `Generic[T]` + + + + + +```python +langchain_core.tracers.memory_stream._ReceiveStream.__aiter__() -> collections.abc.AsyncIterator[langchain_core.tracers.memory_stream.T] +``` + + + + + + +async + + + + + + + + + +```python +class langchain_core.tracers.memory_stream._SendStream( + reader_loop: asyncio.AbstractEventLoop, + queue: asyncio.Queue, + done: object +) +``` + + + + + + +**Bases:** `Generic[T]` + + + + + +```python +langchain_core.tracers.memory_stream._SendStream.aclose() -> None +``` + + + + + + +async + +Async schedule the done object write the queue using the original loop. + + + + + + + +```python +langchain_core.tracers.memory_stream._SendStream.close() -> None +``` + + + + + + +Schedule the done object write the queue using the original loop. + +This is a non-blocking call. + +**Raises:** + +- `RuntimeError`: If the event loop is already closed when trying to write to +the queue. + + + + + + + +```python +langchain_core.tracers.memory_stream._SendStream.send( + item: langchain_core.tracers.memory_stream.T +) -> None +``` + + + + + + +async + +Schedule the item to be written to the queue using the original loop. + +This is a coroutine that can be awaited. + +**Parameters:** + + +The item to write to the queue. + + + + + + + + +```python +langchain_core.tracers.memory_stream._SendStream.send_nowait( + item: langchain_core.tracers.memory_stream.T +) -> None +``` + + + + + + +Schedule the item to be written to the queue using the original loop. + +This is a non-blocking call. + +**Parameters:** + + +The item to write to the queue. + + +**Raises:** + +- `RuntimeError`: If the event loop is already closed when trying to write to +the queue. + + + + + + + + + +```python +langchain_core.tracers.memory_stream.T = TypeVar('T') +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/root_listeners.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/root_listeners.mdx new file mode 100644 index 0000000..55f6a50 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/root_listeners.mdx @@ -0,0 +1,218 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/root_listeners +title: langchain_core.tracers.root_listeners +--- + +Tracers that call listeners. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AsyncRootListenersTracer`](#langchain_core-tracers-root_listeners-AsyncRootListenersTracer) | Async tracer that calls listeners on run start, end, and error. | +| [`RootListenersTracer`](#langchain_core-tracers-root_listeners-RootListenersTracer) | Tracer that calls listeners on run start, end, and error. | + +### Data + +[`AsyncListener`](#langchain_core-tracers-root_listeners-AsyncListener) + +[`Listener`](#langchain_core-tracers-root_listeners-Listener) + +### API + + + + + +```python +class langchain_core.tracers.root_listeners.AsyncRootListenersTracer( + config: langchain_core.runnables.config.RunnableConfig, + on_start: langchain_core.tracers.root_listeners.AsyncListener | None, + on_end: langchain_core.tracers.root_listeners.AsyncListener | None, + on_error: langchain_core.tracers.root_listeners.AsyncListener | None +) +``` + + + + + + +**Bases:** [AsyncBaseTracer](/langchain-core/langchain_core/tracers/base#langchain_core-tracers-base-AsyncBaseTracer) + +Async tracer that calls listeners on run start, end, and error. + + + +Whether to log a warning if the parent is missing. + + + + + + + + +```python +langchain_core.tracers.root_listeners.AsyncRootListenersTracer._on_run_create( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.root_listeners.AsyncRootListenersTracer._on_run_update( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.tracers.root_listeners.AsyncRootListenersTracer._persist_run( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +async + + + + + + + + + +```python +class langchain_core.tracers.root_listeners.RootListenersTracer( + config: langchain_core.runnables.config.RunnableConfig, + on_start: langchain_core.tracers.root_listeners.Listener | None, + on_end: langchain_core.tracers.root_listeners.Listener | None, + on_error: langchain_core.tracers.root_listeners.Listener | None +) +``` + + + + + + +**Bases:** [BaseTracer](/langchain-core/langchain_core/tracers/base#langchain_core-tracers-base-BaseTracer) + +Tracer that calls listeners on run start, end, and error. + + + +Whether to log a warning if the parent is missing. + + + + + + + + +```python +langchain_core.tracers.root_listeners.RootListenersTracer._on_run_create( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.root_listeners.RootListenersTracer._on_run_update( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.root_listeners.RootListenersTracer._persist_run( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + + + +```python +langchain_core.tracers.root_listeners.AsyncListener = Callable[[Run], Awaitable[None]] | Callable[[Run, RunnableConfig], Awaitable[Non... +``` + + + + + + + + + +```python +langchain_core.tracers.root_listeners.Listener = Callable[[Run], None] | Callable[[Run, RunnableConfig], None] +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/run_collector.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/run_collector.mdx new file mode 100644 index 0000000..9f0e8c0 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/run_collector.mdx @@ -0,0 +1,75 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/run_collector +title: langchain_core.tracers.run_collector +--- + +A tracer that collects all nested runs in a list. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RunCollectorCallbackHandler`](#langchain_core-tracers-run_collector-RunCollectorCallbackHandler) | Tracer that collects all nested runs in a list. | + +### API + + + + + +```python +class langchain_core.tracers.run_collector.RunCollectorCallbackHandler( + example_id: uuid.UUID | str | None = None, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseTracer](/langchain-core/langchain_core/tracers/base#langchain_core-tracers-base-BaseTracer) + +Tracer that collects all nested runs in a list. + +This tracer is useful for inspection and evaluation purposes. + + + + + + + + + + + + + + +```python +langchain_core.tracers.run_collector.RunCollectorCallbackHandler._persist_run( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + +Persist a run by adding it to the `traced_runs` list. + +**Parameters:** + + +The run to be persisted. + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/schemas.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/schemas.mdx new file mode 100644 index 0000000..035e92d --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/schemas.mdx @@ -0,0 +1,41 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/schemas +title: langchain_core.tracers.schemas +--- + +Schemas for tracers. + +## Module Contents + +### Data + +[`Run`](#langchain_core-tracers-schemas-Run) + +[`__all__`](#langchain_core-tracers-schemas-__all__) + +### API + + + + + +```python +langchain_core.tracers.schemas.Run = RunTree +``` + + + + + + + + + +```python +langchain_core.tracers.schemas.__all__ = ['Run'] +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/stdout.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/stdout.mdx new file mode 100644 index 0000000..9af1c84 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/tracers/stdout.mdx @@ -0,0 +1,387 @@ +--- +layout: overview +slug: langchain-core/langchain_core/tracers/stdout +title: langchain_core.tracers.stdout +--- + +Tracers that print to the console. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ConsoleCallbackHandler`](#langchain_core-tracers-stdout-ConsoleCallbackHandler) | Tracer that prints to the console. | +| [`FunctionCallbackHandler`](#langchain_core-tracers-stdout-FunctionCallbackHandler) | Tracer that calls a function with a single str parameter. | + +### Functions + +| Name | Description | +|------|-------------| +| [`elapsed`](#langchain_core-tracers-stdout-elapsed) | Get the elapsed time of a run. | +| [`try_json_stringify`](#langchain_core-tracers-stdout-try_json_stringify) | Try to stringify an object to JSON. | + +### Data + +[`MILLISECONDS_IN_SECOND`](#langchain_core-tracers-stdout-MILLISECONDS_IN_SECOND) + +### API + + + + + +```python +class langchain_core.tracers.stdout.ConsoleCallbackHandler( + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [FunctionCallbackHandler](#langchain_core-tracers-stdout-FunctionCallbackHandler) + +Tracer that prints to the console. + + + + + + + + + + +```python +class langchain_core.tracers.stdout.FunctionCallbackHandler( + function: collections.abc.Callable[[str], None], + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [BaseTracer](/langchain-core/langchain_core/tracers/base#langchain_core-tracers-base-BaseTracer) + +Tracer that calls a function with a single str parameter. + + + +The name of the tracer. + +This is used to identify the tracer in the logs. + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler._on_chain_end( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler._on_chain_error( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler._on_chain_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler._on_llm_end( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler._on_llm_error( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler._on_llm_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler._on_tool_end( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler._on_tool_error( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler._on_tool_start( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler._persist_run( + run: langchain_core.tracers.schemas.Run +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler.get_breadcrumbs( + run: langchain_core.tracers.schemas.Run +) -> str +``` + + + + + + +Get the breadcrumbs of a run. + +**Parameters:** + + +The run to get the breadcrumbs of. + + +**Returns:** `str` + +A string with the breadcrumbs of the run. + + + + + + + +```python +langchain_core.tracers.stdout.FunctionCallbackHandler.get_parents( + run: langchain_core.tracers.schemas.Run +) -> list[langchain_core.tracers.schemas.Run] +``` + + + + + + +Get the parents of a run. + +**Parameters:** + + +The run to get the parents of. + + +**Returns:** `list[Run]` + +A list of parent runs. + + + + + + + + + +```python +langchain_core.tracers.stdout.elapsed( + run: typing.Any +) -> str +``` + + + + + + +Get the elapsed time of a run. + +**Parameters:** + + +any object with a `start_time` and `end_time` attribute. + + +**Returns:** `str` + +A string with the elapsed time in seconds or milliseconds if time is less than a +second. + + + + + + + + +```python +langchain_core.tracers.stdout.try_json_stringify( + obj: typing.Any, + fallback: str +) -> str +``` + + + + + + +Try to stringify an object to JSON. + +**Parameters:** + + +Object to stringify. + + + +Fallback string to return if the object cannot be stringified. + + +**Returns:** `str` + +A JSON string if the object can be stringified, otherwise the fallback string. + + + + + + + + +```python +langchain_core.tracers.stdout.MILLISECONDS_IN_SECOND = 1000 +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils.mdx new file mode 100644 index 0000000..98a3642 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils.mdx @@ -0,0 +1,105 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils +title: langchain_core.utils +--- + +Utility functions for LangChain. + +These functions do not depend on any other LangChain module. + +## Submodules + +- **[`langchain_core.utils._merge`](/langchain-core/langchain_core/utils/_merge)** +- **[`langchain_core.utils.aiter`](/langchain-core/langchain_core/utils/aiter)** +- **[`langchain_core.utils.env`](/langchain-core/langchain_core/utils/env)** +- **[`langchain_core.utils.formatting`](/langchain-core/langchain_core/utils/formatting)** +- **[`langchain_core.utils.function_calling`](/langchain-core/langchain_core/utils/function_calling)** +- **[`langchain_core.utils.html`](/langchain-core/langchain_core/utils/html)** +- **[`langchain_core.utils.image`](/langchain-core/langchain_core/utils/image)** +- **[`langchain_core.utils.input`](/langchain-core/langchain_core/utils/input)** +- **[`langchain_core.utils.interactive_env`](/langchain-core/langchain_core/utils/interactive_env)** +- **[`langchain_core.utils.iter`](/langchain-core/langchain_core/utils/iter)** +- **[`langchain_core.utils.json`](/langchain-core/langchain_core/utils/json)** +- **[`langchain_core.utils.json_schema`](/langchain-core/langchain_core/utils/json_schema)** +- **[`langchain_core.utils.mustache`](/langchain-core/langchain_core/utils/mustache)** +- **[`langchain_core.utils.pydantic`](/langchain-core/langchain_core/utils/pydantic)** +- **[`langchain_core.utils.strings`](/langchain-core/langchain_core/utils/strings)** +- **[`langchain_core.utils.usage`](/langchain-core/langchain_core/utils/usage)** +- **[`langchain_core.utils.utils`](/langchain-core/langchain_core/utils/utils)** +- **[`langchain_core.utils.uuid`](/langchain-core/langchain_core/utils/uuid)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-utils-__dir__) | - | +| [`__getattr__`](#langchain_core-utils-__getattr__) | - | + +### Data + +[`__all__`](#langchain_core-utils-__all__) + +[`_dynamic_imports`](#langchain_core-utils-_dynamic_imports) + +### API + + + + + +```python +langchain_core.utils.__dir__() -> list[str] +``` + + + + + + + + + + + + + +```python +langchain_core.utils.__getattr__( + attr_name: str +) -> object +``` + + + + + + + + + + + + + +```python +langchain_core.utils.__all__ = ('StrictFormatter', 'abatch_iterate', 'batch_iterate', 'build_extra_kwargs', 'ch... +``` + + + + + + + + + +```python +langchain_core.utils._dynamic_imports = {'image': '__module__', 'abatch_iterate': 'aiter', 'get_from_dict_or_env': 'env'... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/_merge.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/_merge.mdx new file mode 100644 index 0000000..2244f97 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/_merge.mdx @@ -0,0 +1,140 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/_merge +title: langchain_core.utils._merge +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`merge_dicts`](#langchain_core-utils-_merge-merge_dicts) | Merge dictionaries. | +| [`merge_lists`](#langchain_core-utils-_merge-merge_lists) | Add many lists, handling `None`. | +| [`merge_obj`](#langchain_core-utils-_merge-merge_obj) | Merge two objects. | + +### API + + + + + +```python +langchain_core.utils._merge.merge_dicts( + left: dict[str, typing.Any], + others: dict[str, typing.Any] = () +) -> dict[str, typing.Any] +``` + + + + + + +Merge dictionaries. + +Merge many dicts, handling specific scenarios where a key exists in both +dictionaries but has a value of `None` in `'left'`. In such cases, the method uses +the value from `'right'` for that key in the merged dictionary. + +**Parameters:** + + +The first dictionary to merge. + + + +The other dictionaries to merge. + + +**Returns:** `dict[str, Any]` + +The merged dictionary. + +**Raises:** + +- `TypeError`: If the key exists in both dictionaries but has a different type. +- `TypeError`: If the value has an unsupported type. + + + + + + + + +```python +langchain_core.utils._merge.merge_lists( + left: list | None, + others: list | None = () +) -> list | None +``` + + + + + + +Add many lists, handling `None`. + +**Parameters:** + + +The first list to merge. + + + +The other lists to merge. + + +**Returns:** `list | None` + +The merged list. + + + + + + + + +```python +langchain_core.utils._merge.merge_obj( + left: typing.Any, + right: typing.Any +) -> typing.Any +``` + + + + + + +Merge two objects. + +It handles specific scenarios where a key exists in both dictionaries but has a +value of `None` in `'left'`. In such cases, the method uses the value from `'right'` +for that key in the merged dictionary. + +**Parameters:** + + +The first object to merge. + + + +The other object to merge. + + +**Returns:** `Any` + +The merged object. + +**Raises:** + +- `TypeError`: If the key exists in both dictionaries but has a different type. +- `ValueError`: If the two objects cannot be merged. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/aiter.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/aiter.mdx new file mode 100644 index 0000000..1919329 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/aiter.mdx @@ -0,0 +1,541 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/aiter +title: langchain_core.utils.aiter +--- + +Asynchronous iterator utilities. + +Adapted from +https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py +MIT License. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`NoLock`](#langchain_core-utils-aiter-NoLock) | Dummy lock that provides the proper interface but no protection. | +| [`Tee`](#langchain_core-utils-aiter-Tee) | Create `n` separate asynchronous iterators over `iterable`. | +| [`aclosing`](#langchain_core-utils-aiter-aclosing) | Async context manager to wrap an `AsyncGenerator` that has a `aclose()` method. | + +### Functions + +| Name | Description | +|------|-------------| +| [`abatch_iterate`](#langchain_core-utils-aiter-abatch_iterate) | Utility batching function for async iterables. | +| [`py_anext`](#langchain_core-utils-aiter-py_anext) | Pure-Python implementation of `anext()` for testing purposes. | +| [`tee_peer`](#langchain_core-utils-aiter-tee_peer) | An individual iterator of a `tee`. | + +### Data + +[`T`](#langchain_core-utils-aiter-T) + +[`_no_default`](#langchain_core-utils-aiter-_no_default) + +[`atee`](#langchain_core-utils-aiter-atee) + +### API + + + + + +```python +class langchain_core.utils.aiter.NoLock() +``` + + + + + + +Dummy lock that provides the proper interface but no protection. + + + + + + +```python +langchain_core.utils.aiter.NoLock.__aenter__() -> None +``` + + + + + + +async + +Do nothing. + + + + + + + +```python +langchain_core.utils.aiter.NoLock.__aexit__( + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None +) -> bool +``` + + + + + + +async + +Return False, exception not suppressed. + + + + + + + + + +```python +class langchain_core.utils.aiter.Tee( + iterable: collections.abc.AsyncIterator[langchain_core.utils.aiter.T], + n: int = 2, + lock: contextlib.AbstractAsyncContextManager[typing.Any] | None = None +) +``` + + + + + + +**Bases:** `Generic[T]` + +Create `n` separate asynchronous iterators over `iterable`. + +This splits a single `iterable` into multiple iterators, each providing +the same items in the same order. + +All child iterators may advance separately but share the same items from `iterable` +-- when the most advanced iterator retrieves an item, it is buffered until the least +advanced iterator has yielded it as well. + +A `tee` works lazily and can handle an infinite `iterable`, provided +that all iterators advance. + + + +```python +async def derivative(sensor_data): + previous, current = a.tee(sensor_data, n=2) + await a.anext(previous) # advance one iterator + return a.map(operator.sub, previous, current) +``` + + + +Unlike `itertools.tee`, `.tee` returns a custom type instead of a `tuple`. Like a +tuple, it can be indexed, iterated and unpacked to get the child iterators. In +addition, its `.tee.aclose` method immediately closes all children, and it can be +used in an `async with` context for the same effect. + +If `iterable` is an iterator and read elsewhere, `tee` will *not* provide these +items. Also, `tee` must internally buffer each item until the last iterator has +yielded it; if the most and least advanced iterator differ by most data, using a +`list` is more efficient (but not lazy). + +If the underlying iterable is concurrency safe (`anext` may be awaited concurrently) +the resulting iterators are concurrency safe as well. Otherwise, the iterators are +safe if there is only ever one single "most advanced" iterator. + +To enforce sequential use of `anext`, provide a `lock` + +- e.g. an `asyncio.Lock` instance in an `asyncio` application - and access is + automatically synchronised. + + + + + + + + + + + + + + +```python +langchain_core.utils.aiter.Tee.__aenter__() -> langchain_core.utils.aiter.Tee[langchain_core.utils.aiter.T] +``` + + + + + + +async + +Return the tee instance. + + + + + + + +```python +langchain_core.utils.aiter.Tee.__aexit__( + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None +) -> bool +``` + + + + + + +async + +Close all child iterators. + +**Returns:** `bool` + +`False`, exceptions not suppressed. + + + + + + + +```python +langchain_core.utils.aiter.Tee.__getitem__( + item: int | slice +) -> collections.abc.AsyncIterator[langchain_core.utils.aiter.T] | tuple[collections.abc.AsyncIterator[langchain_core.utils.aiter.T], ...] +``` + + + + + + +Return the child iterator(s) for the given index or slice. + + + + + + + +```python +langchain_core.utils.aiter.Tee.__iter__() -> collections.abc.Iterator[collections.abc.AsyncIterator[langchain_core.utils.aiter.T]] +``` + + + + + + +Iterate over the child iterators. + + + + + + + +```python +langchain_core.utils.aiter.Tee.__len__() -> int +``` + + + + + + +Return the number of child iterators. + + + + + + + +```python +langchain_core.utils.aiter.Tee.aclose() -> None +``` + + + + + + +async + +Async close all child iterators. + + + + + + + + + +```python +class langchain_core.utils.aiter.aclosing( + thing: collections.abc.AsyncGenerator[typing.Any, typing.Any] | collections.abc.AsyncIterator[typing.Any] +) +``` + + + + + + +**Bases:** `AbstractAsyncContextManager` + +Async context manager to wrap an `AsyncGenerator` that has a `aclose()` method. + +Code like this: + + + +```python +async with aclosing(.fetch()) as agen: + +``` + + + +...is equivalent to this: + + + +```python +agen = .fetch() +try: + +finally: + await agen.aclose() + +``` + + + + + + + + +```python +langchain_core.utils.aiter.aclosing.__aenter__() -> collections.abc.AsyncGenerator[typing.Any, typing.Any] | collections.abc.AsyncIterator[typing.Any] +``` + + + + + + +async + + + + + + + +```python +langchain_core.utils.aiter.aclosing.__aexit__( + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None +) -> None +``` + + + + + + +async + + + + + + + + + +```python +langchain_core.utils.aiter.abatch_iterate( + size: int, + iterable: collections.abc.AsyncIterable[langchain_core.utils.aiter.T] +) -> collections.abc.AsyncIterator[list[langchain_core.utils.aiter.T]] +``` + + + + + + +async + +Utility batching function for async iterables. + +**Parameters:** + + +The size of the batch. + + + +The async iterable to batch. + + + + + + + + + +```python +langchain_core.utils.aiter.py_anext( + iterator: collections.abc.AsyncIterator[langchain_core.utils.aiter.T], + default: langchain_core.utils.aiter.T | typing.Any = _no_default +) -> collections.abc.Awaitable[langchain_core.utils.aiter.T | typing.Any | None] +``` + + + + + + +Pure-Python implementation of `anext()` for testing purposes. + +Closely matches the builtin `anext()` C implementation. + +Can be used to compare the built-in implementation of the inner coroutines machinery +to C-implementation of `__anext__()` and `send()` or `throw()` on the returned +generator. + +**Parameters:** + + +The async iterator to advance. + + + +The value to return if the iterator is exhausted. + +If not provided, a `StopAsyncIteration` exception is raised. + + +**Returns:** `Awaitable[T | Any | None]` + +The next value from the iterator, or the default value if the iterator is +exhausted. + +**Raises:** + +- `TypeError`: If the iterator is not an async iterator. + + + + + + + + +```python +langchain_core.utils.aiter.tee_peer( + iterator: collections.abc.AsyncIterator[langchain_core.utils.aiter.T], + buffer: collections.deque[langchain_core.utils.aiter.T], + peers: list[collections.deque[langchain_core.utils.aiter.T]], + lock: contextlib.AbstractAsyncContextManager[typing.Any] +) -> collections.abc.AsyncGenerator[langchain_core.utils.aiter.T, None] +``` + + + + + + +async + +An individual iterator of a `tee`. + +This function is a generator that yields items from the shared iterator +`iterator`. It buffers items until the least advanced iterator has yielded them as +well. + +The buffer is shared with all other peers. + +**Parameters:** + + +The shared iterator. + + + +The buffer for this peer. + + + +The buffers of all peers. + + + +The lock to synchronise access to the shared buffers. + + + + + + + + + +```python +langchain_core.utils.aiter.T = TypeVar('T') +``` + + + + + + + + + +```python +langchain_core.utils.aiter._no_default = object() +``` + + + + + + + + + +```python +langchain_core.utils.aiter.atee = Tee +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/env.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/env.mdx new file mode 100644 index 0000000..4e9a61e --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/env.mdx @@ -0,0 +1,145 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/env +title: langchain_core.utils.env +--- + +Utilities for environment variables. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`env_var_is_set`](#langchain_core-utils-env-env_var_is_set) | Check if an environment variable is set. | +| [`get_from_dict_or_env`](#langchain_core-utils-env-get_from_dict_or_env) | Get a value from a dictionary or an environment variable. | +| [`get_from_env`](#langchain_core-utils-env-get_from_env) | Get a value from a dictionary or an environment variable. | + +### API + + + + + +```python +langchain_core.utils.env.env_var_is_set( + env_var: str +) -> bool +``` + + + + + + +Check if an environment variable is set. + +**Parameters:** + + +The name of the environment variable. + + +**Returns:** `bool` + +`True` if the environment variable is set, `False` otherwise. + + + + + + + + +```python +langchain_core.utils.env.get_from_dict_or_env( + data: dict[str, typing.Any], + key: str | list[str], + env_key: str, + default: str | None = None +) -> str +``` + + + + + + +Get a value from a dictionary or an environment variable. + +**Parameters:** + + +The dictionary to look up the key in. + + + +The key to look up in the dictionary. + +This can be a list of keys to try in order. + + + +The environment variable to look up if the key is not +in the dictionary. + + + +The default value to return if the key is not in the dictionary +or the environment. + + +**Returns:** `str` + +The dict value or the environment variable value. + + + + + + + + +```python +langchain_core.utils.env.get_from_env( + key: str, + env_key: str, + default: str | None = None +) -> str +``` + + + + + + +Get a value from a dictionary or an environment variable. + +**Parameters:** + + +The key to look up in the dictionary. + + + +The environment variable to look up if the key is not +in the dictionary. + + + +The default value to return if the key is not in the dictionary +or the environment. + + +**Returns:** `str` + +The value of the key. + +**Raises:** + +- `ValueError`: If the key is not in the dictionary and no default value is +provided or if the environment variable is not set. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/formatting.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/formatting.mdx new file mode 100644 index 0000000..68e6bcb --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/formatting.mdx @@ -0,0 +1,145 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/formatting +title: langchain_core.utils.formatting +--- + +Utilities for formatting strings. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`StrictFormatter`](#langchain_core-utils-formatting-StrictFormatter) | A string formatter that enforces keyword-only argument substitution. | + +### Data + +[`formatter`](#langchain_core-utils-formatting-formatter) + +### API + + + + + +```python +class langchain_core.utils.formatting.StrictFormatter() +``` + + + + + + +**Bases:** `Formatter` + +A string formatter that enforces keyword-only argument substitution. + +This formatter extends Python's built-in `string.Formatter` to provide stricter +validation for prompt template formatting. It ensures that all variable +substitutions use keyword arguments rather than positional arguments, which improves +clarity and reduces errors when formatting prompt templates. + + + + + + +```python +langchain_core.utils.formatting.StrictFormatter.validate_input_variables( + format_string: str, + input_variables: list[str] +) -> None +``` + + + + + + +Validate that input variables match the placeholders in a format string. + +Checks that the provided input variables can be used to format the given string +without missing or extra keys. This is useful for validating prompt templates +before runtime. + +**Parameters:** + + +A string containing replacement fields to validate +against (e.g., `'Hello, {name}!'`). + + + +List of variable names expected to fill the +replacement fields. + + +**Raises:** + +- `KeyError`: If the format string contains placeholders not present +in input_variables. + + + + + + + +```python +langchain_core.utils.formatting.StrictFormatter.vformat( + format_string: str, + args: collections.abc.Sequence, + kwargs: collections.abc.Mapping[str, typing.Any] +) -> str +``` + + + + + + +Format a string using only keyword arguments. + +Overrides the base `vformat` to reject positional arguments, ensuring all +substitutions are explicit and named. + +**Parameters:** + + +A string containing replacement fields (e.g., `'{name}'`). + + + +Positional arguments (must be empty). + + + +Keyword arguments for substitution into the format string. + + +**Returns:** `str` + +The formatted string with all replacement fields substituted. + +**Raises:** + +- `ValueError`: If any positional arguments are provided. + + + + + + + + + +```python +langchain_core.utils.formatting.formatter = StrictFormatter() +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/function_calling.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/function_calling.mdx new file mode 100644 index 0000000..40951b5 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/function_calling.mdx @@ -0,0 +1,789 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/function_calling +title: langchain_core.utils.function_calling +--- + +Methods for creating function specs in the style of OpenAI Functions. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FunctionDescription`](#langchain_core-utils-function_calling-FunctionDescription) | Representation of a callable function to send to an LLM. | +| [`ToolDescription`](#langchain_core-utils-function_calling-ToolDescription) | Representation of a callable function to the OpenAI API. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_convert_any_typed_dicts_to_pydantic`](#langchain_core-utils-function_calling-_convert_any_typed_dicts_to_pydantic) | - | +| [`_convert_json_schema_to_openai_function`](#langchain_core-utils-function_calling-_convert_json_schema_to_openai_function) | Converts a Pydantic model to a function description for the OpenAI API. | +| [`_convert_pydantic_to_openai_function`](#langchain_core-utils-function_calling-_convert_pydantic_to_openai_function) | Converts a Pydantic model to a function description for the OpenAI API. | +| [`_convert_python_function_to_openai_function`](#langchain_core-utils-function_calling-_convert_python_function_to_openai_function) | Convert a Python function to an OpenAI function-calling API compatible dict. | +| [`_convert_typed_dict_to_openai_function`](#langchain_core-utils-function_calling-_convert_typed_dict_to_openai_function) | - | +| [`_format_tool_to_openai_function`](#langchain_core-utils-function_calling-_format_tool_to_openai_function) | Format tool into the OpenAI function API. | +| [`_get_python_function_name`](#langchain_core-utils-function_calling-_get_python_function_name) | Get the name of a Python function. | +| [`_parse_google_docstring`](#langchain_core-utils-function_calling-_parse_google_docstring) | Parse the function and argument descriptions from the docstring of a function. | +| [`_py_38_safe_origin`](#langchain_core-utils-function_calling-_py_38_safe_origin) | - | +| [`_recursive_set_additional_properties_false`](#langchain_core-utils-function_calling-_recursive_set_additional_properties_false) | - | +| [`_rm_titles`](#langchain_core-utils-function_calling-_rm_titles) | Recursively removes `'title'` fields from a JSON schema dictionary. | +| [`convert_to_json_schema`](#langchain_core-utils-function_calling-convert_to_json_schema) | Convert a schema representation to a JSON schema. | +| [`convert_to_openai_function`](#langchain_core-utils-function_calling-convert_to_openai_function) | Convert a raw function/class to an OpenAI function. | +| [`convert_to_openai_tool`](#langchain_core-utils-function_calling-convert_to_openai_tool) | Convert a tool-like object to an OpenAI tool schema. | +| [`tool_example_to_messages`](#langchain_core-utils-function_calling-tool_example_to_messages) | Convert an example into a list of messages that can be fed into an LLM. | + +### Data + +[`PYTHON_TO_JSON_TYPES`](#langchain_core-utils-function_calling-PYTHON_TO_JSON_TYPES) + +[`_MAX_TYPED_DICT_RECURSION`](#langchain_core-utils-function_calling-_MAX_TYPED_DICT_RECURSION) + +[`_MIN_DOCSTRING_BLOCKS`](#langchain_core-utils-function_calling-_MIN_DOCSTRING_BLOCKS) + +[`_ORIGIN_MAP`](#langchain_core-utils-function_calling-_ORIGIN_MAP) + +[`_WellKnownOpenAITools`](#langchain_core-utils-function_calling-_WellKnownOpenAITools) + +[`logger`](#langchain_core-utils-function_calling-logger) + +### API + + + + + +```python +class langchain_core.utils.function_calling.FunctionDescription +``` + + + + + + +**Bases:** `typing.TypedDict` + +Representation of a callable function to send to an LLM. + + +A description of the function. + + + +The name of the function. + + + +The parameters of the function. + + + + + + + + +```python +class langchain_core.utils.function_calling.ToolDescription +``` + + + + + + +**Bases:** `typing.TypedDict` + +Representation of a callable function to the OpenAI API. + + +The function description. + + + +The type of the tool. + + + + + + + + +```python +langchain_core.utils.function_calling._convert_any_typed_dicts_to_pydantic( + type_: type, + visited: dict[type, type], + depth: int = 0 +) -> type +``` + + + + + + + + + + + + + +```python +langchain_core.utils.function_calling._convert_json_schema_to_openai_function( + schema: dict, + name: str | None = None, + description: str | None = None, + rm_titles: bool = True +) -> langchain_core.utils.function_calling.FunctionDescription +``` + + + + + + +Converts a Pydantic model to a function description for the OpenAI API. + +**Parameters:** + + +The JSON schema to convert. + + + +The name of the function. + +If not provided, the title of the schema will be used. + + + +The description of the function. + +If not provided, the description of the schema will be used. + + + +Whether to remove titles from the schema. + + +**Returns:** `FunctionDescription` + +The function description. + + + + + + + + +```python +langchain_core.utils.function_calling._convert_pydantic_to_openai_function( + model: type, + name: str | None = None, + description: str | None = None, + rm_titles: bool = True +) -> langchain_core.utils.function_calling.FunctionDescription +``` + + + + + + +Converts a Pydantic model to a function description for the OpenAI API. + +**Parameters:** + + +The Pydantic model to convert. + + + +The name of the function. + +If not provided, the title of the schema will be used. + + + +The description of the function. + +If not provided, the description of the schema will be used. + + + +Whether to remove titles from the schema. + + +**Returns:** `FunctionDescription` + +The function description. + +**Raises:** + +- `TypeError`: If the model is not a Pydantic model. +- `TypeError`: If the model contains types that cannot be converted to JSON schema. + + + + + + + + +```python +langchain_core.utils.function_calling._convert_python_function_to_openai_function( + function: collections.abc.Callable +) -> langchain_core.utils.function_calling.FunctionDescription +``` + + + + + + +Convert a Python function to an OpenAI function-calling API compatible dict. + +Assumes the Python function has type hints and a docstring with a description. If +the docstring has Google Python style argument descriptions, these will be included +as well. + +**Parameters:** + + +The Python function to convert. + + +**Returns:** `FunctionDescription` + +The OpenAI function description. + + + + + + + + +```python +langchain_core.utils.function_calling._convert_typed_dict_to_openai_function( + typed_dict: type +) -> langchain_core.utils.function_calling.FunctionDescription +``` + + + + + + + + + + + + + +```python +langchain_core.utils.function_calling._format_tool_to_openai_function( + tool: langchain_core.tools.BaseTool +) -> langchain_core.utils.function_calling.FunctionDescription +``` + + + + + + +Format tool into the OpenAI function API. + +**Parameters:** + + +The tool to format. + + +**Returns:** `FunctionDescription` + +The function description. + +**Raises:** + +- `ValueError`: If the tool call schema is not supported. + + + + + + + + +```python +langchain_core.utils.function_calling._get_python_function_name( + function: collections.abc.Callable +) -> str +``` + + + + + + +Get the name of a Python function. + + + + + + + + +```python +langchain_core.utils.function_calling._parse_google_docstring( + docstring: str | None, + args: list[str], + error_on_invalid_docstring: bool = False +) -> tuple[str, dict] +``` + + + + + + +Parse the function and argument descriptions from the docstring of a function. + +Assumes the function docstring follows Google Python style guide. + +**Parameters:** + + +The docstring to parse. + + + +The list of argument names to extract descriptions for. + + + +Whether to raise an error if the docstring is +invalid. + + +**Returns:** `tuple[str, dict]` + +A tuple of the function description and a dictionary of argument descriptions. + + + + + + + + +```python +langchain_core.utils.function_calling._py_38_safe_origin( + origin: type +) -> type +``` + + + + + + + + + + + + + +```python +langchain_core.utils.function_calling._recursive_set_additional_properties_false( + schema: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + + + + + +```python +langchain_core.utils.function_calling._rm_titles( + kv: dict, + prev_key: str = '' +) -> dict +``` + + + + + + +Recursively removes `'title'` fields from a JSON schema dictionary. + +Remove `'title'` fields from the input JSON schema dictionary, +except when a `'title'` appears within a property definition under `'properties'`. + +**Parameters:** + + +The input JSON schema as a dictionary. + + + +The key from the parent dictionary, used to identify context. + + +**Returns:** `dict` + +A new dictionary with appropriate `'title'` fields removed. + + + + + + + + +```python +langchain_core.utils.function_calling.convert_to_json_schema( + schema: dict[str, typing.Any] | type[pydantic.BaseModel] | collections.abc.Callable | langchain_core.tools.BaseTool, + strict: bool | None = None +) -> dict[str, typing.Any] +``` + + + + + + +Convert a schema representation to a JSON schema. + +**Parameters:** + + +The schema to convert. + + + +If `True`, model output is guaranteed to exactly match the JSON Schema +provided in the function definition. + +If `None`, `strict` argument will not be included in function definition. + + +**Returns:** `dict[str, Any]` + +A JSON schema representation of the input schema. + +**Raises:** + +- `ValueError`: If the input is not a valid OpenAI-format tool. + + + + + + + + +```python +langchain_core.utils.function_calling.convert_to_openai_function( + function: collections.abc.Mapping[str, typing.Any] | type | collections.abc.Callable | langchain_core.tools.BaseTool, + strict: bool | None = None +) -> dict[str, typing.Any] +``` + + + + + + +Convert a raw function/class to an OpenAI function. + +!!! warning "Behavior changed in `langchain-core` 0.3.16" + + `description` and `parameters` keys are now optional. Only `name` is + required and guaranteed to be part of the output. + +**Parameters:** + + +A dictionary, Pydantic `BaseModel` class, `TypedDict` class, a +LangChain `Tool` object, or a Python function. + +If a dictionary is passed in, it is assumed to already be a valid OpenAI +function, a JSON schema with top-level `title` key specified, an Anthropic +format tool, or an Amazon Bedrock Converse format tool. + + + +If `True`, model output is guaranteed to exactly match the JSON Schema +provided in the function definition. + +If `None`, `strict` argument will not be included in function definition. + + +**Returns:** `dict[str, Any]` + +A dict version of the passed in function which is compatible with the OpenAI +function-calling API. + +**Raises:** + +- `ValueError`: If function is not in a supported format. + + + + + + + + +```python +langchain_core.utils.function_calling.convert_to_openai_tool( + tool: collections.abc.Mapping[str, typing.Any] | type[pydantic.BaseModel] | collections.abc.Callable | langchain_core.tools.BaseTool, + strict: bool | None = None +) -> dict[str, typing.Any] +``` + + + + + + +Convert a tool-like object to an OpenAI tool schema. + +[OpenAI tool schema reference](https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools) + +!!! warning "Behavior changed in `langchain-core` 0.3.16" + + `description` and `parameters` keys are now optional. Only `name` is + required and guaranteed to be part of the output. + +!!! warning "Behavior changed in `langchain-core` 0.3.44" + + Return OpenAI Responses API-style tools unchanged. This includes + any dict with `"type"` in `"file_search"`, `"function"`, + `"computer_use_preview"`, `"web_search_preview"`. + +!!! warning "Behavior changed in `langchain-core` 0.3.63" + + Added support for OpenAI's image generation built-in tool. + +**Parameters:** + + +Either a dictionary, a `pydantic.BaseModel` class, Python function, or +`BaseTool`. + +If a dictionary is passed in, it is assumed to already be a valid OpenAI +function, a JSON schema with top-level `title` key specified, an Anthropic +format tool, or an Amazon Bedrock Converse format tool. + + + +If `True`, model output is guaranteed to exactly match the JSON Schema +provided in the function definition. + +If `None`, `strict` argument will not be included in tool definition. + + +**Returns:** `dict[str, Any]` + +A dict version of the passed in tool which is compatible with the OpenAI +tool-calling API. + + + + + + + + +```python +langchain_core.utils.function_calling.tool_example_to_messages( + input: str, + tool_calls: list[pydantic.BaseModel], + tool_outputs: list[str] | None = None, + ai_response: str | None = None +) -> list[langchain_core.messages.BaseMessage] +``` + + + + + + +Convert an example into a list of messages that can be fed into an LLM. + +This code is an adapter that converts a single example to a list of messages +that can be fed into a chat model. + +The list of messages per example by default corresponds to: + +1. `HumanMessage`: contains the content from which content should be extracted. +2. `AIMessage`: contains the extracted information from the model +3. `ToolMessage`: contains confirmation to the model that the model requested a + tool correctly. + +If `ai_response` is specified, there will be a final `AIMessage` with that +response. + +The `ToolMessage` is required because some chat models are hyper-optimized for +agents rather than for an extraction use case. + +**Parameters:** + + +The user input + + + +Tool calls represented as Pydantic BaseModels + + + +Tool call outputs. + +Does not need to be provided. + +If not provided, a placeholder value will be inserted. + + + +If provided, content for a final `AIMessage`. + + +**Returns:** `list[BaseMessage]` + +A list of messages + +**Examples:** + + + +```python +from typing import Optional +from pydantic import BaseModel, Field +from langchain_openai import ChatOpenAI + + +class Person(BaseModel): + '''Information about a person.''' + + name: str | None = Field(..., description="The name of the person") + hair_color: str | None = Field( + ..., description="The color of the person's hair if known" + ) + height_in_meters: str | None = Field(..., description="Height in METERS") + + +examples = [ + ( + "The ocean is vast and blue. It's more than 20,000 feet deep.", + Person(name=None, height_in_meters=None, hair_color=None), + ), + ( + "Fiona traveled far from France to Spain.", + Person(name="Fiona", height_in_meters=None, hair_color=None), + ), +] + + +messages = [] + +for txt, tool_call in examples: + messages.extend(tool_example_to_messages(txt, [tool_call])) +``` + + + + + + + + + + +```python +langchain_core.utils.function_calling.PYTHON_TO_JSON_TYPES = {'str': 'string', 'int': 'integer', 'float': 'number', 'bool': 'boolean'} +``` + + + + + + + + + +```python +langchain_core.utils.function_calling._MAX_TYPED_DICT_RECURSION = 25 +``` + + + + + + + + + +```python +langchain_core.utils.function_calling._MIN_DOCSTRING_BLOCKS = 2 +``` + + + + + + + + + +```python +langchain_core.utils.function_calling._ORIGIN_MAP: dict[type, Any] = {dict: dict, list: list, tuple: tuple, set: set, collections.abc.Iterable: typin... +``` + + + + + + + + + +```python +langchain_core.utils.function_calling._WellKnownOpenAITools = ('function', 'file_search', 'computer_use_preview', 'code_interpreter', 'mcp', '... +``` + + + + + + + + + +```python +langchain_core.utils.function_calling.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/html.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/html.mdx new file mode 100644 index 0000000..55ccaf4 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/html.mdx @@ -0,0 +1,201 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/html +title: langchain_core.utils.html +--- + +Utilities for working with HTML. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`extract_sub_links`](#langchain_core-utils-html-extract_sub_links) | Extract all links from a raw HTML string and convert into absolute paths. | +| [`find_all_links`](#langchain_core-utils-html-find_all_links) | Extract all links from a raw HTML string. | + +### Data + +[`DEFAULT_LINK_REGEX`](#langchain_core-utils-html-DEFAULT_LINK_REGEX) + +[`PREFIXES_TO_IGNORE`](#langchain_core-utils-html-PREFIXES_TO_IGNORE) + +[`PREFIXES_TO_IGNORE_REGEX`](#langchain_core-utils-html-PREFIXES_TO_IGNORE_REGEX) + +[`SUFFIXES_TO_IGNORE`](#langchain_core-utils-html-SUFFIXES_TO_IGNORE) + +[`SUFFIXES_TO_IGNORE_REGEX`](#langchain_core-utils-html-SUFFIXES_TO_IGNORE_REGEX) + +[`logger`](#langchain_core-utils-html-logger) + +### API + + + + + +```python +langchain_core.utils.html.extract_sub_links( + raw_html: str, + url: str, + base_url: str | None = None, + pattern: str | re.Pattern | None = None, + prevent_outside: bool = True, + exclude_prefixes: collections.abc.Sequence[str] = (), + continue_on_failure: bool = False +) -> list[str] +``` + + + + + + +Extract all links from a raw HTML string and convert into absolute paths. + +**Parameters:** + + +Original HTML. + + + +The url of the HTML. + + + +the base URL to check for outside links against. + + + +Regex to use for extracting links from raw HTML. + + + +If `True`, ignore external links which are not children +of the base URL. + + + +Exclude any URLs that start with one of these prefixes. + + + +If `True`, continue if parsing a specific link raises an +exception. Otherwise, raise the exception. + + +**Returns:** `list[str]` + +A list of absolute paths to sub links. + + + + + + + + +```python +langchain_core.utils.html.find_all_links( + raw_html: str, + pattern: str | re.Pattern | None = None +) -> list[str] +``` + + + + + + +Extract all links from a raw HTML string. + +**Parameters:** + + +original HTML. + + + +Regex to use for extracting links from raw HTML. + + +**Returns:** `list[str]` + +A list of all links found in the HTML. + + + + + + + + +```python +langchain_core.utils.html.DEFAULT_LINK_REGEX = f'href=[\"']{PREFIXES_TO_IGNORE_REGEX}((?:{SUFFIXES_TO_IGNORE_REGEX}.)*?)[\#'\"]... +``` + + + + + + + + + +```python +langchain_core.utils.html.PREFIXES_TO_IGNORE = ('javascript:', 'mailto:', '#') +``` + + + + + + + + + +```python +langchain_core.utils.html.PREFIXES_TO_IGNORE_REGEX = '(?!' + '|'.join([(re.escape(s)) for s in PREFIXES_TO_IGNORE]) + ')' +``` + + + + + + + + + +```python +langchain_core.utils.html.SUFFIXES_TO_IGNORE = ('.css', '.js', '.ico', '.png', '.jpg', '.jpeg', '.gif', '.svg', '.csv', '.bz2',... +``` + + + + + + + + + +```python +langchain_core.utils.html.SUFFIXES_TO_IGNORE_REGEX = '(?!' + '|'.join([(re.escape(s) + '[\\#\'\\"]') for s in SUFFIXES_TO_IGNORE]) + ... +``` + + + + + + + + + +```python +langchain_core.utils.html.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/image.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/image.mdx new file mode 100644 index 0000000..0667be8 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/image.mdx @@ -0,0 +1,35 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/image +title: langchain_core.utils.image +--- + +Utilities for image processing. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__getattr__`](#langchain_core-utils-image-__getattr__) | - | + +### API + + + + + +```python +langchain_core.utils.image.__getattr__( + name: str +) -> typing.Any +``` + + + + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/input.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/input.mdx new file mode 100644 index 0000000..c931e86 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/input.mdx @@ -0,0 +1,185 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/input +title: langchain_core.utils.input +--- + +Handle chained inputs. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_bolded_text`](#langchain_core-utils-input-get_bolded_text) | Get bolded text. | +| [`get_color_mapping`](#langchain_core-utils-input-get_color_mapping) | Get mapping for items to a support color. | +| [`get_colored_text`](#langchain_core-utils-input-get_colored_text) | Get colored text. | +| [`print_text`](#langchain_core-utils-input-print_text) | Print text with highlighting and no end characters. | + +### Data + +[`_TEXT_COLOR_MAPPING`](#langchain_core-utils-input-_TEXT_COLOR_MAPPING) + +### API + + + + + +```python +langchain_core.utils.input.get_bolded_text( + text: str +) -> str +``` + + + + + + +Get bolded text. + +**Parameters:** + + +The text to bold. + + +**Returns:** `str` + +The bolded text. + + + + + + + + +```python +langchain_core.utils.input.get_color_mapping( + items: list[str], + excluded_colors: list | None = None +) -> dict[str, str] +``` + + + + + + +Get mapping for items to a support color. + +**Parameters:** + + +The items to map to colors. + + + +The colors to exclude. + + +**Returns:** `dict[str, str]` + +The mapping of items to colors. + +**Raises:** + +- `ValueError`: If no colors are available after applying exclusions. + + + + + + + + +```python +langchain_core.utils.input.get_colored_text( + text: str, + color: str +) -> str +``` + + + + + + +Get colored text. + +**Parameters:** + + +The text to color. + + + +The color to use. + + +**Returns:** `str` + +The colored text. + + + + + + + + +```python +langchain_core.utils.input.print_text( + text: str, + color: str | None = None, + end: str = '', + file: typing.TextIO | None = None +) -> None +``` + + + + + + +Print text with highlighting and no end characters. + +If a color is provided, the text will be printed in that color. + +If a file is provided, the text will be written to that file. + +**Parameters:** + + +The text to print. + + + +The color to use. + + + +The end character to use. + + + +The file to write to. + + + + + + + + + +```python +langchain_core.utils.input._TEXT_COLOR_MAPPING = {'blue': '36;1', 'yellow': '33;1', 'pink': '38;5;200', 'green': '32;1', 'red': '... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/interactive_env.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/interactive_env.mdx new file mode 100644 index 0000000..a17f673 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/interactive_env.mdx @@ -0,0 +1,39 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/interactive_env +title: langchain_core.utils.interactive_env +--- + +Utilities for working with interactive environments. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`is_interactive_env`](#langchain_core-utils-interactive_env-is_interactive_env) | Determine if running within IPython or Jupyter. | + +### API + + + + + +```python +langchain_core.utils.interactive_env.is_interactive_env() -> bool +``` + + + + + + +Determine if running within IPython or Jupyter. + +**Returns:** `bool` + +`True` if running in an interactive environment, `False` otherwise. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/iter.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/iter.mdx new file mode 100644 index 0000000..5612668 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/iter.mdx @@ -0,0 +1,370 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/iter +title: langchain_core.utils.iter +--- + +Utilities for working with iterators. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`NoLock`](#langchain_core-utils-iter-NoLock) | Dummy lock that provides the proper interface but no protection. | +| [`Tee`](#langchain_core-utils-iter-Tee) | Create `n` separate asynchronous iterators over `iterable`. | + +### Functions + +| Name | Description | +|------|-------------| +| [`batch_iterate`](#langchain_core-utils-iter-batch_iterate) | Utility batching function. | +| [`tee_peer`](#langchain_core-utils-iter-tee_peer) | An individual iterator of a `.tee`. | + +### Data + +[`T`](#langchain_core-utils-iter-T) + +[`safetee`](#langchain_core-utils-iter-safetee) + +### API + + + + + +```python +class langchain_core.utils.iter.NoLock() +``` + + + + + + +Dummy lock that provides the proper interface but no protection. + + + + + + +```python +langchain_core.utils.iter.NoLock.__enter__() -> None +``` + + + + + + +Do nothing. + + + + + + + +```python +langchain_core.utils.iter.NoLock.__exit__( + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None +) -> typing.Literal[False] +``` + + + + + + +Return False (exception not suppressed). + + + + + + + + + +```python +class langchain_core.utils.iter.Tee( + iterable: collections.abc.Iterator[langchain_core.utils.iter.T], + n: int = 2, + lock: contextlib.AbstractContextManager[typing.Any] | None = None +) +``` + + + + + + +**Bases:** `Generic[T]` + +Create `n` separate asynchronous iterators over `iterable`. + +This splits a single `iterable` into multiple iterators, each providing the same +items in the same order. + +All child iterators may advance separately but share the same items from `iterable` +-- when the most advanced iterator retrieves an item, it is buffered until the least +advanced iterator has yielded it as well. A `tee` works lazily and can handle an +infinite `iterable`, provided that all iterators advance. + + + +```python +async def derivative(sensor_data): + previous, current = a.tee(sensor_data, n=2) + await a.anext(previous) # advance one iterator + return a.map(operator.sub, previous, current) +``` + + + +Unlike `itertools.tee`, `.tee` returns a custom type instead of a `tuple`. Like a +tuple, it can be indexed, iterated and unpacked to get the child iterators. In +addition, its `.tee.aclose` method immediately closes all children, and it can be +used in an `async with` context for the same effect. + +If `iterable` is an iterator and read elsewhere, `tee` will *not* provide these +items. Also, `tee` must internally buffer each item until the last iterator has +yielded it; if the most and least advanced iterator differ by most data, using a +`list` is more efficient (but not lazy). + +If the underlying iterable is concurrency safe (`anext` may be awaited concurrently) +the resulting iterators are concurrency safe as well. Otherwise, the iterators are +safe if there is only ever one single "most advanced" iterator. To enforce +sequential use of `anext`, provide a `lock` + +- e.g., an `asyncio.Lock` instance in an `asyncio` application - and access is + automatically synchronised. + + + + + + + + + + + + + + +```python +langchain_core.utils.iter.Tee.__enter__() -> langchain_core.utils.iter.Tee[langchain_core.utils.iter.T] +``` + + + + + + +Return `Tee` instance. + + + + + + + +```python +langchain_core.utils.iter.Tee.__exit__( + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None +) -> typing.Literal[False] +``` + + + + + + +Close all child iterators. + +**Returns:** `Literal[False]` + +`False` (exception not suppressed). + + + + + + + +```python +langchain_core.utils.iter.Tee.__getitem__( + item: int | slice +) -> collections.abc.Iterator[langchain_core.utils.iter.T] | tuple[collections.abc.Iterator[langchain_core.utils.iter.T], ...] +``` + + + + + + +Return the child iterator(s) at the given index or slice. + + + + + + + +```python +langchain_core.utils.iter.Tee.__iter__() -> collections.abc.Iterator[collections.abc.Iterator[langchain_core.utils.iter.T]] +``` + + + + + + +Return an iterator over the child iterators. + + + + + + + +```python +langchain_core.utils.iter.Tee.__len__() -> int +``` + + + + + + +Return the number of child iterators. + + + + + + + +```python +langchain_core.utils.iter.Tee.close() -> None +``` + + + + + + +Close all child iterators. + + + + + + + + + +```python +langchain_core.utils.iter.batch_iterate( + size: int | None, + iterable: collections.abc.Iterable[langchain_core.utils.iter.T] +) -> collections.abc.Iterator[list[langchain_core.utils.iter.T]] +``` + + + + + + +Utility batching function. + +**Parameters:** + + +The size of the batch. + +If `None`, returns a single batch. + + + +The iterable to batch. + + + + + + + + + +```python +langchain_core.utils.iter.tee_peer( + iterator: collections.abc.Iterator[langchain_core.utils.iter.T], + buffer: collections.deque[langchain_core.utils.iter.T], + peers: list[collections.deque[langchain_core.utils.iter.T]], + lock: contextlib.AbstractContextManager[typing.Any] +) -> collections.abc.Generator[langchain_core.utils.iter.T, None, None] +``` + + + + + + +An individual iterator of a `.tee`. + +This function is a generator that yields items from the shared iterator `iterator`. +It buffers items until the least advanced iterator has yielded them as well. The +buffer is shared with all other peers. + +**Parameters:** + + +The shared iterator. + + + +The buffer for this peer. + + + +The buffers of all peers. + + + +The lock to synchronise access to the shared buffers. + + + + + + + + + +```python +langchain_core.utils.iter.T = TypeVar('T') +``` + + + + + + + + + +```python +langchain_core.utils.iter.safetee = Tee +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/json.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/json.mdx new file mode 100644 index 0000000..b8ac615 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/json.mdx @@ -0,0 +1,260 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/json +title: langchain_core.utils.json +--- + +Utilities for JSON. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_custom_parser`](#langchain_core-utils-json-_custom_parser) | Custom parser for multiline strings. | +| [`_parse_json`](#langchain_core-utils-json-_parse_json) | Parse a JSON string, handling special characters and whitespace. | +| [`_replace_new_line`](#langchain_core-utils-json-_replace_new_line) | Replace newline characters in a regex match with escaped sequences. | +| [`parse_and_check_json_markdown`](#langchain_core-utils-json-parse_and_check_json_markdown) | Parse and check a JSON string from a Markdown string. | +| [`parse_json_markdown`](#langchain_core-utils-json-parse_json_markdown) | Parse a JSON string from a Markdown string. | +| [`parse_partial_json`](#langchain_core-utils-json-parse_partial_json) | Parse a JSON string that may be missing closing braces. | + +### Data + +[`_json_markdown_re`](#langchain_core-utils-json-_json_markdown_re) + +[`_json_strip_chars`](#langchain_core-utils-json-_json_strip_chars) + +### API + + + + + +```python +langchain_core.utils.json._custom_parser( + multiline_string: str | bytes | bytearray +) -> str +``` + + + + + + +Custom parser for multiline strings. + +The LLM response for `action_input` may be a multiline string containing unescaped +newlines, tabs or quotes. This function replaces those characters with their escaped +counterparts. (newlines in JSON must be double-escaped: `\\n`). + +**Returns:** `str` + +The modified string with escaped newlines, tabs and quotes. + + + + + + + + +```python +langchain_core.utils.json._parse_json( + json_str: str, + parser: collections.abc.Callable[[str], typing.Any] = parse_partial_json +) -> typing.Any +``` + + + + + + +Parse a JSON string, handling special characters and whitespace. + +Strips whitespace, newlines, and backticks from the start and end of the string, +then processes special characters before parsing. + +**Parameters:** + + +The JSON string to parse. + + + +Optional custom parser function. + + +**Returns:** `Any` + +Parsed JSON object. + + + + + + + + +```python +langchain_core.utils.json._replace_new_line( + match: re.Match[str] +) -> str +``` + + + + + + +Replace newline characters in a regex match with escaped sequences. + +**Parameters:** + + +Regex match object containing the string to process. + + +**Returns:** `str` + +String with newlines, carriage returns, tabs, and quotes properly escaped. + + + + + + + + +```python +langchain_core.utils.json.parse_and_check_json_markdown( + text: str, + expected_keys: list[str] +) -> dict +``` + + + + + + +Parse and check a JSON string from a Markdown string. + +Checks that it contains the expected keys. + +**Parameters:** + + +The Markdown string. + + + +The expected keys in the JSON string. + + +**Returns:** `dict` + +The parsed JSON object as a Python dictionary. + +**Raises:** + +- `OutputParserException`: If the JSON string is invalid or does not contain +the expected keys. + + + + + + + + +```python +langchain_core.utils.json.parse_json_markdown( + json_string: str, + parser: collections.abc.Callable[[str], typing.Any] = parse_partial_json +) -> typing.Any +``` + + + + + + +Parse a JSON string from a Markdown string. + +**Parameters:** + + +The Markdown string. + + + +The parser to use. + + +**Returns:** `Any` + +The parsed JSON object as a Python dictionary. + + + + + + + + +```python +langchain_core.utils.json.parse_partial_json( + s: str, + strict: bool = False +) -> typing.Any +``` + + + + + + +Parse a JSON string that may be missing closing braces. + +**Parameters:** + + +The JSON string to parse. + + + +Whether to use strict parsing. + + +**Returns:** `Any` + +The parsed JSON object as a Python dictionary. + + + + + + + + +```python +langchain_core.utils.json._json_markdown_re = re.compile('```(json)?(.*)', re.DOTALL) +``` + + + + + + + + + +```python +langchain_core.utils.json._json_strip_chars = ' \n\r\t`' +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/json_schema.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/json_schema.mdx new file mode 100644 index 0000000..c740f59 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/json_schema.mdx @@ -0,0 +1,283 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/json_schema +title: langchain_core.utils.json_schema +--- + +Utilities for JSON Schema. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_dereference_refs_helper`](#langchain_core-utils-json_schema-_dereference_refs_helper) | Dereference JSON Schema $ref objects, handling both pure and mixed references. | +| [`_process_dict_properties`](#langchain_core-utils-json_schema-_process_dict_properties) | Process dictionary properties, recursing into nested structures. | +| [`_retrieve_ref`](#langchain_core-utils-json_schema-_retrieve_ref) | Retrieve a referenced object from a JSON schema using a path. | +| [`dereference_refs`](#langchain_core-utils-json_schema-dereference_refs) | Resolve and inline JSON Schema `$ref` references in a schema object. | + +### API + + + + + +```python +langchain_core.utils.json_schema._dereference_refs_helper( + obj: typing.Any, + full_schema: dict[str, typing.Any], + processed_refs: set[str] | None, + skip_keys: collections.abc.Sequence[str], + shallow_refs: bool +) -> typing.Any +``` + + + + + + +Dereference JSON Schema $ref objects, handling both pure and mixed references. + +This function processes JSON Schema objects containing $ref properties by resolving +the references and merging any additional properties. It handles: + +- Pure `$ref` objects: `{"$ref": "#/path/to/definition"}` +- Mixed `$ref` objects: `{"$ref": "#/path", "title": "Custom Title", ...}` +- Circular references by breaking cycles and preserving non-ref properties + +**Parameters:** + + +The object to process (can be dict, list, or primitive) + + + +The complete schema containing all definitions + + + +Set tracking currently processing refs (for cycle detection) + + + +Keys under which to skip recursion + + + +If `True`, only break cycles; if `False`, deep-inline all refs + + +**Returns:** `Any` + +The object with `$ref` properties resolved and merged with other properties. + + + + + + + + +```python +langchain_core.utils.json_schema._process_dict_properties( + properties: dict[str, typing.Any], + full_schema: dict[str, typing.Any], + processed_refs: set[str], + skip_keys: collections.abc.Sequence[str], + shallow_refs: bool +) -> dict[str, typing.Any] +``` + + + + + + +Process dictionary properties, recursing into nested structures. + + + + + + + + +```python +langchain_core.utils.json_schema._retrieve_ref( + path: str, + schema: dict +) -> list | dict +``` + + + + + + +Retrieve a referenced object from a JSON schema using a path. + +Resolves JSON schema references (e.g., `'#/definitions/MyType'`) by traversing the +schema structure. + +**Parameters:** + + +Reference path starting with `'#'` (e.g., `'#/definitions/MyType'`). + + + +The JSON schema dictionary to search in. + + +**Returns:** `list | dict` + +A deep copy of the referenced object (dict or list). + +**Raises:** + +- `ValueError`: If the path does not start with `'#'`. +- `KeyError`: If the reference path is not found in the schema. + + + + + + + + +```python +langchain_core.utils.json_schema.dereference_refs( + schema_obj: dict, + full_schema: dict | None = None, + skip_keys: collections.abc.Sequence[str] | None = None +) -> dict +``` + + + + + + +Resolve and inline JSON Schema `$ref` references in a schema object. + +This function processes a JSON Schema and resolves all `$ref` references by +replacing them with the actual referenced content. + +Handles both simple references and complex cases like circular references and mixed +`$ref` objects that contain additional properties alongside the `$ref`. + +!!! note + + - Circular references are handled gracefully by breaking cycles + - Mixed `$ref` objects (with both `$ref` and other properties) are supported + - Additional properties in mixed `$refs` override resolved properties + - The `$defs` section is preserved in the output by default + +**Parameters:** + + +The JSON Schema object or fragment to process. + +This can be a complete schema or just a portion of one. + + + +The complete schema containing all definitions that `$refs` might +point to. + +If not provided, defaults to `schema_obj` (useful when the schema is +self-contained). + + + +Controls recursion behavior and reference resolution depth. + +- If `None` (Default): Only recurse under `'$defs'` and use shallow + reference resolution (break cycles but don't deep-inline nested refs) +- If provided (even as `[]`): Recurse under all keys and use deep reference + resolution (fully inline all nested references) + + +**Returns:** `dict` + +A new dictionary with all $ref references resolved and inlined. + +The original `schema_obj` is not modified. + +**Examples:** + + + +```python +Basic reference resolution: +``` + + + + + +```python +>>> schema = { +... "type": "object", +... "properties": {"name": {"$ref": "#/$defs/string_type"}}, +... "$defs": {"string_type": {"type": "string"}}, +... } +>>> result = dereference_refs(schema) +>>> result["properties"]["name"] # {"type": "string"} +``` + + + + + +```python +Mixed `$ref` with additional properties: +``` + + + + + +```python +>>> schema = { +... "properties": { +... "name": {"$ref": "#/$defs/base", "description": "User name"} +... }, +... "$defs": {"base": {"type": "string", "minLength": 1}}, +... } +>>> result = dereference_refs(schema) +>>> result["properties"]["name"] +# {"type": "string", "minLength": 1, "description": "User name"} +``` + + + + + +```python +Handling circular references: +``` + + + + + +```python +>>> schema = { +... "properties": {"user": {"$ref": "#/$defs/User"}}, +... "$defs": { +... "User": { +... "type": "object", +... "properties": {"friend": {"$ref": "#/$defs/User"}}, +... } +... }, +... } +>>> result = dereference_refs(schema) # Won't cause infinite recursion +``` + + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/mustache.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/mustache.mdx new file mode 100644 index 0000000..738c3ab --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/mustache.mdx @@ -0,0 +1,532 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/mustache +title: langchain_core.utils.mustache +--- + +Adapted from https://github.com/noahmorrison/chevron. + +MIT License. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ChevronError`](#langchain_core-utils-mustache-ChevronError) | Custom exception for Chevron errors. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_key`](#langchain_core-utils-mustache-_get_key) | Retrieve a value from the current scope using a dot-separated key path. | +| [`_get_partial`](#langchain_core-utils-mustache-_get_partial) | Load a partial. | +| [`_html_escape`](#langchain_core-utils-mustache-_html_escape) | Return the HTML-escaped string with these characters escaped: `" & < >`. | +| [`grab_literal`](#langchain_core-utils-mustache-grab_literal) | Parse a literal from the template. | +| [`l_sa_check`](#langchain_core-utils-mustache-l_sa_check) | Do a preliminary check to see if a tag could be a standalone. | +| [`parse_tag`](#langchain_core-utils-mustache-parse_tag) | Parse a tag from a template. | +| [`r_sa_check`](#langchain_core-utils-mustache-r_sa_check) | Do a final check to see if a tag could be a standalone. | +| [`render`](#langchain_core-utils-mustache-render) | Render a mustache template. | +| [`tokenize`](#langchain_core-utils-mustache-tokenize) | Tokenize a mustache template. | + +### Data + +[`EMPTY_DICT`](#langchain_core-utils-mustache-EMPTY_DICT) + +[`Scopes`](#langchain_core-utils-mustache-Scopes) + +[`_CURRENT_LINE`](#langchain_core-utils-mustache-_CURRENT_LINE) + +[`_LAST_TAG_LINE`](#langchain_core-utils-mustache-_LAST_TAG_LINE) + +[`g_token_cache`](#langchain_core-utils-mustache-g_token_cache) + +[`logger`](#langchain_core-utils-mustache-logger) + +### API + + + + + +```python +class langchain_core.utils.mustache.ChevronError() +``` + + + + + + +**Bases:** `SyntaxError` + +Custom exception for Chevron errors. + + + + + + + + +```python +langchain_core.utils.mustache._get_key( + key: str, + scopes: langchain_core.utils.mustache.Scopes, + warn: bool, + keep: bool, + def_ldel: str, + def_rdel: str +) -> typing.Any +``` + + + + + + +Retrieve a value from the current scope using a dot-separated key path. + +Traverses through nested dictionaries and lists using dot notation. + +Supports special key `'.'` to return the current scope. + +**Parameters:** + + +Dot-separated key path (e.g., `'user.name'` or `'.'` for current scope). + + + +List of scope dictionaries to search through. + + + +Whether to log a warning when a key is not found. + + + +Whether to return the original template tag when key is not found. + + + +Left delimiter for template (used when keep is `True`). + + + +Right delimiter for template (used when keep is `True`). + + +**Returns:** `Any` + +The value found at the key path. + +If not found, returns the original template tag when keep is `True`, +otherwise returns an empty string. + + + + + + + + +```python +langchain_core.utils.mustache._get_partial( + name: str, + partials_dict: collections.abc.Mapping[str, str] +) -> str +``` + + + + + + +Load a partial. + +**Returns:** `str` + +The partial. + + + + + + + + +```python +langchain_core.utils.mustache._html_escape( + string: str +) -> str +``` + + + + + + +Return the HTML-escaped string with these characters escaped: `" & < >`. + + + + + + + + +```python +langchain_core.utils.mustache.grab_literal( + template: str, + l_del: str +) -> tuple[str, str] +``` + + + + + + +Parse a literal from the template. + +**Parameters:** + + +The template to parse. + + + +The left delimiter. + + +**Returns:** `tuple[str, str]` + +The literal and the template. + + + + + + + + +```python +langchain_core.utils.mustache.l_sa_check( + template: str, + literal: str, + is_standalone: bool +) -> bool +``` + + + + + + +Do a preliminary check to see if a tag could be a standalone. + +**Parameters:** + + +The template. (Not used.) + + + +The literal. + + + +Whether the tag is standalone. + + +**Returns:** `bool` + +Whether the tag could be a standalone. + + + + + + + + +```python +langchain_core.utils.mustache.parse_tag( + template: str, + l_del: str, + r_del: str +) -> tuple[tuple[str, str], str] +``` + + + + + + +Parse a tag from a template. + +**Parameters:** + + +The template. + + + +The left delimiter. + + + +The right delimiter. + + +**Returns:** `tuple[tuple[str, str], str]` + +The tag and the template. + +**Raises:** + +- `ChevronError`: If the tag is unclosed. +- `ChevronError`: If the set delimiter tag is unclosed. + + + + + + + + +```python +langchain_core.utils.mustache.r_sa_check( + template: str, + tag_type: str, + is_standalone: bool +) -> bool +``` + + + + + + +Do a final check to see if a tag could be a standalone. + +**Parameters:** + + +The template. + + + +The type of the tag. + + + +Whether the tag is standalone. + + +**Returns:** `bool` + +Whether the tag could be a standalone. + + + + + + + + +```python +langchain_core.utils.mustache.render( + template: str | list[tuple[str, str]] = '', + data: collections.abc.Mapping[str, typing.Any] = EMPTY_DICT, + partials_dict: collections.abc.Mapping[str, str] = EMPTY_DICT, + padding: str = '', + def_ldel: str = '{{', + def_rdel: str = '}}', + scopes: langchain_core.utils.mustache.Scopes | None = None, + warn: bool = False, + keep: bool = False +) -> str +``` + + + + + + +Render a mustache template. + +Renders a mustache template with a data scope and inline partial capability. + +**Parameters:** + + +A file-like object or a string containing the template. + + + +A python dictionary with your data scope. + + + +A python dictionary which will be search for partials +before the filesystem is. + +`{'include': 'foo'}` is the same as a file called include.mustache +(defaults to `{}`). + + + +This is for padding partials, and shouldn't be used +(but can be if you really want to). + + + +The default left delimiter + +(`'{{'` by default, as in spec compliant mustache). + + + +The default right delimiter + +(`'}}'` by default, as in spec compliant mustache). + + + +The list of scopes that `get_key` will look through. + + + +Log a warning when a template substitution isn't found in the data + + + +Keep unreplaced tags when a substitution isn't found in the data. + + +**Returns:** `str` + +A string containing the rendered template. + + + + + + + + +```python +langchain_core.utils.mustache.tokenize( + template: str, + def_ldel: str = '{{', + def_rdel: str = '}}' +) -> collections.abc.Iterator[tuple[str, str]] +``` + + + + + + +Tokenize a mustache template. + +Tokenizes a mustache template in a generator fashion, using file-like objects. It +also accepts a string containing the template. + +**Parameters:** + + +a file-like object, or a string of a mustache template + + + +The default left delimiter +(`'{{'` by default, as in spec compliant mustache) + + + +The default right delimiter +(`'}}'` by default, as in spec compliant mustache) + + +**Raises:** + +- `ChevronError`: If there is a syntax error in the template. + + + + + + + + +```python +langchain_core.utils.mustache.EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({}) +``` + + + + + + + + + +```python +langchain_core.utils.mustache.Scopes: TypeAlias = list[Literal[False, 0] | Mapping[str, Any]] +``` + + + + + + + + + +```python +langchain_core.utils.mustache._CURRENT_LINE = 1 +``` + + + + + + + + + +```python +langchain_core.utils.mustache._LAST_TAG_LINE = None +``` + + + + + + + + + +```python +langchain_core.utils.mustache.g_token_cache: dict[str, list[tuple[str, str]]] = {} +``` + + + + + + + + + +```python +langchain_core.utils.mustache.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/pydantic.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/pydantic.mdx new file mode 100644 index 0000000..4228119 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/pydantic.mdx @@ -0,0 +1,679 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/pydantic +title: langchain_core.utils.pydantic +--- + +Utilities for pydantic. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`_IgnoreUnserializable`](#langchain_core-utils-pydantic-_IgnoreUnserializable) | A JSON schema generator that ignores unknown types. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_create_model_cached`](#langchain_core-utils-pydantic-_create_model_cached) | - | +| [`_create_root_model`](#langchain_core-utils-pydantic-_create_root_model) | Create a base class. | +| [`_create_root_model_cached`](#langchain_core-utils-pydantic-_create_root_model_cached) | - | +| [`_create_subset_model`](#langchain_core-utils-pydantic-_create_subset_model) | Create subset model using the same pydantic version as the input model. | +| [`_create_subset_model_v1`](#langchain_core-utils-pydantic-_create_subset_model_v1) | Create a Pydantic model with only a subset of model's fields. | +| [`_create_subset_model_v2`](#langchain_core-utils-pydantic-_create_subset_model_v2) | Create a Pydantic model with a subset of the model fields. | +| [`_remap_field_definitions`](#langchain_core-utils-pydantic-_remap_field_definitions) | This remaps fields to avoid colliding with internal pydantic fields. | +| [`create_model`](#langchain_core-utils-pydantic-create_model) | Create a Pydantic model with the given field definitions. | +| [`create_model_v2`](#langchain_core-utils-pydantic-create_model_v2) | Create a Pydantic model with the given field definitions. | +| [`get_fields`](#langchain_core-utils-pydantic-get_fields) | Return the field names of a Pydantic model. | +| [`get_pydantic_major_version`](#langchain_core-utils-pydantic-get_pydantic_major_version) | DEPRECATED - Get the major version of Pydantic. | +| [`is_basemodel_instance`](#langchain_core-utils-pydantic-is_basemodel_instance) | Check if the given class is an instance of Pydantic `BaseModel`. | +| [`is_basemodel_subclass`](#langchain_core-utils-pydantic-is_basemodel_subclass) | Check if the given class is a subclass of Pydantic `BaseModel`. | +| [`is_pydantic_v1_subclass`](#langchain_core-utils-pydantic-is_pydantic_v1_subclass) | Check if the given class is Pydantic v1-like. | +| [`is_pydantic_v2_subclass`](#langchain_core-utils-pydantic-is_pydantic_v2_subclass) | Check if the given class is Pydantic v2-like. | +| [`pre_init`](#langchain_core-utils-pydantic-pre_init) | Decorator to run a function before model initialization. | + +### Data + +[`IS_PYDANTIC_V1`](#langchain_core-utils-pydantic-IS_PYDANTIC_V1) + +[`IS_PYDANTIC_V2`](#langchain_core-utils-pydantic-IS_PYDANTIC_V2) + +[`NO_DEFAULT`](#langchain_core-utils-pydantic-NO_DEFAULT) + +[`PYDANTIC_MAJOR_VERSION`](#langchain_core-utils-pydantic-PYDANTIC_MAJOR_VERSION) + +[`PYDANTIC_MINOR_VERSION`](#langchain_core-utils-pydantic-PYDANTIC_MINOR_VERSION) + +[`PYDANTIC_VERSION`](#langchain_core-utils-pydantic-PYDANTIC_VERSION) + +[`PydanticBaseModel`](#langchain_core-utils-pydantic-PydanticBaseModel) + +[`TBaseModel`](#langchain_core-utils-pydantic-TBaseModel) + +[`TypeBaseModel`](#langchain_core-utils-pydantic-TypeBaseModel) + +[`_RESERVED_NAMES`](#langchain_core-utils-pydantic-_RESERVED_NAMES) + +[`_SchemaConfig`](#langchain_core-utils-pydantic-_SchemaConfig) + +### API + + + + + +```python +class langchain_core.utils.pydantic._IgnoreUnserializable() +``` + + + + + + +**Bases:** `GenerateJsonSchema` + +A JSON schema generator that ignores unknown types. + +https://docs.pydantic.dev/latest/concepts/json_schema/#customizing-the-json-schema-generation-process + + + + + + +```python +langchain_core.utils.pydantic._IgnoreUnserializable.handle_invalid_for_json_schema( + schema: pydantic_core.core_schema.CoreSchema, + error_info: str +) -> pydantic.json_schema.JsonSchemaValue +``` + + + + + + + + + + + + + + +```python +langchain_core.utils.pydantic._create_model_cached( + model_name: str, + field_definitions: typing.Any = {} +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + + +```python +langchain_core.utils.pydantic._create_root_model( + name: str, + type_: typing.Any, + module_name: str | None = None, + default_: object = NO_DEFAULT +) -> type[pydantic.BaseModel] +``` + + + + + + +Create a base class. + + + + + + + + +```python +langchain_core.utils.pydantic._create_root_model_cached( + model_name: str, + type_: typing.Any, + module_name: str | None = None, + default_: object = NO_DEFAULT +) -> type[pydantic.BaseModel] +``` + + + + + + + + + + + + + +```python +langchain_core.utils.pydantic._create_subset_model( + name: str, + model: langchain_core.utils.pydantic.TypeBaseModel, + field_names: list[str], + descriptions: dict | None = None, + fn_description: str | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Create subset model using the same pydantic version as the input model. + +**Returns:** `type[BaseModel]` + +The created subset model. + + + + + + + + +```python +langchain_core.utils.pydantic._create_subset_model_v1( + name: str, + model: type[pydantic.v1.BaseModel], + field_names: list, + descriptions: dict | None = None, + fn_description: str | None = None +) -> type[pydantic.v1.BaseModel] +``` + + + + + + +Create a Pydantic model with only a subset of model's fields. + + + + + + + + +```python +langchain_core.utils.pydantic._create_subset_model_v2( + name: str, + model: type[pydantic.BaseModel], + field_names: list[str], + descriptions: dict | None = None, + fn_description: str | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Create a Pydantic model with a subset of the model fields. + + + + + + + + +```python +langchain_core.utils.pydantic._remap_field_definitions( + field_definitions: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + +This remaps fields to avoid colliding with internal pydantic fields. + + + + + + + + +```python +langchain_core.utils.pydantic.create_model( + model_name: str, + module_name: str | None = None, + field_definitions: typing.Any = {} +) -> type[pydantic.BaseModel] +``` + + + + + + +Create a Pydantic model with the given field definitions. + +Please use `create_model_v2` instead of this function. + +**Parameters:** + + +The name of the model. + + + +The name of the module where the model is defined. + +This is used by Pydantic to resolve any forward references. + + + +The field definitions for the model. + + +**Returns:** `type[BaseModel]` + +The created model. + + + + + + + + +```python +langchain_core.utils.pydantic.create_model_v2( + model_name: str, + module_name: str | None = None, + field_definitions: dict[str, typing.Any] | None = None, + root: typing.Any | None = None +) -> type[pydantic.BaseModel] +``` + + + + + + +Create a Pydantic model with the given field definitions. + +!!! warning + + Do not use outside of langchain packages. This API is subject to change at any + time. + +**Parameters:** + + +The name of the model. + + + +The name of the module where the model is defined. + +This is used by Pydantic to resolve any forward references. + + + +The field definitions for the model. + + + +Type for a root model (`RootModel`) + + +**Returns:** `type[BaseModel]` + +The created model. + + + + + + + + +```python +langchain_core.utils.pydantic.get_fields( + model: type[pydantic.BaseModel | pydantic.v1.BaseModel] | pydantic.BaseModel | pydantic.v1.BaseModel +) -> dict[str, pydantic.fields.FieldInfo] | dict[str, pydantic.v1.fields.ModelField] +``` + + + + + + +Return the field names of a Pydantic model. + +**Parameters:** + + +The Pydantic model or instance. + + +**Raises:** + +- `TypeError`: If the model is not a Pydantic model. + + + + + + + + +```python +langchain_core.utils.pydantic.get_pydantic_major_version() -> int +``` + + + + + + +DEPRECATED - Get the major version of Pydantic. + +Use `PYDANTIC_VERSION.major` instead. + +**Returns:** `int` + +The major version of Pydantic. + + + + + + + + +```python +langchain_core.utils.pydantic.is_basemodel_instance( + obj: typing.Any +) -> bool +``` + + + + + + +Check if the given class is an instance of Pydantic `BaseModel`. + +Check if the given class is an instance of any of the following: + +* `pydantic.BaseModel` in Pydantic 2.x +* `pydantic.v1.BaseModel` in Pydantic 2.x + +**Returns:** `bool` + +`True` if the given class is an instance of Pydantic `BaseModel`. + + + + + + + + +```python +langchain_core.utils.pydantic.is_basemodel_subclass( + cls: type +) -> bool +``` + + + + + + +Check if the given class is a subclass of Pydantic `BaseModel`. + +Check if the given class is a subclass of any of the following: + +* `pydantic.BaseModel` in Pydantic 2.x +* `pydantic.v1.BaseModel` in Pydantic 2.x + +**Returns:** `bool` + +`True` if the given class is a subclass of Pydantic `BaseModel`. + + + + + + + + +```python +langchain_core.utils.pydantic.is_pydantic_v1_subclass( + cls: type +) -> bool +``` + + + + + + +Check if the given class is Pydantic v1-like. + +**Returns:** `bool` + +`True` if the given class is a subclass of Pydantic `BaseModel` 1.x. + + + + + + + + +```python +langchain_core.utils.pydantic.is_pydantic_v2_subclass( + cls: type +) -> bool +``` + + + + + + +Check if the given class is Pydantic v2-like. + +**Returns:** `bool` + +`True` if the given class is a subclass of Pydantic `BaseModel` 2.x. + + + + + + + + +```python +langchain_core.utils.pydantic.pre_init( + func: collections.abc.Callable +) -> typing.Any +``` + + + + + + +Decorator to run a function before model initialization. + +**Parameters:** + + +The function to run before model initialization. + + +**Returns:** `Any` + +The decorated function. + + + + + + + + +```python +langchain_core.utils.pydantic.IS_PYDANTIC_V1 = False +``` + + + + + + + + + +```python +langchain_core.utils.pydantic.IS_PYDANTIC_V2 = True +``` + + + + + + + + + +```python +langchain_core.utils.pydantic.NO_DEFAULT = object() +``` + + + + + + + + + +```python +langchain_core.utils.pydantic.PYDANTIC_MAJOR_VERSION = PYDANTIC_VERSION.major +``` + + + + + + + + + +```python +langchain_core.utils.pydantic.PYDANTIC_MINOR_VERSION = PYDANTIC_VERSION.minor +``` + + + + + + + + + +```python +langchain_core.utils.pydantic.PYDANTIC_VERSION = version.parse(pydantic.__version__) +``` + + + + + + + + + +```python +langchain_core.utils.pydantic.PydanticBaseModel = BaseModel +``` + + + + + + + + + +```python +langchain_core.utils.pydantic.TBaseModel = TypeVar('TBaseModel', bound=PydanticBaseModel) +``` + + + + + + + + + +```python +langchain_core.utils.pydantic.TypeBaseModel = type[BaseModel] +``` + + + + + + + + + +```python +langchain_core.utils.pydantic._RESERVED_NAMES = {key for key in (dir(BaseModel)) if not key.startswith('_')} +``` + + + + + + + + + +```python +langchain_core.utils.pydantic._SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True, protected_namespaces=()) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/strings.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/strings.mdx new file mode 100644 index 0000000..e4b8b35 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/strings.mdx @@ -0,0 +1,149 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/strings +title: langchain_core.utils.strings +--- + +String utilities. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`comma_list`](#langchain_core-utils-strings-comma_list) | Convert an iterable to a comma-separated string. | +| [`sanitize_for_postgres`](#langchain_core-utils-strings-sanitize_for_postgres) | Sanitize text by removing NUL bytes that are incompatible with PostgreSQL. | +| [`stringify_dict`](#langchain_core-utils-strings-stringify_dict) | Stringify a dictionary. | +| [`stringify_value`](#langchain_core-utils-strings-stringify_value) | Stringify a value. | + +### API + + + + + +```python +langchain_core.utils.strings.comma_list( + items: collections.abc.Iterable[typing.Any] +) -> str +``` + + + + + + +Convert an iterable to a comma-separated string. + +**Parameters:** + + +The iterable to convert. + + +**Returns:** `str` + +The comma-separated string. + + + + + + + + +```python +langchain_core.utils.strings.sanitize_for_postgres( + text: str, + replacement: str = '' +) -> str +``` + + + + + + +Sanitize text by removing NUL bytes that are incompatible with PostgreSQL. + +PostgreSQL text fields cannot contain `NUL (0x00)` bytes, which can cause +`psycopg.DataError` when inserting documents. This function removes or replaces +such characters to ensure compatibility. + +**Parameters:** + + +The text to sanitize. + + + +String to replace `NUL` bytes with. + + +**Returns:** `str` + +The sanitized text with `NUL` bytes removed or replaced. + + + + + + + + +```python +langchain_core.utils.strings.stringify_dict( + data: dict +) -> str +``` + + + + + + +Stringify a dictionary. + +**Parameters:** + + +The dictionary to stringify. + + +**Returns:** `str` + +The stringified dictionary. + + + + + + + + +```python +langchain_core.utils.strings.stringify_value( + val: typing.Any +) -> str +``` + + + + + + +Stringify a value. + +**Parameters:** + + +The value to stringify. + + +**Returns:** `str` + +The stringified value. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/usage.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/usage.mdx new file mode 100644 index 0000000..fb97fef --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/usage.mdx @@ -0,0 +1,81 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/usage +title: langchain_core.utils.usage +--- + +Usage utilities. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_dict_int_op`](#langchain_core-utils-usage-_dict_int_op) | Apply an integer operation to corresponding values in two dictionaries. | + +### API + + + + + +```python +langchain_core.utils.usage._dict_int_op( + left: dict, + right: dict, + op: collections.abc.Callable[[int, int], int], + default: int = 0, + depth: int = 0, + max_depth: int = 100 +) -> dict +``` + + + + + + +Apply an integer operation to corresponding values in two dictionaries. + +Recursively combines two dictionaries by applying the given operation to integer +values at matching keys. + +Supports nested dictionaries. + +**Parameters:** + + +First dictionary to combine. + + + +Second dictionary to combine. + + + +Binary operation function to apply to integer values. + + + +Default value to use when a key is missing from a dictionary. + + + +Current recursion depth (used internally). + + + +Maximum recursion depth (to prevent infinite loops). + + +**Returns:** `dict` + +A new dictionary with combined values. + +**Raises:** + +- `ValueError`: If `max_depth` is exceeded or if value types are not supported. + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/utils.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/utils.mdx new file mode 100644 index 0000000..201d183 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/utils.mdx @@ -0,0 +1,567 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/utils +title: langchain_core.utils.utils +--- + +Generic utility functions. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`_NoDefaultType`](#langchain_core-utils-utils-_NoDefaultType) | Type to indicate no default value is provided. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_build_model_kwargs`](#langchain_core-utils-utils-_build_model_kwargs) | Build `model_kwargs` param from Pydantic constructor values. | +| [`build_extra_kwargs`](#langchain_core-utils-utils-build_extra_kwargs) | Build extra kwargs from values and extra_kwargs. | +| [`check_package_version`](#langchain_core-utils-utils-check_package_version) | Check the version of a package. | +| [`convert_to_secret_str`](#langchain_core-utils-utils-convert_to_secret_str) | Convert a string to a `SecretStr` if needed. | +| [`ensure_id`](#langchain_core-utils-utils-ensure_id) | Ensure the ID is a valid string, generating a new UUID if not provided. | +| [`from_env`](#langchain_core-utils-utils-from_env) | Create a factory method that gets a value from an environment variable. | +| [`get_pydantic_field_names`](#langchain_core-utils-utils-get_pydantic_field_names) | Get field names, including aliases, for a pydantic class. | +| [`guard_import`](#langchain_core-utils-utils-guard_import) | Dynamically import a module. | +| [`mock_now`](#langchain_core-utils-utils-mock_now) | Context manager for mocking out datetime.now() in unit tests. | +| [`raise_for_status_with_text`](#langchain_core-utils-utils-raise_for_status_with_text) | Raise an error with the response text. | +| [`secret_from_env`](#langchain_core-utils-utils-secret_from_env) | Secret from env. | +| [`xor_args`](#langchain_core-utils-utils-xor_args) | Validate specified keyword args are mutually exclusive. | + +### Data + +[`LC_AUTO_PREFIX`](#langchain_core-utils-utils-LC_AUTO_PREFIX) + +[`LC_ID_PREFIX`](#langchain_core-utils-utils-LC_ID_PREFIX) + +[`_NoDefault`](#langchain_core-utils-utils-_NoDefault) + +### API + + + + + +```python +class langchain_core.utils.utils._NoDefaultType() +``` + + + + + + +Type to indicate no default value is provided. + + + + + + + + +```python +langchain_core.utils.utils._build_model_kwargs( + values: dict[str, typing.Any], + all_required_field_names: set[str] +) -> dict[str, typing.Any] +``` + + + + + + +Build `model_kwargs` param from Pydantic constructor values. + +**Parameters:** + + +All init args passed in by user. + + + +All required field names for the pydantic class. + + +**Returns:** `dict[str, Any]` + +Extra kwargs. + +**Raises:** + +- `ValueError`: If a field is specified in both `values` and `extra_kwargs`. +- `ValueError`: If a field is specified in `model_kwargs`. + + + + + + + + +```python +langchain_core.utils.utils.build_extra_kwargs( + extra_kwargs: dict[str, typing.Any], + values: dict[str, typing.Any], + all_required_field_names: set[str] +) -> dict[str, typing.Any] +``` + + + + + + +Build extra kwargs from values and extra_kwargs. + +!!! danger "DON'T USE" + + Kept for backwards-compatibility but should never have been public. Use the + internal `_build_model_kwargs` function instead. + +**Parameters:** + + +Extra kwargs passed in by user. + + + +Values passed in by user. + + + +All required field names for the pydantic class. + + +**Returns:** `dict[str, Any]` + +Extra kwargs. + +**Raises:** + +- `ValueError`: If a field is specified in both `values` and `extra_kwargs`. +- `ValueError`: If a field is specified in `model_kwargs`. + + + + + + + + +```python +langchain_core.utils.utils.check_package_version( + package: str, + lt_version: str | None = None, + lte_version: str | None = None, + gt_version: str | None = None, + gte_version: str | None = None +) -> None +``` + + + + + + +Check the version of a package. + +**Parameters:** + + +The name of the package. + + + +The version must be less than this. + + + +The version must be less than or equal to this. + + + +The version must be greater than this. + + + +The version must be greater than or equal to this. + + +**Raises:** + +- `ValueError`: If the package version does not meet the requirements. + + + + + + + + +```python +langchain_core.utils.utils.convert_to_secret_str( + value: pydantic.SecretStr | str +) -> pydantic.SecretStr +``` + + + + + + +Convert a string to a `SecretStr` if needed. + +**Parameters:** + + +The value to convert. + + +**Returns:** `SecretStr` + +The `SecretStr` value. + + + + + + + + +```python +langchain_core.utils.utils.ensure_id( + id_val: str | None +) -> str +``` + + + + + + +Ensure the ID is a valid string, generating a new UUID if not provided. + +Auto-generated UUIDs are prefixed by `'lc_'` to indicate they are +LangChain-generated IDs. + +**Parameters:** + + +Optional string ID value to validate. + + +**Returns:** `str` + +A string ID, either the validated provided value or a newly generated UUID4. + + + + + + + + +```python +langchain_core.utils.utils.from_env( + key: str | collections.abc.Sequence[str], + default: str | langchain_core.utils.utils._NoDefaultType | None = _NoDefault, + error_message: str | None = None +) -> collections.abc.Callable[[], str] | collections.abc.Callable[[], str | None] +``` + + + + + + +Create a factory method that gets a value from an environment variable. + +**Parameters:** + + +The environment variable to look up. + +If a list of keys is provided, the first key found in the environment will +be used. If no key is found, the default value will be used if set, +otherwise an error will be raised. + + + +The default value to return if the environment variable is not set. + + + +The error message which will be raised if the key is not found +and no default value is provided. + +This will be raised as a ValueError. + + +**Returns:** `Callable[[], str] | Callable[[], str | None]` + +Factory method that will look up the value from the environment. + + + + + + + + +```python +langchain_core.utils.utils.get_pydantic_field_names( + pydantic_cls: typing.Any +) -> set[str] +``` + + + + + + +Get field names, including aliases, for a pydantic class. + +**Parameters:** + + +Pydantic class. + + +**Returns:** `set[str]` + +Field names. + + + + + + + + +```python +langchain_core.utils.utils.guard_import( + module_name: str, + pip_name: str | None = None, + package: str | None = None +) -> typing.Any +``` + + + + + + +Dynamically import a module. + +Raise an exception if the module is not installed. + +**Parameters:** + + +The name of the module to import. + + + +The name of the module to install with pip. + + + +The package to import the module from. + + +**Returns:** `Any` + +The imported module. + +**Raises:** + +- `ImportError`: If the module is not installed. + + + + + + + + +```python +langchain_core.utils.utils.mock_now( + dt_value: datetime.datetime +) -> collections.abc.Iterator[type] +``` + + + + + + +Context manager for mocking out datetime.now() in unit tests. + +**Parameters:** + + +The datetime value to use for datetime.now(). + + + + + + + + + +```python +langchain_core.utils.utils.raise_for_status_with_text( + response: requests.Response +) -> None +``` + + + + + + +Raise an error with the response text. + +**Parameters:** + + +The response to check for errors. + + +**Raises:** + +- `ValueError`: If the response has an error status code. + + + + + + + + +```python +langchain_core.utils.utils.secret_from_env( + key: str | collections.abc.Sequence[str], + default: str | langchain_core.utils.utils._NoDefaultType | None = _NoDefault, + error_message: str | None = None +) -> collections.abc.Callable[[], pydantic.SecretStr | None] | collections.abc.Callable[[], pydantic.SecretStr] +``` + + + + + + +Secret from env. + +**Parameters:** + + +The environment variable to look up. + + + +The default value to return if the environment variable is not set. + + + +The error message which will be raised if the key is not found +and no default value is provided. + +This will be raised as a `ValueError`. + + +**Returns:** `Callable[[], SecretStr | None] | Callable[[], SecretStr]` + +Factory method that will look up the secret from the environment. + + + + + + + + +```python +langchain_core.utils.utils.xor_args( + arg_groups: tuple[str, ...] = () +) -> collections.abc.Callable +``` + + + + + + +Validate specified keyword args are mutually exclusive. + +**Parameters:** + + +Groups of mutually exclusive keyword args. + + +**Returns:** `Callable` + +Decorator that validates the specified keyword args are mutually exclusive. + + + + + + + + +```python +langchain_core.utils.utils.LC_AUTO_PREFIX = 'lc_' +``` + + + + + + +LangChain auto-generated ID prefix for messages and content blocks. + + + + + + + +```python +langchain_core.utils.utils.LC_ID_PREFIX = 'lc_run-' +``` + + + + + + +Internal tracing/callback system identifier. + +Used for: + +- Tracing. Every LangChain operation (LLM call, chain execution, tool use, etc.) + gets a unique run_id (UUID) +- Enables tracking parent-child relationships between operations + + + + + + + +```python +langchain_core.utils.utils._NoDefault = _NoDefaultType() +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/uuid.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/uuid.mdx new file mode 100644 index 0000000..0d1970f --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/utils/uuid.mdx @@ -0,0 +1,103 @@ +--- +layout: overview +slug: langchain-core/langchain_core/utils/uuid +title: langchain_core.utils.uuid +--- + +UUID utility functions. + +This module exports a uuid7 function to generate monotonic, time-ordered UUIDs +for tracing and similar operations. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_to_timestamp_and_nanos`](#langchain_core-utils-uuid-_to_timestamp_and_nanos) | Split a nanosecond timestamp into seconds and remaining nanoseconds. | +| [`uuid7`](#langchain_core-utils-uuid-uuid7) | Generate a UUID from a Unix timestamp in nanoseconds and random bits. | + +### Data + +[`_NANOS_PER_SECOND`](#langchain_core-utils-uuid-_NANOS_PER_SECOND) + +[`__all__`](#langchain_core-utils-uuid-__all__) + +### API + + + + + +```python +langchain_core.utils.uuid._to_timestamp_and_nanos( + nanoseconds: int +) -> tuple[int, int] +``` + + + + + + +Split a nanosecond timestamp into seconds and remaining nanoseconds. + + + + + + + + +```python +langchain_core.utils.uuid.uuid7( + nanoseconds: int | None = None +) -> uuid.UUID +``` + + + + + + +Generate a UUID from a Unix timestamp in nanoseconds and random bits. + +UUIDv7 objects feature monotonicity within a millisecond. + +**Parameters:** + + +Optional ns timestamp. If not provided, uses current time. + + +**Returns:** `UUID` + +A UUIDv7 object. + + + + + + + + +```python +langchain_core.utils.uuid._NANOS_PER_SECOND: Final = 1000000000 +``` + + + + + + + + + +```python +langchain_core.utils.uuid.__all__ = ['uuid7'] +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores.mdx new file mode 100644 index 0000000..7e08797 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores.mdx @@ -0,0 +1,113 @@ +--- +layout: overview +slug: langchain-core/langchain_core/vectorstores +title: langchain_core.vectorstores +--- + +Vector stores. + +## Submodules + +- **[`langchain_core.vectorstores.base`](/langchain-core/langchain_core/vectorstores/base)** +- **[`langchain_core.vectorstores.in_memory`](/langchain-core/langchain_core/vectorstores/in_memory)** +- **[`langchain_core.vectorstores.utils`](/langchain-core/langchain_core/vectorstores/utils)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`__dir__`](#langchain_core-vectorstores-__dir__) | Return a list of available attributes for this module. | +| [`__getattr__`](#langchain_core-vectorstores-__getattr__) | Dynamically import and return an attribute from a submodule. | + +### Data + +[`__all__`](#langchain_core-vectorstores-__all__) + +[`_dynamic_imports`](#langchain_core-vectorstores-_dynamic_imports) + +### API + + + + + +```python +langchain_core.vectorstores.__dir__() -> list[str] +``` + + + + + + +Return a list of available attributes for this module. + +**Returns:** `list[str]` + +List of attribute names that can be imported from this module. + + + + + + + + +```python +langchain_core.vectorstores.__getattr__( + attr_name: str +) -> object +``` + + + + + + +Dynamically import and return an attribute from a submodule. + +This function enables lazy loading of vectorstore classes from submodules, reducing +initial import time and circular dependency issues. + +**Parameters:** + + +Name of the attribute to import. + + +**Returns:** `object` + +The imported attribute object. + +**Raises:** + +- `AttributeError`: If the attribute is not found in `_dynamic_imports`. + + + + + + + + +```python +langchain_core.vectorstores.__all__ = ('VST', 'InMemoryVectorStore', 'VectorStore', 'VectorStoreRetriever') +``` + + + + + + + + + +```python +langchain_core.vectorstores._dynamic_imports = {'VectorStore': 'base', 'VST': 'base', 'VectorStoreRetriever': 'base', 'InMemory... +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/base.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/base.mdx new file mode 100644 index 0000000..4211152 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/base.mdx @@ -0,0 +1,1712 @@ +--- +layout: overview +slug: langchain-core/langchain_core/vectorstores/base +title: langchain_core.vectorstores.base +--- + +A vector store stores embedded data and performs vector search. + +One of the most common ways to store and search over unstructured data is to +embed it and store the resulting embedding vectors, and then query the store +and retrieve the data that are 'most similar' to the embedded query. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`VectorStore`](#langchain_core-vectorstores-base-VectorStore) | Interface for vector store. | +| [`VectorStoreRetriever`](#langchain_core-vectorstores-base-VectorStoreRetriever) | Base Retriever class for VectorStore. | + +### Data + +[`VST`](#langchain_core-vectorstores-base-VST) + +[`logger`](#langchain_core-vectorstores-base-logger) + +### API + + + + + +```python +class langchain_core.vectorstores.base.VectorStore() +``` + + + + + + +Abstract + +Interface for vector store. + + + +Access the query embedding object if available. + + + + + +```python +langchain_core.vectorstores.base.VectorStore._asimilarity_search_with_relevance_scores( + query: str, + k: int = 4, + kwargs: typing.Any = {} +) -> list[tuple[langchain_core.documents.Document, float]] +``` + + + + + + +async + +Default similarity search with relevance scores. + +Modify if necessary in subclass. +Return docs and relevance scores in the range `[0, 1]`. + +`0` is dissimilar, `1` is most similar. + +**Parameters:** + + +Input text. + + + +Number of `Document` objects to return. + + + +Kwargs to be passed to similarity search. + +Should include `score_threshold`, an optional floating point value +between `0` to `1` to filter the resulting set of retrieved docs. + + +**Returns:** `list[tuple[Document, float]]` + +List of tuples of `(doc, similarity_score)` + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore._cosine_relevance_score_fn( + distance: float +) -> float +``` + + + + + + +staticmethod + +Normalize the distance to a score on a scale [0, 1]. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore._euclidean_relevance_score_fn( + distance: float +) -> float +``` + + + + + + +staticmethod + +Return a similarity score on a scale [0, 1]. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore._get_retriever_tags() -> list[str] +``` + + + + + + +Get tags for retriever. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore._max_inner_product_relevance_score_fn( + distance: float +) -> float +``` + + + + + + +staticmethod + +Normalize the distance to a score on a scale [0, 1]. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore._select_relevance_score_fn() -> collections.abc.Callable[[float], float] +``` + + + + + + +The 'correct' relevance function. + +May differ depending on a few things, including: + +- The distance / similarity metric used by the VectorStore +- The scale of your embeddings (OpenAI's are unit normed. Many others are not!) +- Embedding dimensionality +- etc. + +Vectorstores should define their own selection-based method of relevance. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore._similarity_search_with_relevance_scores( + query: str, + k: int = 4, + kwargs: typing.Any = {} +) -> list[tuple[langchain_core.documents.Document, float]] +``` + + + + + + +Default similarity search with relevance scores. + +Modify if necessary in subclass. +Return docs and relevance scores in the range `[0, 1]`. + +`0` is dissimilar, `1` is most similar. + +**Parameters:** + + +Input text. + + + +Number of `Document` objects to return. + + + +Kwargs to be passed to similarity search. + +Should include `score_threshold`, an optional floating point value +between `0` to `1` to filter the resulting set of retrieved docs. + + +**Returns:** `list[tuple[Document, float]]` + +List of tuples of `(doc, similarity_score)` + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.aadd_documents( + documents: list[langchain_core.documents.Document], + kwargs: typing.Any = {} +) -> list[str] +``` + + + + + + +async + +Async run more documents through the embeddings and add to the `VectorStore`. + +**Parameters:** + + +Documents to add to the `VectorStore`. + + + +Additional keyword arguments. + + +**Returns:** `list[str]` + +List of IDs of the added texts. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.aadd_texts( + texts: collections.abc.Iterable[str], + metadatas: list[dict] | None = None, + ids: list[str] | None = None, + kwargs: typing.Any = {} +) -> list[str] +``` + + + + + + +async + +Async run more texts through the embeddings and add to the `VectorStore`. + +**Parameters:** + + +Iterable of strings to add to the `VectorStore`. + + + +Optional list of metadatas associated with the texts. + + + +Optional list + + + +`VectorStore` specific parameters. + + +**Returns:** `list[str]` + +List of IDs from adding the texts into the `VectorStore`. + +**Raises:** + +- `ValueError`: If the number of metadatas does not match the number of texts. +- `ValueError`: If the number of IDs does not match the number of texts. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.add_documents( + documents: list[langchain_core.documents.Document], + kwargs: typing.Any = {} +) -> list[str] +``` + + + + + + +Add or update documents in the `VectorStore`. + +**Parameters:** + + +Documents to add to the `VectorStore`. + + + +Additional keyword arguments. + +If kwargs contains IDs and documents contain ids, the IDs in the kwargs +will receive precedence. + + +**Returns:** `list[str]` + +List of IDs of the added texts. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.add_texts( + texts: collections.abc.Iterable[str], + metadatas: list[dict] | None = None, + ids: list[str] | None = None, + kwargs: typing.Any = {} +) -> list[str] +``` + + + + + + +Run more texts through the embeddings and add to the `VectorStore`. + +**Parameters:** + + +Iterable of strings to add to the `VectorStore`. + + + +Optional list of metadatas associated with the texts. + + + +Optional list of IDs associated with the texts. + + + +`VectorStore` specific parameters. + +One of the kwargs should be `ids` which is a list of ids +associated with the texts. + + +**Returns:** `list[str]` + +List of IDs from adding the texts into the `VectorStore`. + +**Raises:** + +- `ValueError`: If the number of metadatas does not match the number of texts. +- `ValueError`: If the number of IDs does not match the number of texts. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.adelete( + ids: list[str] | None = None, + kwargs: typing.Any = {} +) -> bool | None +``` + + + + + + +async + +Async delete by vector ID or other criteria. + +**Parameters:** + + +List of IDs to delete. If `None`, delete all. + + + +Other keyword arguments that subclasses might use. + + +**Returns:** `bool | None` + +`True` if deletion is successful, `False` otherwise, `None` if not +implemented. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.afrom_documents( + documents: list[langchain_core.documents.Document], + embedding: langchain_core.embeddings.Embeddings, + kwargs: typing.Any = {} +) -> typing_extensions.Self +``` + + + + + + +async classmethod + +Async return `VectorStore` initialized from documents and embeddings. + +**Parameters:** + + +List of `Document` objects to add to the `VectorStore`. + + + +Embedding function to use. + + + +Additional keyword arguments. + + +**Returns:** `Self` + +`VectorStore` initialized from documents and embeddings. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.afrom_texts( + texts: list[str], + embedding: langchain_core.embeddings.Embeddings, + metadatas: list[dict] | None = None, + ids: list[str] | None = None, + kwargs: typing.Any = {} +) -> typing_extensions.Self +``` + + + + + + +async classmethod + +Async return `VectorStore` initialized from texts and embeddings. + +**Parameters:** + + +Texts to add to the `VectorStore`. + + + +Embedding function to use. + + + +Optional list of metadatas associated with the texts. + + + +Optional list of IDs associated with the texts. + + + +Additional keyword arguments. + + +**Returns:** `Self` + +`VectorStore` initialized from texts and embeddings. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.aget_by_ids( + ids: collections.abc.Sequence[str] +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + +Async get documents by their IDs. + +The returned documents are expected to have the ID field set to the ID of the +document in the vector store. + +Fewer documents may be returned than requested if some IDs are not found or +if there are duplicated IDs. + +Users should not assume that the order of the returned documents matches +the order of the input IDs. Instead, users should rely on the ID field of the +returned documents. + +This method should **NOT** raise exceptions if no documents are found for +some IDs. + +**Parameters:** + + +List of IDs to retrieve. + + +**Returns:** `list[Document]` + +List of `Document` objects. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.amax_marginal_relevance_search( + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + +Async return docs selected using the maximal marginal relevance. + +Maximal marginal relevance optimizes for similarity to query AND diversity +among selected documents. + +**Parameters:** + + +Text to look up documents similar to. + + + +Number of `Document` objects to return. + + + +Number of `Document` objects to fetch to pass to MMR algorithm. + + + +Number between `0` and `1` that determines the degree of +diversity among the results with `0` corresponding to maximum diversity +and `1` to minimum diversity. + + + +Arguments to pass to the search method. + + +**Returns:** `list[Document]` + +List of `Document` objects selected by maximal marginal relevance. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.amax_marginal_relevance_search_by_vector( + embedding: list[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + +Async return docs selected using the maximal marginal relevance. + +Maximal marginal relevance optimizes for similarity to query AND diversity +among selected documents. + +**Parameters:** + + +Embedding to look up documents similar to. + + + +Number of `Document` objects to return. + + + +Number of `Document` objects to fetch to pass to MMR algorithm. + + + +Number between `0` and `1` that determines the degree of +diversity among the results with `0` corresponding to maximum diversity +and `1` to minimum diversity. + + + +Arguments to pass to the search method. + + +**Returns:** `list[Document]` + +List of `Document` objects selected by maximal marginal relevance. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.as_retriever( + kwargs: typing.Any = {} +) -> langchain_core.vectorstores.base.VectorStoreRetriever +``` + + + + + + +Return `VectorStoreRetriever` initialized from this `VectorStore`. + +Examples: + + +```python +# Retrieve more documents with higher diversity +# Useful if your dataset has many similar documents +docsearch.as_retriever( + search_type="mmr", search_kwargs={"k": 6, "lambda_mult": 0.25} +) + +# Fetch more documents for the MMR algorithm to consider +# But only return the top 5 +docsearch.as_retriever(search_type="mmr", search_kwargs={"k": 5, "fetch_k": 50}) + +# Only retrieve documents that have a relevance score +# Above a certain threshold +docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"score_threshold": 0.8}, +) + +# Only get the single most similar document from the dataset +docsearch.as_retriever(search_kwargs={"k": 1}) + +# Use a filter to only retrieve documents from a specific paper +docsearch.as_retriever( + search_kwargs={"filter": {"paper_title": "GPT-4 Technical Report"}} +) +``` + + + +**Parameters:** + + +Keyword arguments to pass to the search function. + +Can include: + +* `search_type`: Defines the type of search that the Retriever should + perform. Can be `'similarity'` (default), `'mmr'`, or + `'similarity_score_threshold'`. +* `search_kwargs`: Keyword arguments to pass to the search function. + + Can include things like: + + * `k`: Amount of documents to return (Default: `4`) + * `score_threshold`: Minimum relevance threshold + for `similarity_score_threshold` + * `fetch_k`: Amount of documents to pass to MMR algorithm + (Default: `20`) + * `lambda_mult`: Diversity of results returned by MMR; + `1` for minimum diversity and 0 for maximum. (Default: `0.5`) + * `filter`: Filter by document metadata + + +**Returns:** `VectorStoreRetriever` + +Retriever class for `VectorStore`. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.asearch( + query: str, + search_type: str, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + +Async return docs most similar to query using a specified search type. + +**Parameters:** + + +Input text. + + + +Type of search to perform. + +Can be `'similarity'`, `'mmr'`, or `'similarity_score_threshold'`. + + + +Arguments to pass to the search method. + + +**Returns:** `list[Document]` + +List of `Document` objects most similar to the query. + +**Raises:** + +- `ValueError`: If `search_type` is not one of `'similarity'`, +`'mmr'`, or `'similarity_score_threshold'`. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.asimilarity_search( + query: str, + k: int = 4, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + +Async return docs most similar to query. + +**Parameters:** + + +Input text. + + + +Number of `Document` objects to return. + + + +Arguments to pass to the search method. + + +**Returns:** `list[Document]` + +List of `Document` objects most similar to the query. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.asimilarity_search_by_vector( + embedding: list[float], + k: int = 4, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + +Async return docs most similar to embedding vector. + +**Parameters:** + + +Embedding to look up documents similar to. + + + +Number of `Document` objects to return. + + + +Arguments to pass to the search method. + + +**Returns:** `list[Document]` + +List of `Document` objects most similar to the query vector. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.asimilarity_search_with_relevance_scores( + query: str, + k: int = 4, + kwargs: typing.Any = {} +) -> list[tuple[langchain_core.documents.Document, float]] +``` + + + + + + +async + +Async return docs and relevance scores in the range `[0, 1]`. + +`0` is dissimilar, `1` is most similar. + +**Parameters:** + + +Input text. + + + +Number of `Document` objects to return. + + + +Kwargs to be passed to similarity search. + +Should include `score_threshold`, an optional floating point value +between `0` to `1` to filter the resulting set of retrieved docs. + + +**Returns:** `list[tuple[Document, float]]` + +List of tuples of `(doc, similarity_score)` + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.asimilarity_search_with_score( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> list[tuple[langchain_core.documents.Document, float]] +``` + + + + + + +async + +Async run similarity search with distance. + +**Parameters:** + + +Arguments to pass to the search method. + + + +Arguments to pass to the search method. + + +**Returns:** `list[tuple[Document, float]]` + +List of tuples of `(doc, similarity_score)`. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.delete( + ids: list[str] | None = None, + kwargs: typing.Any = {} +) -> bool | None +``` + + + + + + +Delete by vector ID or other criteria. + +**Parameters:** + + +List of IDs to delete. If `None`, delete all. + + + +Other keyword arguments that subclasses might use. + + +**Returns:** `bool | None` + +`True` if deletion is successful, `False` otherwise, `None` if not +implemented. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.from_documents( + documents: list[langchain_core.documents.Document], + embedding: langchain_core.embeddings.Embeddings, + kwargs: typing.Any = {} +) -> typing_extensions.Self +``` + + + + + + +classmethod + +Return `VectorStore` initialized from documents and embeddings. + +**Parameters:** + + +List of `Document` objects to add to the `VectorStore`. + + + +Embedding function to use. + + + +Additional keyword arguments. + + +**Returns:** `Self` + +`VectorStore` initialized from documents and embeddings. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.from_texts( + texts: list[str], + embedding: langchain_core.embeddings.Embeddings, + metadatas: list[dict] | None = None, + ids: list[str] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.vectorstores.base.VST +``` + + + + + + +classmethod abstract + +Return `VectorStore` initialized from texts and embeddings. + +**Parameters:** + + +Texts to add to the `VectorStore`. + + + +Embedding function to use. + + + +Optional list of metadatas associated with the texts. + + + +Optional list of IDs associated with the texts. + + + +Additional keyword arguments. + + +**Returns:** `VST` + +`VectorStore` initialized from texts and embeddings. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.get_by_ids( + ids: collections.abc.Sequence[str] +) -> list[langchain_core.documents.Document] +``` + + + + + + +Get documents by their IDs. + +The returned documents are expected to have the ID field set to the ID of the +document in the vector store. + +Fewer documents may be returned than requested if some IDs are not found or +if there are duplicated IDs. + +Users should not assume that the order of the returned documents matches +the order of the input IDs. Instead, users should rely on the ID field of the +returned documents. + +This method should **NOT** raise exceptions if no documents are found for +some IDs. + +**Parameters:** + + +List of IDs to retrieve. + + +**Returns:** `list[Document]` + +List of `Document` objects. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.max_marginal_relevance_search( + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +Return docs selected using the maximal marginal relevance. + +Maximal marginal relevance optimizes for similarity to query AND diversity +among selected documents. + +**Parameters:** + + +Text to look up documents similar to. + + + +Number of `Document` objects to return. + + + +Number of `Document` objects to fetch to pass to MMR algorithm. + + + +Number between `0` and `1` that determines the degree of +diversity among the results with `0` corresponding to maximum diversity +and `1` to minimum diversity. + + + +Arguments to pass to the search method. + + +**Returns:** `list[Document]` + +List of `Document` objects selected by maximal marginal relevance. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.max_marginal_relevance_search_by_vector( + embedding: list[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +Return docs selected using the maximal marginal relevance. + +Maximal marginal relevance optimizes for similarity to query AND diversity +among selected documents. + +**Parameters:** + + +Embedding to look up documents similar to. + + + +Number of `Document` objects to return. + + + +Number of `Document` objects to fetch to pass to MMR algorithm. + + + +Number between `0` and `1` that determines the degree of +diversity among the results with `0` corresponding to maximum diversity +and `1` to minimum diversity. + + + +Arguments to pass to the search method. + + +**Returns:** `list[Document]` + +List of `Document` objects selected by maximal marginal relevance. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.search( + query: str, + search_type: str, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +Return docs most similar to query using a specified search type. + +**Parameters:** + + +Input text. + + + +Type of search to perform. + +Can be `'similarity'`, `'mmr'`, or `'similarity_score_threshold'`. + + + +Arguments to pass to the search method. + + +**Returns:** `list[Document]` + +List of `Document` objects most similar to the query. + +**Raises:** + +- `ValueError`: If `search_type` is not one of `'similarity'`, +`'mmr'`, or `'similarity_score_threshold'`. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.similarity_search( + query: str, + k: int = 4, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +abstract + +Return docs most similar to query. + +**Parameters:** + + +Input text. + + + +Number of `Document` objects to return. + + + +Arguments to pass to the search method. + + +**Returns:** `list[Document]` + +List of `Document` objects most similar to the query. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.similarity_search_by_vector( + embedding: list[float], + k: int = 4, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +Return docs most similar to embedding vector. + +**Parameters:** + + +Embedding to look up documents similar to. + + + +Number of `Document` objects to return. + + + +Arguments to pass to the search method. + + +**Returns:** `list[Document]` + +List of `Document` objects most similar to the query vector. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.similarity_search_with_relevance_scores( + query: str, + k: int = 4, + kwargs: typing.Any = {} +) -> list[tuple[langchain_core.documents.Document, float]] +``` + + + + + + +Return docs and relevance scores in the range `[0, 1]`. + +`0` is dissimilar, `1` is most similar. + +**Parameters:** + + +Input text. + + + +Number of `Document` objects to return. + + + +Kwargs to be passed to similarity search. + +Should include `score_threshold`, an optional floating point value +between `0` to `1` to filter the resulting set of retrieved docs. + + +**Returns:** `list[tuple[Document, float]]` + +List of tuples of `(doc, similarity_score)`. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStore.similarity_search_with_score( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> list[tuple[langchain_core.documents.Document, float]] +``` + + + + + + +Run similarity search with distance. + +**Parameters:** + + +Arguments to pass to the search method. + + + +Arguments to pass to the search method. + + +**Returns:** `list[tuple[Document, float]]` + +List of tuples of `(doc, similarity_score)`. + + + + + + + + + +```python +class langchain_core.vectorstores.base.VectorStoreRetriever() +``` + + + + + + +**Bases:** [BaseRetriever](/langchain-core/langchain_core/retrievers#langchain_core-retrievers-BaseRetriever) + +Base Retriever class for VectorStore. + + + + + + + + + +Keyword arguments to pass to the search function. + + + +Type of search to perform. + + + +VectorStore to use for retrieval. + + + + + +```python +langchain_core.vectorstores.base.VectorStoreRetriever._aget_relevant_documents( + query: str, + run_manager: langchain_core.callbacks.manager.AsyncCallbackManagerForRetrieverRun, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + + + + + + + +```python +langchain_core.vectorstores.base.VectorStoreRetriever._get_ls_params( + kwargs: typing.Any = {} +) -> langchain_core.retrievers.LangSmithRetrieverParams +``` + + + + + + +Get standard params for tracing. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStoreRetriever._get_relevant_documents( + query: str, + run_manager: langchain_core.callbacks.manager.CallbackManagerForRetrieverRun, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + + + + + + + +```python +langchain_core.vectorstores.base.VectorStoreRetriever.aadd_documents( + documents: list[langchain_core.documents.Document], + kwargs: typing.Any = {} +) -> list[str] +``` + + + + + + +async + +Async add documents to the `VectorStore`. + +**Parameters:** + + +Documents to add to the `VectorStore`. + + + +Other keyword arguments that subclasses might use. + + +**Returns:** `list[str]` + +List of IDs of the added texts. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStoreRetriever.add_documents( + documents: list[langchain_core.documents.Document], + kwargs: typing.Any = {} +) -> list[str] +``` + + + + + + +Add documents to the `VectorStore`. + +**Parameters:** + + +Documents to add to the `VectorStore`. + + + +Other keyword arguments that subclasses might use. + + +**Returns:** `list[str]` + +List of IDs of the added texts. + + + + + + + +```python +langchain_core.vectorstores.base.VectorStoreRetriever.validate_search_type( + values: dict +) -> typing.Any +``` + + + + + + +classmethod + +Validate search type. + +**Parameters:** + + +Values to validate. + + +**Returns:** `Any` + +Validated values. + +**Raises:** + +- `ValueError`: If `search_type` is not one of the allowed search types. +- `ValueError`: If `score_threshold` is not specified with a float value(`0~1`) + + + + + + + + + +```python +langchain_core.vectorstores.base.VST = TypeVar('VST', bound='VectorStore') +``` + + + + + + + + + +```python +langchain_core.vectorstores.base.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/in_memory.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/in_memory.mdx new file mode 100644 index 0000000..6d81777 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/in_memory.mdx @@ -0,0 +1,560 @@ +--- +layout: overview +slug: langchain-core/langchain_core/vectorstores/in_memory +title: langchain_core.vectorstores.in_memory +--- + +In-memory vector store. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`InMemoryVectorStore`](#langchain_core-vectorstores-in_memory-InMemoryVectorStore) | In-memory vector store implementation. | + +### Data + +[`_HAS_NUMPY`](#langchain_core-vectorstores-in_memory-_HAS_NUMPY) + +### API + + + + + +```python +class langchain_core.vectorstores.in_memory.InMemoryVectorStore( + embedding: langchain_core.embeddings.Embeddings +) +``` + + + + + + +**Bases:** [VectorStore](/langchain-core/langchain_core/vectorstores/base#langchain_core-vectorstores-base-VectorStore) + +In-memory vector store implementation. + +Uses a dictionary, and computes cosine similarity for search using numpy. + +Key init args — indexing params: + + * embedding_function: Embeddings + Embedding function to use. + + + + + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore._similarity_search_with_score_by_vector( + embedding: list[float], + k: int = 4, + filter: collections.abc.Callable[[Document], bool] | None = None +) -> list[tuple[langchain_core.documents.Document, float, list[float]]] +``` + + + + + + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.aadd_documents( + documents: list[langchain_core.documents.Document], + ids: list[str] | None = None, + kwargs: typing.Any = {} +) -> list[str] +``` + + + + + + +async + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.add_documents( + documents: list[langchain_core.documents.Document], + ids: list[str] | None = None, + kwargs: typing.Any = {} +) -> list[str] +``` + + + + + + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.adelete( + ids: collections.abc.Sequence[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + +async + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.afrom_texts( + texts: list[str], + embedding: langchain_core.embeddings.Embeddings, + metadatas: list[dict] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.vectorstores.in_memory.InMemoryVectorStore +``` + + + + + + +async classmethod + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.aget_by_ids( + ids: collections.abc.Sequence[str] +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + +Async get documents by their ids. + +**Parameters:** + + +The IDs of the documents to get. + + +**Returns:** `list[Document]` + +A list of `Document` objects. + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.amax_marginal_relevance_search( + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.asimilarity_search( + query: str, + k: int = 4, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.asimilarity_search_by_vector( + embedding: list[float], + k: int = 4, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + +async + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.asimilarity_search_with_score( + query: str, + k: int = 4, + kwargs: typing.Any = {} +) -> list[tuple[langchain_core.documents.Document, float]] +``` + + + + + + +async + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.delete( + ids: collections.abc.Sequence[str] | None = None, + kwargs: typing.Any = {} +) -> None +``` + + + + + + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.dump( + path: str +) -> None +``` + + + + + + +Dump the vector store to a file. + +**Parameters:** + + +The path to dump the vector store to. + + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.from_texts( + texts: list[str], + embedding: langchain_core.embeddings.Embeddings, + metadatas: list[dict] | None = None, + kwargs: typing.Any = {} +) -> langchain_core.vectorstores.in_memory.InMemoryVectorStore +``` + + + + + + +classmethod + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.get_by_ids( + ids: collections.abc.Sequence[str] +) -> list[langchain_core.documents.Document] +``` + + + + + + +Get documents by their ids. + +**Parameters:** + + +The IDs of the documents to get. + + +**Returns:** `list[Document]` + +A list of `Document` objects. + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.load( + path: str, + embedding: langchain_core.embeddings.Embeddings, + kwargs: typing.Any = {} +) -> langchain_core.vectorstores.in_memory.InMemoryVectorStore +``` + + + + + + +classmethod + +Load a vector store from a file. + +**Parameters:** + + +The path to load the vector store from. + + + +The embedding to use. + + + +Additional arguments to pass to the constructor. + + +**Returns:** `InMemoryVectorStore` + +A `VectorStore` object. + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.max_marginal_relevance_search( + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.max_marginal_relevance_search_by_vector( + embedding: list[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: collections.abc.Callable[[Document], bool] | None = None, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.similarity_search( + query: str, + k: int = 4, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.similarity_search_by_vector( + embedding: list[float], + k: int = 4, + kwargs: typing.Any = {} +) -> list[langchain_core.documents.Document] +``` + + + + + + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.similarity_search_with_score( + query: str, + k: int = 4, + kwargs: typing.Any = {} +) -> list[tuple[langchain_core.documents.Document, float]] +``` + + + + + + + + + + + + +```python +langchain_core.vectorstores.in_memory.InMemoryVectorStore.similarity_search_with_score_by_vector( + embedding: list[float], + k: int = 4, + filter: collections.abc.Callable[[Document], bool] | None = None, + _kwargs: typing.Any = {} +) -> list[tuple[langchain_core.documents.Document, float]] +``` + + + + + + +Search for the most similar documents to the given embedding. + +**Parameters:** + + +The embedding to search for. + + + +The number of documents to return. + + + +A function to filter the documents. + + +**Returns:** `list[tuple[Document, float]]` + +A list of tuples of `Document` objects and their similarity scores. + + + + + + + + + +```python +langchain_core.vectorstores.in_memory._HAS_NUMPY = True +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/utils.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/utils.mdx new file mode 100644 index 0000000..354373b --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/vectorstores/utils.mdx @@ -0,0 +1,171 @@ +--- +layout: overview +slug: langchain-core/langchain_core/vectorstores/utils +title: langchain_core.vectorstores.utils +--- + +Internal utilities for the in memory implementation of `VectorStore`. + +!!! warning + + These are part of a private API, and users should not use them directly as they can + change without notice. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_cosine_similarity`](#langchain_core-vectorstores-utils-_cosine_similarity) | Row-wise cosine similarity between two equal-width matrices. | +| [`maximal_marginal_relevance`](#langchain_core-vectorstores-utils-maximal_marginal_relevance) | Calculate maximal marginal relevance. | + +### Data + +[`Matrix`](#langchain_core-vectorstores-utils-Matrix) + +[`_HAS_NUMPY`](#langchain_core-vectorstores-utils-_HAS_NUMPY) + +[`_HAS_SIMSIMD`](#langchain_core-vectorstores-utils-_HAS_SIMSIMD) + +[`logger`](#langchain_core-vectorstores-utils-logger) + +### API + + + + + +```python +langchain_core.vectorstores.utils._cosine_similarity( + x: langchain_core.vectorstores.utils.Matrix, + y: langchain_core.vectorstores.utils.Matrix +) -> numpy.ndarray +``` + + + + + + +Row-wise cosine similarity between two equal-width matrices. + +**Parameters:** + + +A matrix of shape `(n, m)`. + + + +A matrix of shape `(k, m)`. + + +**Returns:** `np.ndarray` + +A matrix of shape `(n, k)` where each element `(i, j)` is the cosine similarity +between the `i`th row of `x` and the `j`th row of `y`. + +**Raises:** + +- `ValueError`: If the number of columns in `x` and `y` are not the same. +- `ImportError`: If numpy is not installed. + + + + + + + + +```python +langchain_core.vectorstores.utils.maximal_marginal_relevance( + query_embedding: numpy.ndarray, + embedding_list: list, + lambda_mult: float = 0.5, + k: int = 4 +) -> list[int] +``` + + + + + + +Calculate maximal marginal relevance. + +**Parameters:** + + +The query embedding. + + + +A list of embeddings. + + + +The lambda parameter for MMR. + + + +The number of embeddings to return. + + +**Returns:** `list[int]` + +A list of indices of the embeddings to return. + +**Raises:** + +- `ImportError`: If numpy is not installed. + + + + + + + + +```python +langchain_core.vectorstores.utils.Matrix = list[list[float]] | list[np.ndarray] | np.ndarray +``` + + + + + + + + + +```python +langchain_core.vectorstores.utils._HAS_NUMPY = True +``` + + + + + + + + + +```python +langchain_core.vectorstores.utils._HAS_SIMSIMD = True +``` + + + + + + + + + +```python +langchain_core.vectorstores.utils.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/version.mdx b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/version.mdx new file mode 100644 index 0000000..b823fb1 --- /dev/null +++ b/fern/library-docs/langchain-core-docs/langchain-core/langchain_core/version.mdx @@ -0,0 +1,27 @@ +--- +layout: overview +slug: langchain-core/langchain_core/version +title: langchain_core.version +--- + +langchain-core version information and utilities. + +## Module Contents + +### Data + +[`VERSION`](#langchain_core-version-VERSION) + +### API + + + + + +```python +langchain_core.version.VERSION = '1.2.15' +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/_navigation.yml b/fern/library-docs/nemo-rl-docs/_navigation.yml new file mode 100644 index 0000000..75d75a9 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/_navigation.yml @@ -0,0 +1,1033 @@ +# AUTO-GENERATED by `fern docs md generate` — DO NOT EDIT +- type: section + title: algorithms + slug: nemo-rl/nemo_rl/algorithms + children: + - type: section + title: advantage_estimator + slug: nemo-rl/nemo_rl/algorithms/advantage_estimator + children: + - type: page + title: advantage_estimator + slug: nemo-rl/nemo_rl/algorithms/advantage_estimator + pageId: nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx + - type: section + title: async_utils + slug: nemo-rl/nemo_rl/algorithms/async_utils + children: + - type: page + title: async_utils + slug: nemo-rl/nemo_rl/algorithms/async_utils + pageId: nemo-rl/nemo_rl/algorithms/async_utils.mdx + - type: section + title: distillation + slug: nemo-rl/nemo_rl/algorithms/distillation + children: + - type: page + title: distillation + slug: nemo-rl/nemo_rl/algorithms/distillation + pageId: nemo-rl/nemo_rl/algorithms/distillation.mdx + - type: section + title: dpo + slug: nemo-rl/nemo_rl/algorithms/dpo + children: + - type: page + title: dpo + slug: nemo-rl/nemo_rl/algorithms/dpo + pageId: nemo-rl/nemo_rl/algorithms/dpo.mdx + - type: section + title: grpo + slug: nemo-rl/nemo_rl/algorithms/grpo + children: + - type: page + title: grpo + slug: nemo-rl/nemo_rl/algorithms/grpo + pageId: nemo-rl/nemo_rl/algorithms/grpo.mdx + - type: section + title: interfaces + slug: nemo-rl/nemo_rl/algorithms/interfaces + children: + - type: page + title: interfaces + slug: nemo-rl/nemo_rl/algorithms/interfaces + pageId: nemo-rl/nemo_rl/algorithms/interfaces.mdx + - type: section + title: loss_functions + slug: nemo-rl/nemo_rl/algorithms/loss_functions + children: + - type: page + title: loss_functions + slug: nemo-rl/nemo_rl/algorithms/loss_functions + pageId: nemo-rl/nemo_rl/algorithms/loss_functions.mdx + - type: section + title: reward_functions + slug: nemo-rl/nemo_rl/algorithms/reward_functions + children: + - type: page + title: reward_functions + slug: nemo-rl/nemo_rl/algorithms/reward_functions + pageId: nemo-rl/nemo_rl/algorithms/reward_functions.mdx + - type: section + title: rm + slug: nemo-rl/nemo_rl/algorithms/rm + children: + - type: page + title: rm + slug: nemo-rl/nemo_rl/algorithms/rm + pageId: nemo-rl/nemo_rl/algorithms/rm.mdx + - type: section + title: sft + slug: nemo-rl/nemo_rl/algorithms/sft + children: + - type: page + title: sft + slug: nemo-rl/nemo_rl/algorithms/sft + pageId: nemo-rl/nemo_rl/algorithms/sft.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/algorithms/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/algorithms/utils + pageId: nemo-rl/nemo_rl/algorithms/utils.mdx +- type: section + title: data + slug: nemo-rl/nemo_rl/data + children: + - type: section + title: chat_templates + slug: nemo-rl/nemo_rl/data/chat_templates + children: + - type: page + title: chat_templates + slug: nemo-rl/nemo_rl/data/chat_templates + pageId: nemo-rl/nemo_rl/data/chat_templates.mdx + - type: section + title: collate_fn + slug: nemo-rl/nemo_rl/data/collate_fn + children: + - type: page + title: collate_fn + slug: nemo-rl/nemo_rl/data/collate_fn + pageId: nemo-rl/nemo_rl/data/collate_fn.mdx + - type: section + title: datasets + slug: nemo-rl/nemo_rl/data/datasets + children: + - type: section + title: eval_datasets + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets + children: + - type: section + title: aime + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime + children: + - type: page + title: aime + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx + - type: section + title: gpqa + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa + children: + - type: page + title: gpqa + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx + - type: section + title: local_math_dataset + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset + children: + - type: page + title: local_math_dataset + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx + - type: section + title: math + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/math + children: + - type: page + title: math + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/math + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx + - type: section + title: mmlu + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu + children: + - type: page + title: mmlu + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx + - type: section + title: mmlu_pro + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro + children: + - type: page + title: mmlu_pro + slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro + pageId: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx + - type: section + title: preference_datasets + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets + children: + - type: section + title: binary_preference_dataset + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset + children: + - type: page + title: binary_preference_dataset + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset + pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx + - type: section + title: helpsteer3 + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3 + children: + - type: page + title: helpsteer3 + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3 + pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx + - type: section + title: preference_dataset + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset + children: + - type: page + title: preference_dataset + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset + pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx + - type: section + title: tulu3 + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3 + children: + - type: page + title: tulu3 + slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3 + pageId: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx + - type: section + title: processed_dataset + slug: nemo-rl/nemo_rl/data/datasets/processed_dataset + children: + - type: page + title: processed_dataset + slug: nemo-rl/nemo_rl/data/datasets/processed_dataset + pageId: nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx + - type: section + title: raw_dataset + slug: nemo-rl/nemo_rl/data/datasets/raw_dataset + children: + - type: page + title: raw_dataset + slug: nemo-rl/nemo_rl/data/datasets/raw_dataset + pageId: nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx + - type: section + title: response_datasets + slug: nemo-rl/nemo_rl/data/datasets/response_datasets + children: + - type: section + title: aime24 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24 + children: + - type: page + title: aime24 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24 + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx + - type: section + title: clevr + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr + children: + - type: page + title: clevr + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx + - type: section + title: dapo_math + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math + children: + - type: page + title: dapo_math + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx + - type: section + title: deepscaler + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler + children: + - type: page + title: deepscaler + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx + - type: section + title: geometry3k + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k + children: + - type: page + title: geometry3k + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx + - type: section + title: helpsteer3 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3 + children: + - type: page + title: helpsteer3 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3 + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx + - type: section + title: nemogym_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset + children: + - type: page + title: nemogym_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx + - type: section + title: oai_format_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset + children: + - type: page + title: oai_format_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx + - type: section + title: oasst + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst + children: + - type: page + title: oasst + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx + - type: section + title: openmathinstruct2 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2 + children: + - type: page + title: openmathinstruct2 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2 + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx + - type: section + title: refcoco + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco + children: + - type: page + title: refcoco + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx + - type: section + title: response_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset + children: + - type: page + title: response_dataset + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx + - type: section + title: squad + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/squad + children: + - type: page + title: squad + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/squad + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx + - type: section + title: tulu3 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3 + children: + - type: page + title: tulu3 + slug: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3 + pageId: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/data/datasets/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/data/datasets/utils + pageId: nemo-rl/nemo_rl/data/datasets/utils.mdx + - type: section + title: interfaces + slug: nemo-rl/nemo_rl/data/interfaces + children: + - type: page + title: interfaces + slug: nemo-rl/nemo_rl/data/interfaces + pageId: nemo-rl/nemo_rl/data/interfaces.mdx + - type: section + title: llm_message_utils + slug: nemo-rl/nemo_rl/data/llm_message_utils + children: + - type: page + title: llm_message_utils + slug: nemo-rl/nemo_rl/data/llm_message_utils + pageId: nemo-rl/nemo_rl/data/llm_message_utils.mdx + - type: section + title: multimodal_utils + slug: nemo-rl/nemo_rl/data/multimodal_utils + children: + - type: page + title: multimodal_utils + slug: nemo-rl/nemo_rl/data/multimodal_utils + pageId: nemo-rl/nemo_rl/data/multimodal_utils.mdx + - type: section + title: packing + slug: nemo-rl/nemo_rl/data/packing + children: + - type: section + title: algorithms + slug: nemo-rl/nemo_rl/data/packing/algorithms + children: + - type: page + title: algorithms + slug: nemo-rl/nemo_rl/data/packing/algorithms + pageId: nemo-rl/nemo_rl/data/packing/algorithms.mdx + - type: section + title: metrics + slug: nemo-rl/nemo_rl/data/packing/metrics + children: + - type: page + title: metrics + slug: nemo-rl/nemo_rl/data/packing/metrics + pageId: nemo-rl/nemo_rl/data/packing/metrics.mdx + - type: section + title: processors + slug: nemo-rl/nemo_rl/data/processors + children: + - type: page + title: processors + slug: nemo-rl/nemo_rl/data/processors + pageId: nemo-rl/nemo_rl/data/processors.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/data/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/data/utils + pageId: nemo-rl/nemo_rl/data/utils.mdx +- type: section + title: distributed + slug: nemo-rl/nemo_rl/distributed + children: + - type: section + title: batched_data_dict + slug: nemo-rl/nemo_rl/distributed/batched_data_dict + children: + - type: page + title: batched_data_dict + slug: nemo-rl/nemo_rl/distributed/batched_data_dict + pageId: nemo-rl/nemo_rl/distributed/batched_data_dict.mdx + - type: section + title: collectives + slug: nemo-rl/nemo_rl/distributed/collectives + children: + - type: page + title: collectives + slug: nemo-rl/nemo_rl/distributed/collectives + pageId: nemo-rl/nemo_rl/distributed/collectives.mdx + - type: section + title: model_utils + slug: nemo-rl/nemo_rl/distributed/model_utils + children: + - type: page + title: model_utils + slug: nemo-rl/nemo_rl/distributed/model_utils + pageId: nemo-rl/nemo_rl/distributed/model_utils.mdx + - type: section + title: named_sharding + slug: nemo-rl/nemo_rl/distributed/named_sharding + children: + - type: page + title: named_sharding + slug: nemo-rl/nemo_rl/distributed/named_sharding + pageId: nemo-rl/nemo_rl/distributed/named_sharding.mdx + - type: section + title: ray_actor_environment_registry + slug: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry + children: + - type: page + title: ray_actor_environment_registry + slug: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry + pageId: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx + - type: section + title: stateless_process_group + slug: nemo-rl/nemo_rl/distributed/stateless_process_group + children: + - type: page + title: stateless_process_group + slug: nemo-rl/nemo_rl/distributed/stateless_process_group + pageId: nemo-rl/nemo_rl/distributed/stateless_process_group.mdx + - type: section + title: virtual_cluster + slug: nemo-rl/nemo_rl/distributed/virtual_cluster + children: + - type: page + title: virtual_cluster + slug: nemo-rl/nemo_rl/distributed/virtual_cluster + pageId: nemo-rl/nemo_rl/distributed/virtual_cluster.mdx + - type: section + title: worker_group_utils + slug: nemo-rl/nemo_rl/distributed/worker_group_utils + children: + - type: page + title: worker_group_utils + slug: nemo-rl/nemo_rl/distributed/worker_group_utils + pageId: nemo-rl/nemo_rl/distributed/worker_group_utils.mdx + - type: section + title: worker_groups + slug: nemo-rl/nemo_rl/distributed/worker_groups + children: + - type: page + title: worker_groups + slug: nemo-rl/nemo_rl/distributed/worker_groups + pageId: nemo-rl/nemo_rl/distributed/worker_groups.mdx +- type: section + title: environments + slug: nemo-rl/nemo_rl/environments + children: + - type: section + title: code_environment + slug: nemo-rl/nemo_rl/environments/code_environment + children: + - type: page + title: code_environment + slug: nemo-rl/nemo_rl/environments/code_environment + pageId: nemo-rl/nemo_rl/environments/code_environment.mdx + - type: section + title: code_jaccard_environment + slug: nemo-rl/nemo_rl/environments/code_jaccard_environment + children: + - type: page + title: code_jaccard_environment + slug: nemo-rl/nemo_rl/environments/code_jaccard_environment + pageId: nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx + - type: section + title: dapo_math_verifier + slug: nemo-rl/nemo_rl/environments/dapo_math_verifier + children: + - type: page + title: dapo_math_verifier + slug: nemo-rl/nemo_rl/environments/dapo_math_verifier + pageId: nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx + - type: section + title: interfaces + slug: nemo-rl/nemo_rl/environments/interfaces + children: + - type: page + title: interfaces + slug: nemo-rl/nemo_rl/environments/interfaces + pageId: nemo-rl/nemo_rl/environments/interfaces.mdx + - type: section + title: math_environment + slug: nemo-rl/nemo_rl/environments/math_environment + children: + - type: page + title: math_environment + slug: nemo-rl/nemo_rl/environments/math_environment + pageId: nemo-rl/nemo_rl/environments/math_environment.mdx + - type: section + title: metrics + slug: nemo-rl/nemo_rl/environments/metrics + children: + - type: page + title: metrics + slug: nemo-rl/nemo_rl/environments/metrics + pageId: nemo-rl/nemo_rl/environments/metrics.mdx + - type: section + title: nemo_gym + slug: nemo-rl/nemo_rl/environments/nemo_gym + children: + - type: page + title: nemo_gym + slug: nemo-rl/nemo_rl/environments/nemo_gym + pageId: nemo-rl/nemo_rl/environments/nemo_gym.mdx + - type: section + title: reward_model_environment + slug: nemo-rl/nemo_rl/environments/reward_model_environment + children: + - type: page + title: reward_model_environment + slug: nemo-rl/nemo_rl/environments/reward_model_environment + pageId: nemo-rl/nemo_rl/environments/reward_model_environment.mdx + - type: section + title: rewards + slug: nemo-rl/nemo_rl/environments/rewards + children: + - type: page + title: rewards + slug: nemo-rl/nemo_rl/environments/rewards + pageId: nemo-rl/nemo_rl/environments/rewards.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/environments/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/environments/utils + pageId: nemo-rl/nemo_rl/environments/utils.mdx + - type: section + title: vlm_environment + slug: nemo-rl/nemo_rl/environments/vlm_environment + children: + - type: page + title: vlm_environment + slug: nemo-rl/nemo_rl/environments/vlm_environment + pageId: nemo-rl/nemo_rl/environments/vlm_environment.mdx +- type: section + title: evals + slug: nemo-rl/nemo_rl/evals + children: + - type: section + title: answer_parsing + slug: nemo-rl/nemo_rl/evals/answer_parsing + children: + - type: page + title: answer_parsing + slug: nemo-rl/nemo_rl/evals/answer_parsing + pageId: nemo-rl/nemo_rl/evals/answer_parsing.mdx + - type: section + title: eval + slug: nemo-rl/nemo_rl/evals/eval + children: + - type: page + title: eval + slug: nemo-rl/nemo_rl/evals/eval + pageId: nemo-rl/nemo_rl/evals/eval.mdx +- type: section + title: experience + slug: nemo-rl/nemo_rl/experience + children: + - type: section + title: rollouts + slug: nemo-rl/nemo_rl/experience/rollouts + children: + - type: page + title: rollouts + slug: nemo-rl/nemo_rl/experience/rollouts + pageId: nemo-rl/nemo_rl/experience/rollouts.mdx +- type: section + title: models + slug: nemo-rl/nemo_rl/models + children: + - type: section + title: automodel + slug: nemo-rl/nemo_rl/models/automodel + children: + - type: section + title: config + slug: nemo-rl/nemo_rl/models/automodel/config + children: + - type: page + title: config + slug: nemo-rl/nemo_rl/models/automodel/config + pageId: nemo-rl/nemo_rl/models/automodel/config.mdx + - type: section + title: data + slug: nemo-rl/nemo_rl/models/automodel/data + children: + - type: page + title: data + slug: nemo-rl/nemo_rl/models/automodel/data + pageId: nemo-rl/nemo_rl/models/automodel/data.mdx + - type: section + title: setup + slug: nemo-rl/nemo_rl/models/automodel/setup + children: + - type: page + title: setup + slug: nemo-rl/nemo_rl/models/automodel/setup + pageId: nemo-rl/nemo_rl/models/automodel/setup.mdx + - type: section + title: train + slug: nemo-rl/nemo_rl/models/automodel/train + children: + - type: page + title: train + slug: nemo-rl/nemo_rl/models/automodel/train + pageId: nemo-rl/nemo_rl/models/automodel/train.mdx + - type: section + title: dtensor + slug: nemo-rl/nemo_rl/models/dtensor + children: + - type: section + title: parallelize + slug: nemo-rl/nemo_rl/models/dtensor/parallelize + children: + - type: page + title: parallelize + slug: nemo-rl/nemo_rl/models/dtensor/parallelize + pageId: nemo-rl/nemo_rl/models/dtensor/parallelize.mdx + - type: section + title: generation + slug: nemo-rl/nemo_rl/models/generation + children: + - type: section + title: interfaces + slug: nemo-rl/nemo_rl/models/generation/interfaces + children: + - type: page + title: interfaces + slug: nemo-rl/nemo_rl/models/generation/interfaces + pageId: nemo-rl/nemo_rl/models/generation/interfaces.mdx + - type: section + title: sglang + slug: nemo-rl/nemo_rl/models/generation/sglang + children: + - type: section + title: config + slug: nemo-rl/nemo_rl/models/generation/sglang/config + children: + - type: page + title: config + slug: nemo-rl/nemo_rl/models/generation/sglang/config + pageId: nemo-rl/nemo_rl/models/generation/sglang/config.mdx + - type: section + title: sglang_copied_utils + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils + children: + - type: page + title: sglang_copied_utils + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils + pageId: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx + - type: section + title: sglang_generation + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation + children: + - type: page + title: sglang_generation + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation + pageId: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx + - type: section + title: sglang_worker + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker + children: + - type: page + title: sglang_worker + slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker + pageId: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/models/generation/sglang/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/models/generation/sglang/utils + pageId: nemo-rl/nemo_rl/models/generation/sglang/utils.mdx + - type: section + title: vllm + slug: nemo-rl/nemo_rl/models/generation/vllm + children: + - type: section + title: config + slug: nemo-rl/nemo_rl/models/generation/vllm/config + children: + - type: page + title: config + slug: nemo-rl/nemo_rl/models/generation/vllm/config + pageId: nemo-rl/nemo_rl/models/generation/vllm/config.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/models/generation/vllm/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/models/generation/vllm/utils + pageId: nemo-rl/nemo_rl/models/generation/vllm/utils.mdx + - type: section + title: vllm_backend + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend + children: + - type: page + title: vllm_backend + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend + pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx + - type: section + title: vllm_generation + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation + children: + - type: page + title: vllm_generation + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation + pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx + - type: section + title: vllm_worker + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker + children: + - type: page + title: vllm_worker + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker + pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx + - type: section + title: vllm_worker_async + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async + children: + - type: page + title: vllm_worker_async + slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async + pageId: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx + - type: section + title: huggingface + slug: nemo-rl/nemo_rl/models/huggingface + children: + - type: section + title: common + slug: nemo-rl/nemo_rl/models/huggingface/common + children: + - type: page + title: common + slug: nemo-rl/nemo_rl/models/huggingface/common + pageId: nemo-rl/nemo_rl/models/huggingface/common.mdx + - type: section + title: megatron + slug: nemo-rl/nemo_rl/models/megatron + children: + - type: section + title: common + slug: nemo-rl/nemo_rl/models/megatron/common + children: + - type: page + title: common + slug: nemo-rl/nemo_rl/models/megatron/common + pageId: nemo-rl/nemo_rl/models/megatron/common.mdx + - type: section + title: community_import + slug: nemo-rl/nemo_rl/models/megatron/community_import + children: + - type: page + title: community_import + slug: nemo-rl/nemo_rl/models/megatron/community_import + pageId: nemo-rl/nemo_rl/models/megatron/community_import.mdx + - type: section + title: config + slug: nemo-rl/nemo_rl/models/megatron/config + children: + - type: page + title: config + slug: nemo-rl/nemo_rl/models/megatron/config + pageId: nemo-rl/nemo_rl/models/megatron/config.mdx + - type: section + title: data + slug: nemo-rl/nemo_rl/models/megatron/data + children: + - type: page + title: data + slug: nemo-rl/nemo_rl/models/megatron/data + pageId: nemo-rl/nemo_rl/models/megatron/data.mdx + - type: section + title: pipeline_parallel + slug: nemo-rl/nemo_rl/models/megatron/pipeline_parallel + children: + - type: page + title: pipeline_parallel + slug: nemo-rl/nemo_rl/models/megatron/pipeline_parallel + pageId: nemo-rl/nemo_rl/models/megatron/pipeline_parallel.mdx + - type: section + title: setup + slug: nemo-rl/nemo_rl/models/megatron/setup + children: + - type: page + title: setup + slug: nemo-rl/nemo_rl/models/megatron/setup + pageId: nemo-rl/nemo_rl/models/megatron/setup.mdx + - type: section + title: train + slug: nemo-rl/nemo_rl/models/megatron/train + children: + - type: page + title: train + slug: nemo-rl/nemo_rl/models/megatron/train + pageId: nemo-rl/nemo_rl/models/megatron/train.mdx + - type: section + title: policy + slug: nemo-rl/nemo_rl/models/policy + children: + - type: section + title: interfaces + slug: nemo-rl/nemo_rl/models/policy/interfaces + children: + - type: page + title: interfaces + slug: nemo-rl/nemo_rl/models/policy/interfaces + pageId: nemo-rl/nemo_rl/models/policy/interfaces.mdx + - type: section + title: lm_policy + slug: nemo-rl/nemo_rl/models/policy/lm_policy + children: + - type: page + title: lm_policy + slug: nemo-rl/nemo_rl/models/policy/lm_policy + pageId: nemo-rl/nemo_rl/models/policy/lm_policy.mdx + - type: section + title: utils + slug: nemo-rl/nemo_rl/models/policy/utils + children: + - type: page + title: utils + slug: nemo-rl/nemo_rl/models/policy/utils + pageId: nemo-rl/nemo_rl/models/policy/utils.mdx + - type: section + title: workers + slug: nemo-rl/nemo_rl/models/policy/workers + children: + - type: section + title: base_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker + children: + - type: page + title: base_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker + pageId: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx + - type: section + title: dtensor_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker + children: + - type: page + title: dtensor_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker + pageId: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx + - type: section + title: dtensor_policy_worker_v2 + slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2 + children: + - type: page + title: dtensor_policy_worker_v2 + slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2 + pageId: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx + - type: section + title: megatron_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker + children: + - type: page + title: megatron_policy_worker + slug: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker + pageId: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx + - type: section + title: patches + slug: nemo-rl/nemo_rl/models/policy/workers/patches + children: + - type: page + title: patches + slug: nemo-rl/nemo_rl/models/policy/workers/patches + pageId: nemo-rl/nemo_rl/models/policy/workers/patches.mdx +- type: section + title: package_info + slug: nemo-rl/nemo_rl/package_info + children: + - type: page + title: package_info + slug: nemo-rl/nemo_rl/package_info + pageId: nemo-rl/nemo_rl/package_info.mdx +- type: section + title: utils + slug: nemo-rl/nemo_rl/utils + children: + - type: section + title: automodel_checkpoint + slug: nemo-rl/nemo_rl/utils/automodel_checkpoint + children: + - type: page + title: automodel_checkpoint + slug: nemo-rl/nemo_rl/utils/automodel_checkpoint + pageId: nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx + - type: section + title: checkpoint + slug: nemo-rl/nemo_rl/utils/checkpoint + children: + - type: page + title: checkpoint + slug: nemo-rl/nemo_rl/utils/checkpoint + pageId: nemo-rl/nemo_rl/utils/checkpoint.mdx + - type: section + title: config + slug: nemo-rl/nemo_rl/utils/config + children: + - type: page + title: config + slug: nemo-rl/nemo_rl/utils/config + pageId: nemo-rl/nemo_rl/utils/config.mdx + - type: section + title: flops_formulas + slug: nemo-rl/nemo_rl/utils/flops_formulas + children: + - type: page + title: flops_formulas + slug: nemo-rl/nemo_rl/utils/flops_formulas + pageId: nemo-rl/nemo_rl/utils/flops_formulas.mdx + - type: section + title: flops_tracker + slug: nemo-rl/nemo_rl/utils/flops_tracker + children: + - type: page + title: flops_tracker + slug: nemo-rl/nemo_rl/utils/flops_tracker + pageId: nemo-rl/nemo_rl/utils/flops_tracker.mdx + - type: section + title: logger + slug: nemo-rl/nemo_rl/utils/logger + children: + - type: page + title: logger + slug: nemo-rl/nemo_rl/utils/logger + pageId: nemo-rl/nemo_rl/utils/logger.mdx + - type: section + title: memory_tracker + slug: nemo-rl/nemo_rl/utils/memory_tracker + children: + - type: page + title: memory_tracker + slug: nemo-rl/nemo_rl/utils/memory_tracker + pageId: nemo-rl/nemo_rl/utils/memory_tracker.mdx + - type: section + title: native_checkpoint + slug: nemo-rl/nemo_rl/utils/native_checkpoint + children: + - type: page + title: native_checkpoint + slug: nemo-rl/nemo_rl/utils/native_checkpoint + pageId: nemo-rl/nemo_rl/utils/native_checkpoint.mdx + - type: section + title: nsys + slug: nemo-rl/nemo_rl/utils/nsys + children: + - type: page + title: nsys + slug: nemo-rl/nemo_rl/utils/nsys + pageId: nemo-rl/nemo_rl/utils/nsys.mdx + - type: section + title: nvml + slug: nemo-rl/nemo_rl/utils/nvml + children: + - type: page + title: nvml + slug: nemo-rl/nemo_rl/utils/nvml + pageId: nemo-rl/nemo_rl/utils/nvml.mdx + - type: section + title: packed_tensor + slug: nemo-rl/nemo_rl/utils/packed_tensor + children: + - type: page + title: packed_tensor + slug: nemo-rl/nemo_rl/utils/packed_tensor + pageId: nemo-rl/nemo_rl/utils/packed_tensor.mdx + - type: section + title: prefetch_venvs + slug: nemo-rl/nemo_rl/utils/prefetch_venvs + children: + - type: page + title: prefetch_venvs + slug: nemo-rl/nemo_rl/utils/prefetch_venvs + pageId: nemo-rl/nemo_rl/utils/prefetch_venvs.mdx + - type: section + title: timer + slug: nemo-rl/nemo_rl/utils/timer + children: + - type: page + title: timer + slug: nemo-rl/nemo_rl/utils/timer + pageId: nemo-rl/nemo_rl/utils/timer.mdx + - type: section + title: venvs + slug: nemo-rl/nemo_rl/utils/venvs + children: + - type: page + title: venvs + slug: nemo-rl/nemo_rl/utils/venvs + pageId: nemo-rl/nemo_rl/utils/venvs.mdx diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl.mdx new file mode 100644 index 0000000..002c19d --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl.mdx @@ -0,0 +1,149 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl +title: nemo_rl +--- + +## Subpackages + +- **[`nemo_rl.algorithms`](/nemo-rl/nemo_rl/algorithms)** +- **[`nemo_rl.data`](/nemo-rl/nemo_rl/data)** +- **[`nemo_rl.distributed`](/nemo-rl/nemo_rl/distributed)** +- **[`nemo_rl.environments`](/nemo-rl/nemo_rl/environments)** +- **[`nemo_rl.evals`](/nemo-rl/nemo_rl/evals)** +- **[`nemo_rl.experience`](/nemo-rl/nemo_rl/experience)** +- **[`nemo_rl.models`](/nemo-rl/nemo_rl/models)** +- **[`nemo_rl.utils`](/nemo-rl/nemo_rl/utils)** + +## Submodules + +- **[`nemo_rl.package_info`](/nemo-rl/nemo_rl/package_info)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_check_container_fingerprint`](#nemo_rl-_check_container_fingerprint) | Check if container dependencies match the current code (container-only). | +| [`_is_build_isolation`](#nemo_rl-_is_build_isolation) | Detect if we're running in a uv build isolation environment. | +| [`_patch_nsight_file`](#nemo_rl-_patch_nsight_file) | Patch the nsight.py file to fix the context.py_executable assignment. | +| [`patch_transformers_module_dir`](#nemo_rl-patch_transformers_module_dir) | - | + +### Data + +[`megatron_path`](#nemo_rl-megatron_path) + +### API + + + + + +```python +nemo_rl._check_container_fingerprint() +``` + + + + + + +Check if container dependencies match the current code (container-only). + +This check only runs when NRL_CONTAINER=1 is set (inside containers). +It compares the container's fingerprint (computed at build time) with +the current code's fingerprint to detect dependency drift. + +This check is also skipped entirely if NRL_FORCE_REBUILD_VENVS=true is set, +since environment rebuilding will ensure dependencies are consistent regardless +of a mismatch. + +If there's a mismatch, raises RuntimeError unless NRL_IGNORE_VERSION_MISMATCH is set. + + + + + + + + +```python +nemo_rl._is_build_isolation() +``` + + + + + + +Detect if we're running in a uv build isolation environment. + +When running uv lock/sync, uv creates a temporary isolated environment +in ~/.cache/uv/builds-v*/ to build packages and introspect metadata. +We skip the fingerprint check in this context since the user is updating dependencies. + +Returns True if in build isolation, False otherwise. + + + + + + + + +```python +nemo_rl._patch_nsight_file() +``` + + + + + + +Patch the nsight.py file to fix the context.py_executable assignment. + +Until this fix is upstreamed, we will maintain this patch here. This patching +logic is only applied if the user intends to use nsys profiling which they enable with +NRL_NSYS_WORKER_PATTERNS. + +If enabled, will effectively apply the following patch in an idempotent manner: + +https://github.com/ray-project/ray/compare/master...terrykong:ray:tk/nsight-py-exeutable-fix?expand=1 + +This hack works b/c the nsight plugin is not called from the main driver process, so +as soon as nemo_rl is imported, the patch is applied and the source of the nsight.py module +is up to date before the nsight.py is actually needed. + + + + + + + + +```python +nemo_rl.patch_transformers_module_dir( + env_vars: dict[str, str] +) +``` + + + + + + + + + + + + + +```python +nemo_rl.megatron_path = Path(__file__).parent.parent / '3rdparty' / 'Megatron-LM-workspace' / 'Megatron-... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx new file mode 100644 index 0000000..7f03746 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms.mdx @@ -0,0 +1,19 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms +title: nemo_rl.algorithms +--- + +## Submodules + +- **[`nemo_rl.algorithms.advantage_estimator`](/nemo-rl/nemo_rl/algorithms/advantage_estimator)** +- **[`nemo_rl.algorithms.async_utils`](/nemo-rl/nemo_rl/algorithms/async_utils)** +- **[`nemo_rl.algorithms.distillation`](/nemo-rl/nemo_rl/algorithms/distillation)** +- **[`nemo_rl.algorithms.dpo`](/nemo-rl/nemo_rl/algorithms/dpo)** +- **[`nemo_rl.algorithms.grpo`](/nemo-rl/nemo_rl/algorithms/grpo)** +- **[`nemo_rl.algorithms.interfaces`](/nemo-rl/nemo_rl/algorithms/interfaces)** +- **[`nemo_rl.algorithms.loss_functions`](/nemo-rl/nemo_rl/algorithms/loss_functions)** +- **[`nemo_rl.algorithms.reward_functions`](/nemo-rl/nemo_rl/algorithms/reward_functions)** +- **[`nemo_rl.algorithms.rm`](/nemo-rl/nemo_rl/algorithms/rm)** +- **[`nemo_rl.algorithms.sft`](/nemo-rl/nemo_rl/algorithms/sft)** +- **[`nemo_rl.algorithms.utils`](/nemo-rl/nemo_rl/algorithms/utils)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx new file mode 100644 index 0000000..0841909 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/advantage_estimator.mdx @@ -0,0 +1,196 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/advantage_estimator +title: nemo_rl.algorithms.advantage_estimator +--- + +Advantage Estimators for RL algorithms. + +This module provides different advantage estimation strategies: +- GRPOAdvantageEstimator: Standard GRPO advantage with leave-one-out baseline +- ReinforcePlusPlusAdvantageEstimator: Reinforce++ with optional baseline subtraction (minus_baseline) and KL penalty in reward +Reference papers: +- ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/ +- Reinforce++: https://arxiv.org/abs/2501.03262 + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`GRPOAdvantageEstimator`](#nemo_rl-algorithms-advantage_estimator-GRPOAdvantageEstimator) | GRPO-style advantage estimator with leave-one-out baseline. | +| [`ReinforcePlusPlusAdvantageEstimator`](#nemo_rl-algorithms-advantage_estimator-ReinforcePlusPlusAdvantageEstimator) | Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward. | + +### API + + + + + +```python +class nemo_rl.algorithms.advantage_estimator.GRPOAdvantageEstimator( + estimator_config: dict, + loss_config: dict +) +``` + + + + + + +GRPO-style advantage estimator with leave-one-out baseline. + +Note: GRPO computes advantages over all responses for each prompt. + + + + + + + + + + + +```python +nemo_rl.algorithms.advantage_estimator.GRPOAdvantageEstimator.compute_advantage( + prompt_ids, + rewards, + mask, + kwargs = {} +) +``` + + + + + + +Compute GRPO advantages. + +**Parameters:** + + +Tensor of shape [batch_size] identifying which prompt each sample belongs to. + + + +Tensor of shape [batch_size] containing reward for each sample. + + + +Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. + Used only for expanding advantages to token-level shape. + + + +Additional arguments (unused). + + +**Returns:** + +Advantages tensor of shape [batch_size, seq_len]. + + + + + + + + + +```python +class nemo_rl.algorithms.advantage_estimator.ReinforcePlusPlusAdvantageEstimator( + estimator_config: dict, + loss_config: dict +) +``` + + + + + + +Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward. + +**Parameters:** + + +If True, subtract per-prompt mean baseline from rewards. + + + +If True, add KL penalty to reward instead of loss. + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.advantage_estimator.ReinforcePlusPlusAdvantageEstimator.compute_advantage( + prompt_ids, + rewards, + mask, + logprobs_policy = None, + logprobs_reference = None, + kwargs = {} +) +``` + + + + + + +Compute Reinforce++ advantages with optional KL penalty. + +**Parameters:** + + +Tensor of shape [batch_size] identifying which prompt each sample belongs to. + + + +Tensor of shape [batch_size] containing reward for each sample. + + + +Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. + Used for: (1) expanding advantages to token-level shape, (2) global normalization + that only considers valid tokens. + + + +Policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward. + + + +Reference policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward. + + + +Additional arguments (unused). + + +**Returns:** + +Advantages tensor of shape [batch_size, seq_len], globally normalized across valid tokens. + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx new file mode 100644 index 0000000..f9ad506 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/async_utils.mdx @@ -0,0 +1,572 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/async_utils +title: nemo_rl.algorithms.async_utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AsyncTrajectoryCollector`](#nemo_rl-algorithms-async_utils-AsyncTrajectoryCollector) | Collects trajectories asynchronously and adds them to replay buffer. | +| [`ReplayBuffer`](#nemo_rl-algorithms-async_utils-ReplayBuffer) | Replay buffer storing per-prompt groups. | + +### Data + +[`TokenizerType`](#nemo_rl-algorithms-async_utils-TokenizerType) + +### API + + + + + +```python +class nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + tokenizer: nemo_rl.algorithms.async_utils.TokenizerType, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + master_config: nemo_rl.algorithms.grpo.MasterConfig, + replay_buffer: typing.Any, + start_step: int = 0 +) +``` + + + + + + +Collects trajectories asynchronously and adds them to replay buffer. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._calculate_target_weights( + generation_weight_version: int +) -> list[int] +``` + + + + + + +Calculate target weight versions for given generation weight version. + +The list of versions returned enumerate the possible version a generation +server can target. These versions are looped over to see what training +step they can target. If all target versions are exhausted, this generation +server will remain idle until the next weight update. + +Example: +generation_weight_version = 10 +max_trajectory_age_steps = 4 + +**Returns:** `list[int]` + +[11, 12, 13, 14] # Meaning this generation server can create trajectories for training step 11, 12, 13, 14 + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._cleanup_finished_threads() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._collection_loop() +``` + + + + + + +Run the collection loop in background thread. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._get_next_target_for_generation( + generation_weight_version: int +) -> typing.Optional[int] +``` + + + + + + +Get the next target weight that needs generation (if any). + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._process_batch( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] +) -> None +``` + + + + + + +Process a single batch and generate for one target weight. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._run_prompt_group_worker( + repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + generation_weight_version: int, + target_weight_version: int, + prompt_idx: int +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector._should_pause_for_generation_limits() -> bool +``` + + + + + + +Check if collection should be paused due to generation limits. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.get_dataloader_state() -> dict +``` + + + + + + +Get the current dataloader state for checkpointing. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.get_weight_version() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.pause() -> None +``` + + + + + + +Pause trajectory collection. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.prepare_for_refit() -> None +``` + + + + + + +Pause new generation starts and optionally wait for pending generations. + +For vLLM V1 async engine, leverages in-flight weight updates via collective_rpc, +allowing ongoing generations to continue with their current KV caches while +weights are updated. This significantly improves async performance. + +For non-async engines, waits for all pending generations to complete before refit. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.resume() -> None +``` + + + + + + +Resume trajectory collection. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.resume_after_refit() -> None +``` + + + + + + +Resume new generation starts after refit is complete. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.set_weight_version( + version: int +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.start_collection( + dataloader: torchdata.stateful_dataloader.StatefulDataLoader +) -> None +``` + + + + + + +Start collecting trajectories from dataloader. + + + + + + + +```python +nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector.wait_for_pending_generations() -> None +``` + + + + + + +Wait for all in-flight generation threads to complete. + + + + + + + + + +```python +class nemo_rl.algorithms.async_utils.ReplayBuffer( + max_size: int +) +``` + + + + + + +Replay buffer storing per-prompt groups. + +A single entry corresponds to 1 prompt repeated by +grpo.num_generations_per_prompt (required to compute per-prompt advantages). + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.clear() -> None +``` + + + + + + +Clear the buffer. + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.get_debug_info() -> dict +``` + + + + + + +Get debug information about buffer state. + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.get_existing_target_weights() -> set[int] +``` + + + + + + +Get set of target weight versions that already have trajectories. + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.get_last_target_weight_already_generated() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.push_with_wait_signal( + trajectory: dict[str, typing.Any], + weight_version: int, + target_weight_version: int +) -> str +``` + + + + + + +Add a per-prompt trajectory group with metadata. + +**Parameters:** + + +data dict + + + +version of the model weights used for generation + + + +version of the model weights this trajectory is intended for training + + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.sample( + num_prompt_groups: int, + current_weight_version: int, + max_age_steps: int +) -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +Sample per-prompt trajectory groups intended for the current training step. + +Only returns trajectories with target_weight_version == current_weight_version. +If insufficient trajectories are available, returns None to stall training +until the remaining trajectories are generated. This ensures no trajectory +loses its last chance to be used for its intended training step. + +**Returns:** `Optional[dict[str, Any]]` + +Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None if insufficient data + + + + + + + +```python +nemo_rl.algorithms.async_utils.ReplayBuffer.size() -> int +``` + + + + + + +Return current buffer size. + + + + + + + + + +```python +nemo_rl.algorithms.async_utils.TokenizerType = PreTrainedTokenizerBase +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx new file mode 100644 index 0000000..2dede47 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/distillation.mdx @@ -0,0 +1,326 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/distillation +title: nemo_rl.algorithms.distillation +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DistillationConfig`](#nemo_rl-algorithms-distillation-DistillationConfig) | - | +| [`DistillationSaveState`](#nemo_rl-algorithms-distillation-DistillationSaveState) | - | +| [`MasterConfig`](#nemo_rl-algorithms-distillation-MasterConfig) | Main configuration structure. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_default_distillation_save_state`](#nemo_rl-algorithms-distillation-_default_distillation_save_state) | - | +| [`check_vocab_equality`](#nemo_rl-algorithms-distillation-check_vocab_equality) | Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal. | +| [`distillation_train`](#nemo_rl-algorithms-distillation-distillation_train) | Run Distillation training algorithm. | +| [`setup`](#nemo_rl-algorithms-distillation-setup) | Main entry point for distillation algorithm. | +| [`validate`](#nemo_rl-algorithms-distillation-validate) | Run validation on the validation dataset. | + +### Data + +[`TokenizerType`](#nemo_rl-algorithms-distillation-TokenizerType) + +### API + + + + + +```python +class nemo_rl.algorithms.distillation.DistillationConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.distillation.DistillationSaveState +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.distillation.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Main configuration structure. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.distillation._default_distillation_save_state() -> nemo_rl.algorithms.distillation.DistillationSaveState +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.distillation.check_vocab_equality( + tokenizer: nemo_rl.algorithms.distillation.TokenizerType, + student_model_name: str, + teacher_model_name: str +) -> None +``` + + + + + + +Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal. + + + + + + + + +```python +nemo_rl.algorithms.distillation.distillation_train( + student_policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, + teacher_policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, + student_generation: typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], + dataloader: torchdata.stateful_dataloader.StatefulDataLoader, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer: nemo_rl.algorithms.distillation.TokenizerType, + loss_fn: nemo_rl.algorithms.loss_functions.DistillationLossFn, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], + logger: nemo_rl.utils.logger.Logger, + checkpointer: nemo_rl.utils.checkpoint.CheckpointManager, + distillation_save_state: nemo_rl.algorithms.distillation.DistillationSaveState, + master_config: nemo_rl.algorithms.distillation.MasterConfig +) -> None +``` + + + + + + +Run Distillation training algorithm. + + + + + + + + +```python +nemo_rl.algorithms.distillation.setup( + master_config: nemo_rl.algorithms.distillation.MasterConfig, + tokenizer: nemo_rl.algorithms.distillation.TokenizerType, + train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, + val_dataset: typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset] +) -> tuple[nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], torchdata.stateful_dataloader.StatefulDataLoader, typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.DistillationLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.distillation.DistillationSaveState, nemo_rl.algorithms.distillation.MasterConfig] +``` + + + + + + +Main entry point for distillation algorithm. + +**Returns:** `ColocatablePolicyInterface` + +tuple of student_policy, teacher_policy, student_generation, + + + + + + + + +```python +nemo_rl.algorithms.distillation.validate( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer, + val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], + step: int, + master_config: nemo_rl.algorithms.distillation.MasterConfig +) -> tuple[dict[str, typing.Any], dict[str, typing.Any]] +``` + + + + + + +Run validation on the validation dataset. + + + + + + + + +```python +nemo_rl.algorithms.distillation.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx new file mode 100644 index 0000000..3d57261 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/dpo.mdx @@ -0,0 +1,378 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/dpo +title: nemo_rl.algorithms.dpo +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DPOConfig`](#nemo_rl-algorithms-dpo-DPOConfig) | - | +| [`DPOSaveState`](#nemo_rl-algorithms-dpo-DPOSaveState) | - | +| [`DPOValMetrics`](#nemo_rl-algorithms-dpo-DPOValMetrics) | - | +| [`MasterConfig`](#nemo_rl-algorithms-dpo-MasterConfig) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_default_dpo_save_state`](#nemo_rl-algorithms-dpo-_default_dpo_save_state) | - | +| [`add_ref_logprobs_to_data`](#nemo_rl-algorithms-dpo-add_ref_logprobs_to_data) | - | +| [`dpo_train`](#nemo_rl-algorithms-dpo-dpo_train) | - | +| [`setup`](#nemo_rl-algorithms-dpo-setup) | Main entry point for running DPO algorithm. | +| [`validate`](#nemo_rl-algorithms-dpo-validate) | - | +| [`validate_one_dataset`](#nemo_rl-algorithms-dpo-validate_one_dataset) | Run validation on one validation dataset. | + +### API + + + + + +```python +class nemo_rl.algorithms.dpo.DPOConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.dpo.DPOSaveState +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.dpo.DPOValMetrics +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.dpo.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.dpo._default_dpo_save_state() -> nemo_rl.algorithms.dpo.DPOSaveState +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.dpo.add_ref_logprobs_to_data( + dataloader, + policy, + master_config, + is_val = False +) +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.dpo.dpo_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + checkpointer, + dpo_save_state: nemo_rl.algorithms.dpo.DPOSaveState +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.dpo.setup( + master_config: nemo_rl.algorithms.dpo.MasterConfig, + tokenizer: transformers.AutoTokenizer, + train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, + val_dataset: dict[str, nemo_rl.data.datasets.AllTaskProcessedDataset] +) -> tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, dict[str, torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.DPOLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.dpo.DPOSaveState, nemo_rl.algorithms.dpo.MasterConfig] +``` + + + + + + +Main entry point for running DPO algorithm. + +**Returns:** `tuple[Policy, RayVirtualCluster, StatefulDataLoader, dict[str, StatefulDataLoader], DPOLossFn, Logger, CheckpointManager, DPOSaveState, MasterConfig]` + +Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger + + + + + + + + +```python +nemo_rl.algorithms.dpo.validate( + policy: nemo_rl.models.policy.interfaces.PolicyInterface, + val_dataloader: dict[str, torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer, + loss_fn, + step: int, + master_config: nemo_rl.algorithms.dpo.MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int, + logger: nemo_rl.utils.logger.Logger +) +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.dpo.validate_one_dataset( + policy: nemo_rl.models.policy.interfaces.PolicyInterface, + val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader, + loss_fn, + step: int, + master_config: nemo_rl.algorithms.dpo.MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int, + dataset_name: str +) +``` + + + + + + +Run validation on one validation dataset. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx new file mode 100644 index 0000000..e75524c --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/grpo.mdx @@ -0,0 +1,916 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/grpo +title: nemo_rl.algorithms.grpo +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AdvEstimatorConfig`](#nemo_rl-algorithms-grpo-AdvEstimatorConfig) | Configuration for advantage estimator (GRPO or Reinforce++). | +| [`AsyncGRPOConfig`](#nemo_rl-algorithms-grpo-AsyncGRPOConfig) | - | +| [`GRPOConfig`](#nemo_rl-algorithms-grpo-GRPOConfig) | - | +| [`GRPOLoggerConfig`](#nemo_rl-algorithms-grpo-GRPOLoggerConfig) | - | +| [`GRPOSaveState`](#nemo_rl-algorithms-grpo-GRPOSaveState) | - | +| [`MasterConfig`](#nemo_rl-algorithms-grpo-MasterConfig) | - | +| [`RewardScalingConfig`](#nemo_rl-algorithms-grpo-RewardScalingConfig) | Configure linear reward scaling with clamping. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_create_advantage_estimator`](#nemo_rl-algorithms-grpo-_create_advantage_estimator) | Create and return an advantage estimator based on configuration. | +| [`_default_grpo_save_state`](#nemo_rl-algorithms-grpo-_default_grpo_save_state) | - | +| [`_extract_prompt_only_messages`](#nemo_rl-algorithms-grpo-_extract_prompt_only_messages) | Extract only prompt messages (user/system) from message logs. | +| [`_log_mixed_rewards_and_advantages_information`](#nemo_rl-algorithms-grpo-_log_mixed_rewards_and_advantages_information) | - | +| [`_should_log_nemo_gym_responses`](#nemo_rl-algorithms-grpo-_should_log_nemo_gym_responses) | - | +| [`_should_use_async_rollouts`](#nemo_rl-algorithms-grpo-_should_use_async_rollouts) | Determine if async rollouts should be used based on the configuration. | +| [`_should_use_nemo_gym`](#nemo_rl-algorithms-grpo-_should_use_nemo_gym) | Determine if NeMo-Gym should be used for rollouts and validation based on the configuration. | +| [`async_grpo_train`](#nemo_rl-algorithms-grpo-async_grpo_train) | Run asynchronous GRPO training with replay buffer. | +| [`compute_and_apply_seq_logprob_error_masking`](#nemo_rl-algorithms-grpo-compute_and_apply_seq_logprob_error_masking) | Compute sequence-level logprob error metrics and optionally mask high-error sequences. | +| [`dynamic_sampling`](#nemo_rl-algorithms-grpo-dynamic_sampling) | Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation. | +| [`grpo_train`](#nemo_rl-algorithms-grpo-grpo_train) | Run GRPO training algorithm. | +| [`refit_policy_generation`](#nemo_rl-algorithms-grpo-refit_policy_generation) | Refit the policy generation interface with the latest policy weights. | +| [`scale_rewards`](#nemo_rl-algorithms-grpo-scale_rewards) | Linearly scales rewards from a source range to a target range. | +| [`setup`](#nemo_rl-algorithms-grpo-setup) | Main entry point for running GRPO algorithm. | +| [`validate`](#nemo_rl-algorithms-grpo-validate) | Run validation on the validation dataset. | + +### Data + +[`TokenizerType`](#nemo_rl-algorithms-grpo-TokenizerType) + +### API + + + + + +```python +class nemo_rl.algorithms.grpo.AdvEstimatorConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for advantage estimator (GRPO or Reinforce++). + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.AsyncGRPOConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.GRPOConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.GRPOLoggerConfig() +``` + + + + + + +**Bases:** [LoggerConfig](/nemo-rl/nemo_rl/utils/logger#nemo_rl-utils-logger-LoggerConfig) + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.GRPOSaveState +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.grpo.RewardScalingConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configure linear reward scaling with clamping. + +When `enabled` is True, each reward is clamped to the source interval +[source_min, source_max] and linearly mapped to the target interval +[target_min, target_max]. Refer to the scale_rewards function for the implementation. + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.grpo._create_advantage_estimator( + master_config: nemo_rl.algorithms.grpo.MasterConfig +) +``` + + + + + + +Create and return an advantage estimator based on configuration. + +**Parameters:** + + +The master configuration dictionary. + + +**Returns:** + +An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator). + +**Raises:** + +- `ValueError`: If the advantage estimator name is not recognized. + + + + + + + + +```python +nemo_rl.algorithms.grpo._default_grpo_save_state() -> nemo_rl.algorithms.grpo.GRPOSaveState +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.grpo._extract_prompt_only_messages( + message_logs: list +) -> list +``` + + + + + + +Extract only prompt messages (user/system) from message logs. + +This is used to get prompt IDs for advantage estimation, excluding +any assistant responses. + +**Parameters:** + + +List of message logs, where each log is a list of messages. + + +**Returns:** `list` + +List of message logs containing only user and system messages. + + + + + + + + +```python +nemo_rl.algorithms.grpo._log_mixed_rewards_and_advantages_information( + logger: nemo_rl.utils.logger.Logger, + total_steps: int, + metrics: dict[str, typing.Any], + baseline: torch.Tensor, + advantages: torch.Tensor +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.grpo._should_log_nemo_gym_responses( + master_config: nemo_rl.algorithms.grpo.MasterConfig +) -> bool +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.grpo._should_use_async_rollouts( + master_config: nemo_rl.algorithms.grpo.MasterConfig +) -> bool +``` + + + + + + +Determine if async rollouts should be used based on the configuration. + +Returns True if vLLM backend is used with async_engine enabled. + + + + + + + + +```python +nemo_rl.algorithms.grpo._should_use_nemo_gym( + master_config: nemo_rl.algorithms.grpo.MasterConfig +) -> bool +``` + + + + + + +Determine if NeMo-Gym should be used for rollouts and validation based on the configuration. + + + + + + + + +```python +nemo_rl.algorithms.grpo.async_grpo_train( + policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, + policy_generation: typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], + dataloader: torchdata.stateful_dataloader.StatefulDataLoader, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer: nemo_rl.algorithms.grpo.TokenizerType, + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], + logger: nemo_rl.utils.logger.Logger, + checkpointer: nemo_rl.utils.checkpoint.CheckpointManager, + grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState, + master_config: nemo_rl.algorithms.grpo.MasterConfig, + max_trajectory_age_steps: int = 1 +) -> None +``` + + + + + + +Run asynchronous GRPO training with replay buffer. + +**Parameters:** + + +Training policy + + + +Generation interface + + + +Training data loader + + + +Validation data loader + + + +Tokenizer + + + +Loss function + + + +Training environments + + + +Validation environments + + + +Logger + + + +Checkpoint manager + + + +Training state + + + +Master configuration + + + +Maximum age (in training steps) for trajectories to be used in training + + + + + + + + + +```python +nemo_rl.algorithms.grpo.compute_and_apply_seq_logprob_error_masking( + train_data: nemo_rl.distributed.batched_data_dict.BatchedDataDict, + rewards: torch.Tensor, + seq_logprob_error_threshold: typing.Optional[float] +) -> tuple[float, int, float] +``` + + + + + + +Compute sequence-level logprob error metrics and optionally mask high-error sequences. + +This function computes the multiplicative probability error per sequence +(same calculation as token_mult_prob_error but aggregated per-sequence) and +optionally masks sequences that exceed the configured threshold. + +**Parameters:** + + +Training data dict containing token_mask, sample_mask, + prev_logprobs, and generation_logprobs. If masking is applied, + sample_mask will be updated in-place. + + + +Reward tensor for computing statistics on masked sequences. + + + +If set, mask sequences with mult_prob_error + exceeding this threshold. If None, only compute metrics. + + +**Returns:** `tuple[float, int, float]` + +Tuple of (max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct) + + + + + + + + +```python +nemo_rl.algorithms.grpo.dynamic_sampling( + repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + std: torch.Tensor, + baseline: torch.Tensor, + dynamic_sampling_num_gen_batches: int, + master_config: nemo_rl.algorithms.grpo.MasterConfig, + timer: nemo_rl.utils.timer.Timer, + batch_cache: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] +``` + + + + + + +Implements the dynamic sampling algorithm to select prompts with non-zero standard deviation. + +This function filters the current batch to retain only those prompts that have a non-zero standard deviation. +If the current batch has fewer number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, +we store it in the batch_cache to be used in later iterations. +If the current batch has more number of prompts with non-zero standard deviation than the required batch size, defined as num_prompts_per_step * num_generations_per_prompt, +the batch is sliced to ensure batch size is num_prompts_per_step * num_generations_per_prompt. +is_batch_complete is set to False to indicate that the current batch is not enough to meet the required batch size. This is used as a signal in the GRPO training loop +to continue sampling or proceed to training. +This approach is based on the dynamic sampling algorithm from the DAPO paper: +https://arxiv.org/pdf/2503.14476. + +**Parameters:** + + +The current batch of data containing prompts, responses, rewards, baselines, and std. + + + +Tensor representing the standard deviation for each prompt group. + + + +Baseline values for each prompt group. + + + +Number of generation batches processed at the current step. + + + +Configuration containing GRPO and policy settings. + + + +Cache storing previously selected prompts with non-zero std. + + +**Returns:** `BatchedDataDict[DatumSpec]` + +A tuple containing: +- repeated_batch (BatchedDataDict[DatumSpec]): Updated batch with selected prompts. +- is_batch_complete (bool): Indicates if the batch has enough samples with non-zero std for training. +- batch_cache (BatchedDataDict[DatumSpec]): Updated cache for future iterations. + + + + + + + + +```python +nemo_rl.algorithms.grpo.grpo_train( + policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, + policy_generation: typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], + dataloader: torchdata.stateful_dataloader.StatefulDataLoader, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer: nemo_rl.algorithms.grpo.TokenizerType, + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], + logger: nemo_rl.utils.logger.Logger, + checkpointer: nemo_rl.utils.checkpoint.CheckpointManager, + grpo_save_state: nemo_rl.algorithms.grpo.GRPOSaveState, + master_config: nemo_rl.algorithms.grpo.MasterConfig +) -> None +``` + + + + + + +Run GRPO training algorithm. + + + + + + + + +```python +nemo_rl.algorithms.grpo.refit_policy_generation( + policy: nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + colocated_inference: bool, + _refit_buffer_size_gb: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None, + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Refit the policy generation interface with the latest policy weights. + +**Parameters:** + + +The policy to provide weights to the inference engine. + + + +The inference engine to refit. + + + +The size of the buffer to use for refitting. +If it is None, the buffer size will be computed by the remaining memory. +This parameter is primarily used for testing. + + + +Optional Timer used to time the prepare/transfer/update phase + + + +Optional dictionary of KV cache scales for FP8 quantization. + + + + + + + + + +```python +nemo_rl.algorithms.grpo.scale_rewards( + repeated_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + reward_scaling_cfg: nemo_rl.algorithms.grpo.RewardScalingConfig +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec] +``` + + + + + + +Linearly scales rewards from a source range to a target range. + +If `reward_scaling.enabled` is True, each reward in `repeated_batch["total_reward"]` +is clamped to the configured source interval [source_min, source_max] and then +rescaled to the target interval [target_min, target_max]. + + + + + + + + +```python +nemo_rl.algorithms.grpo.setup( + master_config: nemo_rl.algorithms.grpo.MasterConfig, + tokenizer: nemo_rl.algorithms.grpo.TokenizerType, + dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, + val_dataset: typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset], + processor: typing.Optional[transformers.AutoProcessor] = None +) -> tuple[nemo_rl.models.policy.interfaces.ColocatablePolicyInterface, typing.Optional[nemo_rl.models.generation.interfaces.GenerationInterface], tuple[nemo_rl.distributed.virtual_cluster.RayVirtualCluster, nemo_rl.distributed.virtual_cluster.RayVirtualCluster], torchdata.stateful_dataloader.StatefulDataLoader, typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.ClippedPGLossFn, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.grpo.GRPOSaveState, nemo_rl.algorithms.grpo.MasterConfig] +``` + + + + + + +Main entry point for running GRPO algorithm. + +**Returns:** `tuple[ColocatablePolicyInterface, Optional[GenerationInterface], tuple[RayVirtualCluster, RayVirtualCluster], StatefulDataLoader, Optional[StatefulDataLoader], ClippedPGLossFn, Logger, CheckpointManager, GRPOSaveState, MasterConfig]` + +tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader + + + + + + + + +```python +nemo_rl.algorithms.grpo.validate( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer, + val_task_to_env: typing.Optional[dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]], + step: int, + master_config: nemo_rl.algorithms.grpo.MasterConfig, + logger: typing.Optional[nemo_rl.utils.logger.Logger] = None +) -> tuple[dict[str, typing.Any], dict[str, typing.Any]] +``` + + + + + + +Run validation on the validation dataset. + + + + + + + + +```python +nemo_rl.algorithms.grpo.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx new file mode 100644 index 0000000..7976052 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/interfaces.mdx @@ -0,0 +1,123 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/interfaces +title: nemo_rl.algorithms.interfaces +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LossFunction`](#nemo_rl-algorithms-interfaces-LossFunction) | Signature for loss functions used in reinforcement learning algorithms. | +| [`LossType`](#nemo_rl-algorithms-interfaces-LossType) | - | + +### API + + + + + +```python +class nemo_rl.algorithms.interfaces.LossFunction() +``` + + + + + + +Protocol + +Signature for loss functions used in reinforcement learning algorithms. + +Loss functions compute a scalar loss value and associated metrics from +model logprobs and other data contained in a BatchedDataDict. + + + + + + + + +```python +nemo_rl.algorithms.interfaces.LossFunction.__call__( + next_token_logits: torch.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + +Compute loss and metrics from logprobs and other data. + +**Parameters:** + + +Logits from the model, typically with shape [batch_size, seq_len, vocab_size]. + For each position (b, i), contains the logit distribution over the entire vocabulary + for predicting the next token (at position i+1). For example, if processing "The cat sat on", + then next_token_logits[b, 3] would contain the logits for predicting the word + that follows "on". + + + +Dictionary containing all relevant data for loss computation + such as rewards, values, actions, advantages, masks, and other + algorithm-specific information needed for the particular loss calculation. + + + +torch.Tensor +this tensor should contain the number of valid sequences in the microbatch. +It's used for global normalization for losses/metrics that are computed at the sequence level +and needs to be aggregated across all microbatches. + + + +torch.Tensor +This tensor should contain the number of valid tokens in the microbatch. +It's used for global normalization for losses/metrics that are computed at the token level +and needs to be aggregated across all microbatches. + + +**Returns:** `tuple[torch.Tensor, dict[str, Any]]` + +(loss, metrics) +- loss: A scalar tensor representing the loss value to be minimized during training +- metrics: A dictionary of metrics related to the loss computation, which may include + component losses, statistics about gradients/rewards, and other diagnostic information + + + + + + + + + +```python +class nemo_rl.algorithms.interfaces.LossType +``` + + + + + + +**Bases:** `enum.Enum` + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx new file mode 100644 index 0000000..f8307d1 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/loss_functions.mdx @@ -0,0 +1,875 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/loss_functions +title: nemo_rl.algorithms.loss_functions +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ClippedPGLossConfig`](#nemo_rl-algorithms-loss_functions-ClippedPGLossConfig) | - | +| [`ClippedPGLossDataDict`](#nemo_rl-algorithms-loss_functions-ClippedPGLossDataDict) | Required keys for the Clipped Policy Gradient loss function. | +| [`ClippedPGLossFn`](#nemo_rl-algorithms-loss_functions-ClippedPGLossFn) | Generalized Clipped Policy Gradient loss function w/ KL regularization. | +| [`DPOLossConfig`](#nemo_rl-algorithms-loss_functions-DPOLossConfig) | - | +| [`DPOLossDataDict`](#nemo_rl-algorithms-loss_functions-DPOLossDataDict) | Required keys for the DPO loss function. | +| [`DPOLossFn`](#nemo_rl-algorithms-loss_functions-DPOLossFn) | Direct Preference Optimization (DPO) loss function. | +| [`DistillationLossConfig`](#nemo_rl-algorithms-loss_functions-DistillationLossConfig) | - | +| [`DistillationLossDataDict`](#nemo_rl-algorithms-loss_functions-DistillationLossDataDict) | - | +| [`DistillationLossFn`](#nemo_rl-algorithms-loss_functions-DistillationLossFn) | Distillation loss function. | +| [`NLLLoss`](#nemo_rl-algorithms-loss_functions-NLLLoss) | Negative Log Likelihood Loss function. | +| [`PreferenceLoss`](#nemo_rl-algorithms-loss_functions-PreferenceLoss) | Preference Loss function. | +| [`PreferenceLossDataDict`](#nemo_rl-algorithms-loss_functions-PreferenceLossDataDict) | Required keys for the preference loss function. | +| [`SequencePackingLossWrapper`](#nemo_rl-algorithms-loss_functions-SequencePackingLossWrapper) | - | + +### Data + +[`Tensor`](#nemo_rl-algorithms-loss_functions-Tensor) + +### API + + + + + +```python +class nemo_rl.algorithms.loss_functions.ClippedPGLossConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict +``` + + + + + + +**Bases:** `typing.TypedDict` + +Required keys for the Clipped Policy Gradient loss function. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.ClippedPGLossFn( + cfg: nemo_rl.algorithms.loss_functions.ClippedPGLossConfig +) +``` + + + + + + +**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) + +Generalized Clipped Policy Gradient loss function w/ KL regularization. + +This implements: + +- PPO (Clipped) - https://arxiv.org/abs/1707.06347 +- GRPO - https://arxiv.org/abs/2402.03300 +- REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740 +- GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071 +- Truly on-policy (set force_on_policy_ratio = True to force ratio = 1.0, requires one update per rollout) + +Formula: +L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref) + +where: +- r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio +- A_t is the advantage estimate +- ε is the clip parameter (ratio_clip_min/ratio_clip_max) + - As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476), + we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.) + - ratio_clip_min: minimum value for the clip parameter + - ratio_clip_max: maximum value for the clip parameter +- β is the KL penalty coefficient (reference_policy_kl_penalty) +- KL(π_θ || π_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.) + +For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: +L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref) + +Also supports "Dual-Clipping" from https://arxiv.org/pdf/1912.09729, which +imposes an additional upper bound on the probability ratio when advantages are negative. +This prevents excessive policy updates. $rA << 0$ -> $cA$(clipped) +The loss function is modified to the following when A_t < 0: +L(θ) = E_t [ max(min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t), c * A_t) ] - β * KL(π_θ || π_ref) + +where: +- c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is + usually set as 3 empirically. + +Due to potential numerical instability, we cast the logits to float32 before computing the loss. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.ClippedPGLossFn.__call__( + next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.ClippedPGLossDataDict], + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> tuple[torch.Tensor, dict] +``` + + + + + + +Clipped Policy Gradient RL loss function. + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DPOLossConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DPOLossDataDict +``` + + + + + + +**Bases:** `typing.TypedDict` + +Required keys for the DPO loss function. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DPOLossFn( + cfg: nemo_rl.algorithms.loss_functions.DPOLossConfig +) +``` + + + + + + +**Bases:** [PreferenceLoss](#nemo_rl-algorithms-loss_functions-PreferenceLoss) + +Direct Preference Optimization (DPO) loss function. + +This loss function implements the DPO algorithm as described in: +"Direct Preference Optimization: Your Language Model is Secretly a Reward Model" +(https://arxiv.org/abs/2305.18290) + +The loss combines two main components: +1. Preference Loss: Optimizes the model to prefer chosen responses over rejected ones +2. SFT Loss (optional): Auxiliary supervised fine-tuning loss on chosen responses + +The total loss is computed as: +L(θ) = w_p * L_pref(θ) + w_s * L_sft(θ) + +where: +- w_p is the preference_loss_weight +- w_s is the sft_loss_weight +- L_pref(θ) is the preference loss term +- L_sft(θ) is the supervised fine-tuning loss term + +The preference loss term is computed as: +L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))] + +where: +- σ is the sigmoid function +- β is the reference_policy_kl_penalty +- r_chosen and r_rejected are the rewards for chosen and rejected responses +- The rewards are computed as the sum of log probability differences between + the current policy and reference policy + +If preference_average_log_probs is True, the rewards are averaged over tokens: +r = (1/n) * Σ_t (log π_θ(a_t|s_t) - log π_ref(a_t|s_t)) + +Otherwise, the rewards are summed over tokens. + +The SFT loss term is a standard negative log likelihood loss on the chosen responses. +If sft_average_log_probs is True, the loss is averaged over tokens. + +**Parameters:** + + +Configuration dictionary containing: +- reference_policy_kl_penalty (float): Strength of the KL penalty term (β) +- preference_loss_weight (float): Weight for the preference loss term (w_p) +- sft_loss_weight (float): Weight for the SFT loss term (w_s) +- preference_average_log_probs (bool): Whether to average log probs across tokens in preference loss +- sft_average_log_probs (bool): Whether to average log probs across tokens in SFT loss + + +**Returns:** + +tuple[torch.Tensor, dict]: A tuple containing: +- The total loss value +- A dictionary with metrics including: + - loss: Total loss value + - sft_loss: SFT loss component + - preference_loss: Preference loss component + - accuracy: Fraction of examples where chosen response has higher reward + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.DPOLossFn.__call__( + next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict], + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, + global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.DPOLossFn._dpo_loss( + next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.DPOLossDataDict], + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor] +``` + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DistillationLossConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DistillationLossDataDict +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.DistillationLossFn( + cfg: nemo_rl.algorithms.loss_functions.DistillationLossConfig +) +``` + + + + + + +**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) + +Distillation loss function. + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.DistillationLossFn.__call__( + next_token_logits: torch.Tensor, + data: nemo_rl.algorithms.loss_functions.DistillationLossDataDict, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + +Compute distillation loss between teacher and student logits. + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.NLLLoss() +``` + + + + + + +**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) + +Negative Log Likelihood Loss function. + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.NLLLoss.__call__( + next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor | None, + global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + dpo_loss: bool = False, + dpo_average_log_probs: bool = False +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.PreferenceLoss() +``` + + + + + + +**Bases:** [LossFunction](/nemo-rl/nemo_rl/algorithms/interfaces#nemo_rl-algorithms-interfaces-LossFunction) + +Preference Loss function. + +Optimizes the model to prefer chosen responses over rejected ones + +The preference loss is computed as: +L_pref(θ) = -E[log(σ(β * (r_chosen - r_rejected)))] + +where: +- σ is the sigmoid function +- β is a scaling factor (ex: `reference_policy_kl_penalty` in DPO) +- r_chosen and r_rejected are the rewards for chosen and rejected responses + +**Returns:** + +tuple[torch.Tensor, dict]: A tuple containing: +- The preference loss value +- A dictionary with metrics including: + - loss: Preference loss + - accuracy: Fraction of examples where chosen response has higher reward + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.PreferenceLoss.__call__( + rewards: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.algorithms.loss_functions.PreferenceLossDataDict], + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, + global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.PreferenceLoss._preference_loss( + rewards: nemo_rl.algorithms.loss_functions.Tensor, + sample_mask: nemo_rl.algorithms.loss_functions.Tensor, + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor, + beta: float = 1.0 +) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor] +``` + + + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.PreferenceLoss.split_output_tensor( + tensor: nemo_rl.algorithms.loss_functions.Tensor +) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, nemo_rl.algorithms.loss_functions.Tensor] +``` + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.PreferenceLossDataDict +``` + + + + + + +**Bases:** `typing.TypedDict` + +Required keys for the preference loss function. + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.loss_functions.SequencePackingLossWrapper( + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + cu_seqlens_q: nemo_rl.algorithms.loss_functions.Tensor, + cu_seqlens_q_padded: typing.Optional[nemo_rl.algorithms.loss_functions.Tensor] = None +) +``` + + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.SequencePackingLossWrapper.__call__( + next_token_logits: nemo_rl.algorithms.loss_functions.Tensor, + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + global_valid_seqs: nemo_rl.algorithms.loss_functions.Tensor | None, + global_valid_toks: nemo_rl.algorithms.loss_functions.Tensor | None, + vocab_parallel_rank: typing.Optional[int] = None, + vocab_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> tuple[nemo_rl.algorithms.loss_functions.Tensor, dict[str, typing.Any]] +``` + + + + + + +Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding. + + + + + + + + + +```python +nemo_rl.algorithms.loss_functions.Tensor = TypeVar('Tensor', bound=(torch.Tensor)) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx new file mode 100644 index 0000000..ffcae23 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/reward_functions.mdx @@ -0,0 +1,102 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/reward_functions +title: nemo_rl.algorithms.reward_functions +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RewardShapingConfig`](#nemo_rl-algorithms-reward_functions-RewardShapingConfig) | Configuration for reward function processing. | + +### Functions + +| Name | Description | +|------|-------------| +| [`apply_reward_shaping`](#nemo_rl-algorithms-reward_functions-apply_reward_shaping) | Process rewards by applying penalties for responses exceeding max_response_length. Currently, this function only supports DAPO reward shaping as illustrated in the DAPO paper : https://arxiv.org/pdf/2503.14476. | + +### Data + +[`Tensor`](#nemo_rl-algorithms-reward_functions-Tensor) + +### API + + + + + +```python +class nemo_rl.algorithms.reward_functions.RewardShapingConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for reward function processing. + +This configuration enables custom reward shaping, currently supporting DAPO-style +penalties for responses that exceed the maximum response length threshold. + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.reward_functions.apply_reward_shaping( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict, + cfg: nemo_rl.algorithms.reward_functions.RewardShapingConfig +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict +``` + + + + + + +Process rewards by applying penalties for responses exceeding max_response_length. Currently, this function only supports DAPO reward shaping as illustrated in the DAPO paper : https://arxiv.org/pdf/2503.14476. + +Nonetheless, it can be potentially extended to support any custom reward logic. + + + + + + + + +```python +nemo_rl.algorithms.reward_functions.Tensor = TypeVar('Tensor', bound=(torch.Tensor)) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx new file mode 100644 index 0000000..ed41f3a --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/rm.mdx @@ -0,0 +1,320 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/rm +title: nemo_rl.algorithms.rm +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MasterConfig`](#nemo_rl-algorithms-rm-MasterConfig) | - | +| [`RMConfig`](#nemo_rl-algorithms-rm-RMConfig) | - | +| [`RMSaveState`](#nemo_rl-algorithms-rm-RMSaveState) | - | +| [`RMValMetrics`](#nemo_rl-algorithms-rm-RMValMetrics) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_default_rm_save_state`](#nemo_rl-algorithms-rm-_default_rm_save_state) | - | +| [`rm_train`](#nemo_rl-algorithms-rm-rm_train) | - | +| [`setup`](#nemo_rl-algorithms-rm-setup) | Main entry point for running RM algorithm. | +| [`validate`](#nemo_rl-algorithms-rm-validate) | - | +| [`validate_one_dataset`](#nemo_rl-algorithms-rm-validate_one_dataset) | Run validation on one validation dataset. | + +### API + + + + + +```python +class nemo_rl.algorithms.rm.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.rm.RMConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.rm.RMSaveState +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.rm.RMValMetrics +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.rm._default_rm_save_state() -> nemo_rl.algorithms.rm.RMSaveState +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.rm.rm_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + checkpointer, + rm_save_state +) +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.rm.setup( + master_config: nemo_rl.algorithms.rm.MasterConfig, + tokenizer: transformers.AutoTokenizer, + train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, + val_dataset: dict[str, nemo_rl.data.datasets.AllTaskProcessedDataset] +) -> tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, dict[str, torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.PreferenceLoss, nemo_rl.algorithms.rm.MasterConfig, nemo_rl.utils.logger.Logger, nemo_rl.data.interfaces.TaskDataSpec, nemo_rl.algorithms.rm.RMSaveState] +``` + + + + + + +Main entry point for running RM algorithm. + +**Returns:** `tuple[Policy, RayVirtualCluster, StatefulDataLoader, dict[str, StatefulDataLoader], PreferenceLoss, MasterConfig, Logger, TaskDataSpec, RMSaveState]` + +Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger + + + + + + + + +```python +nemo_rl.algorithms.rm.validate( + policy: nemo_rl.models.policy.interfaces.PolicyInterface, + val_dataloader: dict[str, torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer, + loss_fn, + step: int, + master_config: nemo_rl.algorithms.rm.MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int, + logger: nemo_rl.utils.logger.Logger +) +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.rm.validate_one_dataset( + policy: nemo_rl.models.policy.interfaces.PolicyInterface, + val_dataloader: torchdata.stateful_dataloader.StatefulDataLoader, + loss_fn, + step: int, + master_config: nemo_rl.algorithms.rm.MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int, + dataset_name: str +) +``` + + + + + + +Run validation on one validation dataset. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx new file mode 100644 index 0000000..d9a3bd6 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/sft.mdx @@ -0,0 +1,258 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/sft +title: nemo_rl.algorithms.sft +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MasterConfig`](#nemo_rl-algorithms-sft-MasterConfig) | - | +| [`SFTConfig`](#nemo_rl-algorithms-sft-SFTConfig) | - | +| [`SFTSaveState`](#nemo_rl-algorithms-sft-SFTSaveState) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_default_sft_save_state`](#nemo_rl-algorithms-sft-_default_sft_save_state) | - | +| [`setup`](#nemo_rl-algorithms-sft-setup) | Main entry point for running SFT algorithm. | +| [`sft_train`](#nemo_rl-algorithms-sft-sft_train) | - | +| [`validate`](#nemo_rl-algorithms-sft-validate) | Run validation on the validation dataset. | + +### API + + + + + +```python +class nemo_rl.algorithms.sft.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.sft.SFTConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.algorithms.sft.SFTSaveState +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.algorithms.sft._default_sft_save_state() -> nemo_rl.algorithms.sft.SFTSaveState +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.sft.setup( + master_config: nemo_rl.algorithms.sft.MasterConfig, + tokenizer: transformers.AutoTokenizer, + train_dataset: nemo_rl.data.datasets.AllTaskProcessedDataset, + val_dataset: typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset] +) -> tuple[nemo_rl.models.policy.lm_policy.Policy, nemo_rl.distributed.virtual_cluster.RayVirtualCluster, torchdata.stateful_dataloader.StatefulDataLoader, typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], nemo_rl.algorithms.loss_functions.NLLLoss, nemo_rl.utils.logger.Logger, nemo_rl.utils.checkpoint.CheckpointManager, nemo_rl.algorithms.sft.SFTSaveState, nemo_rl.algorithms.sft.MasterConfig] +``` + + + + + + +Main entry point for running SFT algorithm. + +**Returns:** `tuple[Policy, RayVirtualCluster, StatefulDataLoader, Optional[StatefulDataLoader], NLLLoss, Logger, CheckpointManager, SFTSaveState, MasterConfig]` + +Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger + + + + + + + + +```python +nemo_rl.algorithms.sft.sft_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + checkpointer, + sft_save_state: nemo_rl.algorithms.sft.SFTSaveState +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.algorithms.sft.validate( + policy: nemo_rl.models.policy.interfaces.PolicyInterface, + val_dataloader: typing.Optional[torchdata.stateful_dataloader.StatefulDataLoader], + tokenizer, + loss_fn, + step: int, + master_config: nemo_rl.algorithms.sft.MasterConfig, + val_batches: int, + val_batch_size: int, + val_mbs: int +) +``` + + + + + + +Run validation on the validation dataset. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx new file mode 100644 index 0000000..200ecb0 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/algorithms/utils.mdx @@ -0,0 +1,379 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/algorithms/utils +title: nemo_rl.algorithms.utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`calculate_baseline_and_std_per_prompt`](#nemo_rl-algorithms-utils-calculate_baseline_and_std_per_prompt) | Function to compute a baseline for each (prompt, response) pair in the batch. | +| [`calculate_kl`](#nemo_rl-algorithms-utils-calculate_kl) | Calculates a per-token estimate of the KL Divergence between two logprobs. | +| [`get_tokenizer`](#nemo_rl-algorithms-utils-get_tokenizer) | Get the tokenizer and set pad token to eos token if it is not already set. | +| [`log_generation_metrics_to_wandb`](#nemo_rl-algorithms-utils-log_generation_metrics_to_wandb) | Log generation metrics to wandb. | +| [`masked_mean`](#nemo_rl-algorithms-utils-masked_mean) | Computes the mean of a microbatch, using a global statistic as the normalization factor. | +| [`maybe_pad_last_batch`](#nemo_rl-algorithms-utils-maybe_pad_last_batch) | Pads the given batch so that its size is divisible by (mbs * dp_size). | +| [`print_performance_metrics`](#nemo_rl-algorithms-utils-print_performance_metrics) | Print performance metrics for GRPO. | +| [`set_seed`](#nemo_rl-algorithms-utils-set_seed) | Sets the seed for python, numpy, and pytorch. | +| [`surpress_user_warnings`](#nemo_rl-algorithms-utils-surpress_user_warnings) | - | + +### API + + + + + +```python +nemo_rl.algorithms.utils.calculate_baseline_and_std_per_prompt( + prompts: torch.Tensor, + rewards: torch.Tensor, + valid_mask: torch.Tensor, + leave_one_out_baseline: bool = True +) -> tuple[torch.Tensor, torch.Tensor] +``` + + + + + + +Function to compute a baseline for each (prompt, response) pair in the batch. + +The same baseline is calculated for each prompt. Samples set to 0 in 'valid_mask' +are not included in the baseline calculation. + +prompts: tensor (b, s) Tensor of prompts the model used. May be on any device +rewards: tensor (b,) Float-valued rewards. May be on any device +valid_mask: tensor (b,) Vector of 0/1, where 0 is to ignore and 1 is to keep +leave_one_out_baseline: bool Compute an unbiased baseline by leaving out the sample that + the baseline is for (from RLOO https://arxiv.org/abs/2402.14740) + +Returns: +tensor (b,), tensor (b,) of baselines and std on the same device as 'rewards' + + + + + + + + +```python +nemo_rl.algorithms.utils.calculate_kl( + logprobs: torch.Tensor, + logprobs_reference: torch.Tensor, + kl_type: str = 'k3', + input_clamp_value: float | None = 20.0, + output_clamp_value: float | None = 10.0 +) -> torch.Tensor +``` + + + + + + +Calculates a per-token estimate of the KL Divergence between two logprobs. + +From Schulman 2020, http://joschu.net/blog/kl-approx.html. + +**Parameters:** + + +torch.Tensor (b, s) + + + +torch.Tensor (b, s) + + + +Type of KL approximation to use. Valid values: "k1", "k2", "k3". + + + +Optional clamping value for logr to prevent numerical instability. + If None, no clamping is applied. + + + +Optional clamping value for kl to prevent numerical instability. + If None, no clamping is applied. + + +**Returns:** `torch.Tensor` + +torch.Tensor: Per-token KL penalty values (b, s) + + + + + + + + +```python +nemo_rl.algorithms.utils.get_tokenizer( + tokenizer_config: nemo_rl.models.policy.TokenizerConfig, + get_processor: bool = False +) -> transformers.PreTrainedTokenizerBase +``` + + + + + + +Get the tokenizer and set pad token to eos token if it is not already set. + +This function initializes a tokenizer from the Hugging Face transformers library +and configures it with appropriate chat templates and padding tokens. + +**Parameters:** + + +A dictionary containing tokenizer configuration. +Required keys: + - name: The name or path of the pretrained tokenizer +Optional keys: + - chat_template: The chat template to use. Can be: + - None: Uses a passthrough template that just returns message content + - "default": Uses the tokenizer's default template + - A custom jinja2 template string + If not specified, the tokenizer's default template will be used. + + + +Whether to return a processor (via AutoProcessor) instead of a tokenizer. + + +**Returns:** `PreTrainedTokenizerBase` + +The configured tokenizer instance + +**Examples:** + + + +```python +>>> from transformers import AutoTokenizer +>>> from nemo_rl.algorithms.utils import get_tokenizer +>>> # not specifying a chat template uses the tokenizer's default +>>> config = {"name": "meta-llama/Llama-3.2-1B-Instruct"} +>>> tokenizer = get_tokenizer(config) +No chat template provided, using tokenizer's default +>>> messages = [ +... {"role": "system", "content": "You are a helpful AI assistant."}, +... {"role": "user", "content": "Hello!"} +... ] +>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) +>>> assert formatted == AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct").apply_chat_template(messages, tokenize=False) + +>>> # Using a passthrough template +>>> config = { +... "name": "meta-llama/Llama-3.2-1B-Instruct", +... "chat_template": None +... } +>>> tokenizer = get_tokenizer(config) +Using passthrough chat template +>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) +>>> assert formatted == "".join(msg["content"] for msg in messages) + +>>> # Using a custom template +>>> config = { +... "name": "meta-llama/Llama-3.2-1B-Instruct", +... "chat_template": "{% for message in messages %}{{ ' START: ' + message['content'] + ' END.' }}{% endfor %}" +... } +>>> tokenizer = get_tokenizer(config) +Using custom chat template +>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) +>>> assert formatted == " START: You are a helpful AI assistant. END. START: Hello! END." + +>>> # Requesting a processor (for multimodal models like Qwen-VL) +>>> config = {"name": "Qwen/Qwen2.5-VL-3B-Instruct"} +>>> processor = get_tokenizer(config, get_processor=True) +No chat template provided, using tokenizer's default +>>> messages = [ +... {"role": "system", "content": "You are a helpful AI assistant."}, +... {"role": "user", "content": "Hello!"} +... ] +>>> formatted = processor.tokenizer.apply_chat_template(messages, tokenize=False) +>>> assert formatted == AutoTokenizer.from_pretrained( +... "Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True +... ).apply_chat_template(messages, tokenize=False) +>>> assert processor.pad_token_id == processor.tokenizer.pad_token_id +>>> +``` + + + + + + + + + + +```python +nemo_rl.algorithms.utils.log_generation_metrics_to_wandb( + generation_logger_metrics: dict[str, dict[int, list[typing.Any]]], + step: int, + timeline_interval: float, + logger: nemo_rl.utils.logger.Logger +) -> None +``` + + + + + + +Log generation metrics to wandb. + +**Parameters:** + + +Dictionary of generation logger metrics + + + +Global step value + + + +Interval between timeline points (in seconds) + + + +Logger instance + + + + + + + + + +```python +nemo_rl.algorithms.utils.masked_mean( + values: torch.Tensor, + mask: torch.Tensor, + dim: typing.Optional[int] = None, + global_normalization_factor: typing.Optional[torch.Tensor | float] = None +) +``` + + + + + + +Computes the mean of a microbatch, using a global statistic as the normalization factor. + + + + + + + + +```python +nemo_rl.algorithms.utils.maybe_pad_last_batch( + batch: dict, + dp_size: int, + mbs: int +) -> dict +``` + + + + + + +Pads the given batch so that its size is divisible by (mbs * dp_size). + +**Parameters:** + + +The batch to pad. + + + +Data parallel size. + + + +Micro batch size. + + +**Returns:** `dict` + +The padded batch. + + + + + + + + +```python +nemo_rl.algorithms.utils.print_performance_metrics( + train_results: dict[str, float], + metrics: dict[str, typing.Any], + timing_metrics: dict[str, float], + master_config: dict +) -> dict[str, float] +``` + + + + + + +Print performance metrics for GRPO. + + + + + + + + +```python +nemo_rl.algorithms.utils.set_seed( + seed: int +) -> None +``` + + + + + + +Sets the seed for python, numpy, and pytorch. + + + + + + + + +```python +nemo_rl.algorithms.utils.surpress_user_warnings( + f +) +``` + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx new file mode 100644 index 0000000..3cafa95 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data.mdx @@ -0,0 +1,466 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data +title: nemo_rl.data +--- + +## Subpackages + +- **[`nemo_rl.data.datasets`](/nemo-rl/nemo_rl/data/datasets)** +- **[`nemo_rl.data.packing`](/nemo-rl/nemo_rl/data/packing)** + +## Submodules + +- **[`nemo_rl.data.chat_templates`](/nemo-rl/nemo_rl/data/chat_templates)** +- **[`nemo_rl.data.collate_fn`](/nemo-rl/nemo_rl/data/collate_fn)** +- **[`nemo_rl.data.interfaces`](/nemo-rl/nemo_rl/data/interfaces)** +- **[`nemo_rl.data.llm_message_utils`](/nemo-rl/nemo_rl/data/llm_message_utils)** +- **[`nemo_rl.data.multimodal_utils`](/nemo-rl/nemo_rl/data/multimodal_utils)** +- **[`nemo_rl.data.processors`](/nemo-rl/nemo_rl/data/processors)** +- **[`nemo_rl.data.utils`](/nemo-rl/nemo_rl/data/utils)** + +## Package Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AIMEEvalDataConfig`](#nemo_rl-data-AIMEEvalDataConfig) | Config for AIME datasets. | +| [`DataConfig`](#nemo_rl-data-DataConfig) | - | +| [`GPQAEvalDataConfig`](#nemo_rl-data-GPQAEvalDataConfig) | Config for GPQA datasets. | +| [`LocalMathEvalDataConfig`](#nemo_rl-data-LocalMathEvalDataConfig) | Config for local math datasets loaded from files. | +| [`MMLUEvalDataConfig`](#nemo_rl-data-MMLUEvalDataConfig) | Config for MMLU and multilingual MMLU datasets. | +| [`MMLUProEvalDataConfig`](#nemo_rl-data-MMLUProEvalDataConfig) | Config for MMLU Pro dataset. | +| [`MathEvalDataConfig`](#nemo_rl-data-MathEvalDataConfig) | Config for Math datasets. | +| [`PreferenceDatasetConfig`](#nemo_rl-data-PreferenceDatasetConfig) | - | +| [`ResponseDatasetConfig`](#nemo_rl-data-ResponseDatasetConfig) | - | + +### Data + +[`EvalDataConfigType`](#nemo_rl-data-EvalDataConfigType) + +### API + + + + + +```python +class nemo_rl.data.AIMEEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for AIME datasets. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.DataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.GPQAEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for GPQA datasets. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.LocalMathEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for local math datasets loaded from files. + +dataset_name can be a URL or local file path. +Requires additional fields: problem_key, solution_key, file_format, split. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.MMLUEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for MMLU and multilingual MMLU datasets. + +Supports dataset_name: "mmlu" or "mmlu_{language}" where language is one of: +AR-XY, BN-BD, DE-DE, EN-US, ES-LA, FR-FR, HI-IN, ID-ID, IT-IT, JA-JP, +KO-KR, PT-BR, ZH-CN, SW-KE, YO-NG + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.MMLUProEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for MMLU Pro dataset. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.MathEvalDataConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Config for Math datasets. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.PreferenceDatasetConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.ResponseDatasetConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.EvalDataConfigType = MMLUEvalDataConfig | MMLUProEvalDataConfig | AIMEEvalDataConfig | GPQAEvalDataCo... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx new file mode 100644 index 0000000..11e5f15 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/chat_templates.mdx @@ -0,0 +1,35 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/chat_templates +title: nemo_rl.data.chat_templates +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`COMMON_CHAT_TEMPLATES`](#nemo_rl-data-chat_templates-COMMON_CHAT_TEMPLATES) | - | + +### API + + + + + +```python +class nemo_rl.data.chat_templates.COMMON_CHAT_TEMPLATES() +``` + + + + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx new file mode 100644 index 0000000..56b6bb7 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/collate_fn.mdx @@ -0,0 +1,166 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/collate_fn +title: nemo_rl.data.collate_fn +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`eval_collate_fn`](#nemo_rl-data-collate_fn-eval_collate_fn) | Collate function for evaluation. | +| [`preference_collate_fn`](#nemo_rl-data-collate_fn-preference_collate_fn) | Collate function for preference data training. | +| [`rl_collate_fn`](#nemo_rl-data-collate_fn-rl_collate_fn) | Collate function for RL training. | + +### Data + +[`TokenizerType`](#nemo_rl-data-collate_fn-TokenizerType) + +### API + + + + + +```python +nemo_rl.data.collate_fn.eval_collate_fn( + data_batch: list[nemo_rl.data.interfaces.DatumSpec] +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +``` + + + + + + +Collate function for evaluation. + +Takes a list of data samples and combines them into a single batched dictionary +for model evaluation. + +Examples: + + +```python +>>> import torch +>>> from nemo_rl.data.collate_fn import eval_collate_fn +>>> from nemo_rl.data.interfaces import DatumSpec +>>> data_batch = [ +... DatumSpec( +... message_log=[{"role": "user", "content": "Hello", "token_ids": torch.tensor([1, 2, 3])}], +... extra_env_info={'ground_truth': '1'}, +... idx=0, +... ), +... DatumSpec( +... message_log=[{"role": "assistant", "content": "Hi there", "token_ids": torch.tensor([4, 5, 6, 7])}], +... extra_env_info={'ground_truth': '2'}, +... idx=1, +... ), +... ] +>>> output = eval_collate_fn(data_batch) +>>> output['message_log'][0] +[{'role': 'user', 'content': 'Hello', 'token_ids': tensor([1, 2, 3])}] +>>> output['message_log'][1] +[{'role': 'assistant', 'content': 'Hi there', 'token_ids': tensor([4, 5, 6, 7])}] +>>> output['extra_env_info'] +[{'ground_truth': '1'}, {'ground_truth': '2'}] +>>> output['idx'] +[0, 1] +``` + + + +**Parameters:** + + +List of data samples with message_log, extra_env_info, and idx fields. + + +**Returns:** `BatchedDataDict[Any]` + +BatchedDataDict with message_log, extra_env_info, and idx fields. + + + + + + + + +```python +nemo_rl.data.collate_fn.preference_collate_fn( + data_batch: list[nemo_rl.data.interfaces.PreferenceDatumSpec], + tokenizer: nemo_rl.data.collate_fn.TokenizerType, + make_sequence_length_divisible_by: int, + add_loss_mask: bool +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +``` + + + + + + +Collate function for preference data training. + +This function separates the chosen and rejected responses to create +two examples per prompt. The chosen and rejected examples are interleaved +along the batch dimension, resulting in a batch size of 2 * len(data_batch). + +Returns: + BatchedDataDict with input_ids, input_lengths, token_mask (optional), and sample_mask fields. + +**Parameters:** + + +List of data samples with message_log_chosen, message_log_rejected, length_chosen, length_rejected, loss_multiplier, idx, and task_name fields. + + + +Tokenizer for text processing + + + +Make the sequence length divisible by this value + + + +Whether to add a token_mask to the returned data + + + + + + + + + +```python +nemo_rl.data.collate_fn.rl_collate_fn( + data_batch: list[nemo_rl.data.interfaces.DatumSpec] +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +``` + + + + + + +Collate function for RL training. + + + + + + + + +```python +nemo_rl.data.collate_fn.TokenizerType = Union[PreTrainedTokenizerBase, AutoProcessor] +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx new file mode 100644 index 0000000..88450e5 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets.mdx @@ -0,0 +1,37 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets +title: nemo_rl.data.datasets +--- + +## Subpackages + +- **[`nemo_rl.data.datasets.eval_datasets`](/nemo-rl/nemo_rl/data/datasets/eval_datasets)** +- **[`nemo_rl.data.datasets.preference_datasets`](/nemo-rl/nemo_rl/data/datasets/preference_datasets)** +- **[`nemo_rl.data.datasets.response_datasets`](/nemo-rl/nemo_rl/data/datasets/response_datasets)** + +## Submodules + +- **[`nemo_rl.data.datasets.processed_dataset`](/nemo-rl/nemo_rl/data/datasets/processed_dataset)** +- **[`nemo_rl.data.datasets.raw_dataset`](/nemo-rl/nemo_rl/data/datasets/raw_dataset)** +- **[`nemo_rl.data.datasets.utils`](/nemo-rl/nemo_rl/data/datasets/utils)** + +## Package Contents + +### Data + +[`__all__`](#nemo_rl-data-datasets-__all__) + +### API + + + + + +```python +nemo_rl.data.datasets.__all__ = ['AllTaskProcessedDataset', 'load_eval_dataset', 'load_preference_dataset', 'loa... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx new file mode 100644 index 0000000..433590f --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets.mdx @@ -0,0 +1,60 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets +title: nemo_rl.data.datasets.eval_datasets +--- + +## Submodules + +- **[`nemo_rl.data.datasets.eval_datasets.aime`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime)** +- **[`nemo_rl.data.datasets.eval_datasets.gpqa`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa)** +- **[`nemo_rl.data.datasets.eval_datasets.local_math_dataset`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset)** +- **[`nemo_rl.data.datasets.eval_datasets.math`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/math)** +- **[`nemo_rl.data.datasets.eval_datasets.mmlu`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu)** +- **[`nemo_rl.data.datasets.eval_datasets.mmlu_pro`](/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`load_eval_dataset`](#nemo_rl-data-datasets-eval_datasets-load_eval_dataset) | Loads evaluation dataset. | + +### Data + +[`__all__`](#nemo_rl-data-datasets-eval_datasets-__all__) + +### API + + + + + +```python +nemo_rl.data.datasets.eval_datasets.load_eval_dataset( + data_config +) +``` + + + + + + +Loads evaluation dataset. + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.__all__ = ['AIMEDataset', 'GPQADataset', 'LocalMathDataset', 'MathDataset', 'MMLUDataset',... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx new file mode 100644 index 0000000..155c936 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/aime.mdx @@ -0,0 +1,64 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/aime +title: nemo_rl.data.datasets.eval_datasets.aime +--- + +AIME dataset. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AIMEDataset`](#nemo_rl-data-datasets-eval_datasets-aime-AIMEDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.aime.AIMEDataset( + variant: typing.Literal['2024', '2025'] = '2025', + prompt_file: typing.Optional[str] = None, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.aime.AIMEDataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx new file mode 100644 index 0000000..d1ca3a9 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa.mdx @@ -0,0 +1,64 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/gpqa +title: nemo_rl.data.datasets.eval_datasets.gpqa +--- + +GPQA dataset and its variants. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`GPQADataset`](#nemo_rl-data-datasets-eval_datasets-gpqa-GPQADataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.gpqa.GPQADataset( + variant: typing.Literal['diamond', 'main'] = 'diamond', + prompt_file: typing.Optional[str] = None, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.gpqa.GPQADataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx new file mode 100644 index 0000000..e6d6754 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset.mdx @@ -0,0 +1,65 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/local_math_dataset +title: nemo_rl.data.datasets.eval_datasets.local_math_dataset +--- + +Local math dataset. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LocalMathDataset`](#nemo_rl-data-datasets-eval_datasets-local_math_dataset-LocalMathDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.local_math_dataset.LocalMathDataset( + data_path: str, + problem_key: str, + solution_key: str, + split: typing.Optional[str] = None, + file_format: typing.Literal['csv', 'json'] = 'csv', + prompt_file: typing.Optional[str] = None, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.local_math_dataset.LocalMathDataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx new file mode 100644 index 0000000..c00f375 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/math.mdx @@ -0,0 +1,61 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/math +title: nemo_rl.data.datasets.eval_datasets.math +--- + +Math dataset and its variants. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MathDataset`](#nemo_rl-data-datasets-eval_datasets-math-MathDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.math.MathDataset( + variant: typing.Literal['math_test', 'math_500_test'] = 'math_test', + prompt_file: typing.Optional[str] = None, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.math.MathDataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx new file mode 100644 index 0000000..1114133 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu.mdx @@ -0,0 +1,61 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu +title: nemo_rl.data.datasets.eval_datasets.mmlu +--- + +MMLU dataset and its variants. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MMLUDataset`](#nemo_rl-data-datasets-eval_datasets-mmlu-MMLUDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.mmlu.MMLUDataset( + language: typing.Literal['AR-XY', 'BN-BD', 'DE-DE', 'EN-US', 'ES-LA', 'FR-FR', 'HI-IN', 'ID-ID', 'IT-IT', 'JA-JP', 'KO-KR', 'PT-BR', 'ZH-CN', 'SW-KE', 'YO-NG'] = 'EN-US', + prompt_file: typing.Optional[str] = None, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.mmlu.MMLUDataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx new file mode 100644 index 0000000..998a593 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro.mdx @@ -0,0 +1,60 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/eval_datasets/mmlu_pro +title: nemo_rl.data.datasets.eval_datasets.mmlu_pro +--- + +MMLU-Pro dataset. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MMLUProDataset`](#nemo_rl-data-datasets-eval_datasets-mmlu_pro-MMLUProDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.eval_datasets.mmlu_pro.MMLUProDataset( + prompt_file: str, + system_prompt_file: typing.Optional[str] = None +) +``` + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.eval_datasets.mmlu_pro.MMLUProDataset._rekey( + data: dict[str, typing.Any] +) +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx new file mode 100644 index 0000000..1b101aa --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets.mdx @@ -0,0 +1,72 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/preference_datasets +title: nemo_rl.data.datasets.preference_datasets +--- + +## Submodules + +- **[`nemo_rl.data.datasets.preference_datasets.binary_preference_dataset`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset)** +- **[`nemo_rl.data.datasets.preference_datasets.helpsteer3`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3)** +- **[`nemo_rl.data.datasets.preference_datasets.preference_dataset`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset)** +- **[`nemo_rl.data.datasets.preference_datasets.tulu3`](/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`load_preference_dataset`](#nemo_rl-data-datasets-preference_datasets-load_preference_dataset) | Loads preference dataset. | + +### Data + +[`DATASET_REGISTRY`](#nemo_rl-data-datasets-preference_datasets-DATASET_REGISTRY) + +[`__all__`](#nemo_rl-data-datasets-preference_datasets-__all__) + +### API + + + + + +```python +nemo_rl.data.datasets.preference_datasets.load_preference_dataset( + data_config: nemo_rl.data.PreferenceDatasetConfig +) +``` + + + + + + +Loads preference dataset. + + + + + + + + +```python +nemo_rl.data.datasets.preference_datasets.DATASET_REGISTRY = {'HelpSteer3': HelpSteer3Dataset, 'Tulu3Preference': Tulu3PreferenceDataset, 'Bi... +``` + + + + + + + + + +```python +nemo_rl.data.datasets.preference_datasets.__all__ = ['BinaryPreferenceDataset', 'HelpSteer3Dataset', 'PreferenceDataset', 'Tulu3Pref... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx new file mode 100644 index 0000000..762ddd7 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.mdx @@ -0,0 +1,102 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset +title: nemo_rl.data.datasets.preference_datasets.binary_preference_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BinaryPreferenceDataset`](#nemo_rl-data-datasets-preference_datasets-binary_preference_dataset-BinaryPreferenceDataset) | Dataset class for binary preference data which can be loaded from a JSON file. | + +### API + + + + + +```python +class nemo_rl.data.datasets.preference_datasets.binary_preference_dataset.BinaryPreferenceDataset( + data_path: str, + prompt_key: str = 'prompt', + chosen_key: str = 'chosen', + rejected_key: str = 'rejected', + subset: typing.Optional[str] = None, + split: typing.Optional[str] = None, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Dataset class for binary preference data which can be loaded from a JSON file. + +This class handles loading of preference data for DPO and RM training. +It will be converted to the format of PreferenceDataset through the `to_preference_data_format` function. + +The input JSONL files should contain valid JSON objects formatted like this: +{ + prompt_key: str, # The input prompt/context + chosen_key: str, # The preferred/winning response + rejected_key: str, # The non-preferred/losing response +} +Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/dpo.md#datasets for more details. + +**Parameters:** + + +Path to the dataset JSON file + + + +Key for the input prompt/context, default is "prompt" + + + +Key for the preferred/winning response, default is "chosen" + + + +Key for the non-preferred/losing response, default is "rejected" + + + +Optional subset name for the dataset, used for HuggingFace datasets + + + +Optional split name for the dataset, used for HuggingFace datasets + + + + + + + + + + + + +```python +nemo_rl.data.datasets.preference_datasets.binary_preference_dataset.BinaryPreferenceDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx new file mode 100644 index 0000000..a88c29e --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3.mdx @@ -0,0 +1,66 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/helpsteer3 +title: nemo_rl.data.datasets.preference_datasets.helpsteer3 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`HelpSteer3Dataset`](#nemo_rl-data-datasets-preference_datasets-helpsteer3-HelpSteer3Dataset) | HelpSteer3 preference dataset for DPO training. | + +### API + + + + + +```python +class nemo_rl.data.datasets.preference_datasets.helpsteer3.HelpSteer3Dataset( + split: str = 'train', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +HelpSteer3 preference dataset for DPO training. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + + + + + + + + + + +```python +nemo_rl.data.datasets.preference_datasets.helpsteer3.HelpSteer3Dataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx new file mode 100644 index 0000000..0264eec --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset.mdx @@ -0,0 +1,77 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/preference_dataset +title: nemo_rl.data.datasets.preference_datasets.preference_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`PreferenceDataset`](#nemo_rl-data-datasets-preference_datasets-preference_dataset-PreferenceDataset) | Dataset class for preference data which can be loaded from a JSON file. | + +### API + + + + + +```python +class nemo_rl.data.datasets.preference_datasets.preference_dataset.PreferenceDataset( + data_path: str, + subset: typing.Optional[str] = None, + split: typing.Optional[str] = None, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Dataset class for preference data which can be loaded from a JSON file. + +This class handles loading of preference data for DPO and RM training. +The input JSONL files should contain valid JSON objects formatted like this: +{ + "context": list[dict], # The prompt message (including previous turns, if any) + "completions": [ # The list of completions + { + "rank": 0, # The rank of the completion (lower rank is preferred) + "completion": list[dict], # The completion message(s) + }, + { + "rank": 1, # The rank of the completion (lower rank is preferred) + "completion": list[dict], # The completion message(s) + }, + ... # More completions can be added if needed + ] +} +Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/dpo.md#datasets for more details. + +**Parameters:** + + +Path to the dataset JSON file + + + +Optional subset name for the dataset, used for HuggingFace datasets + + + +Optional split name for the dataset, used for HuggingFace datasets + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx new file mode 100644 index 0000000..0a7c89c --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3.mdx @@ -0,0 +1,59 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/preference_datasets/tulu3 +title: nemo_rl.data.datasets.preference_datasets.tulu3 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Tulu3PreferenceDataset`](#nemo_rl-data-datasets-preference_datasets-tulu3-Tulu3PreferenceDataset) | Tulu3 preference dataset for DPO training. | + +### API + + + + + +```python +class nemo_rl.data.datasets.preference_datasets.tulu3.Tulu3PreferenceDataset( + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Tulu3 preference dataset for DPO training. + + + + + + + + + + + +```python +nemo_rl.data.datasets.preference_datasets.tulu3.Tulu3PreferenceDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx new file mode 100644 index 0000000..130991c --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/processed_dataset.mdx @@ -0,0 +1,135 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/processed_dataset +title: nemo_rl.data.datasets.processed_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AllTaskProcessedDataset`](#nemo_rl-data-datasets-processed_dataset-AllTaskProcessedDataset) | Dataset for processing single or multi-task data with task-specific tokenization and processing. | + +### Data + +[`TokenizerType`](#nemo_rl-data-datasets-processed_dataset-TokenizerType) + +### API + + + + + +```python +class nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset( + dataset: datasets.Dataset | typing.Any, + tokenizer: nemo_rl.data.datasets.processed_dataset.TokenizerType, + default_task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + task_data_processors: dict[str, tuple[nemo_rl.data.interfaces.TaskDataSpec, nemo_rl.data.interfaces.TaskDataProcessFnCallable]] | nemo_rl.data.interfaces.TaskDataProcessFnCallable, + max_seq_length: typing.Optional[int] = None +) +``` + + + + + + +Dataset for processing single or multi-task data with task-specific tokenization and processing. + +**Parameters:** + + +Input dataset containing raw data + + + +Tokenizer for text processing + + + +Default task processing specifications. +In the case of single-task, this is the spec used for processing all entries. +In the case of multi-task, any values not specified in the task-specific specs will be taken from the default spec. + + + +Either a single TaskDataProcessFnCallable for single-task, +or a dict mapping task names to (TaskDataSpec, TaskDataProcessFnCallable) for multi-task + + + +Maximum sequence length for tokenized outputs + + + + + + + +```python +nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset.__getitem__( + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Return a single prompt. + + + + + + + +```python +nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset.__len__() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.data.datasets.processed_dataset.AllTaskProcessedDataset.encode_single( + text: typing.Union[str, list[str]] +) -> tuple[list[int] | torch.Tensor, int] +``` + + + + + + +Takes either a single string or a list of strings that represent multiple turns for the same conversation. + +Returns a single (concatenated) list of tokenized ids and the length of the tokenized ids. + + + + + + + + + +```python +nemo_rl.data.datasets.processed_dataset.TokenizerType = Union[PreTrainedTokenizerBase, AutoProcessor] +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx new file mode 100644 index 0000000..af7a37b --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/raw_dataset.mdx @@ -0,0 +1,94 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/raw_dataset +title: nemo_rl.data.datasets.raw_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RawDataset`](#nemo_rl-data-datasets-raw_dataset-RawDataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.raw_dataset.RawDataset() +``` + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.raw_dataset.RawDataset.set_processor() +``` + + + + + + + + + + + + +```python +nemo_rl.data.datasets.raw_dataset.RawDataset.set_task_spec( + data_config: nemo_rl.data.ResponseDatasetConfig | nemo_rl.data.PreferenceDatasetConfig +) +``` + + + + + + + + + + + + +```python +nemo_rl.data.datasets.raw_dataset.RawDataset.split_train_validation( + test_size: float, + seed: int +) +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx new file mode 100644 index 0000000..3892488 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets.mdx @@ -0,0 +1,82 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets +title: nemo_rl.data.datasets.response_datasets +--- + +## Submodules + +- **[`nemo_rl.data.datasets.response_datasets.aime24`](/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24)** +- **[`nemo_rl.data.datasets.response_datasets.clevr`](/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr)** +- **[`nemo_rl.data.datasets.response_datasets.dapo_math`](/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math)** +- **[`nemo_rl.data.datasets.response_datasets.deepscaler`](/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler)** +- **[`nemo_rl.data.datasets.response_datasets.geometry3k`](/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k)** +- **[`nemo_rl.data.datasets.response_datasets.helpsteer3`](/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3)** +- **[`nemo_rl.data.datasets.response_datasets.nemogym_dataset`](/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset)** +- **[`nemo_rl.data.datasets.response_datasets.oai_format_dataset`](/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset)** +- **[`nemo_rl.data.datasets.response_datasets.oasst`](/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst)** +- **[`nemo_rl.data.datasets.response_datasets.openmathinstruct2`](/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2)** +- **[`nemo_rl.data.datasets.response_datasets.refcoco`](/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco)** +- **[`nemo_rl.data.datasets.response_datasets.response_dataset`](/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset)** +- **[`nemo_rl.data.datasets.response_datasets.squad`](/nemo-rl/nemo_rl/data/datasets/response_datasets/squad)** +- **[`nemo_rl.data.datasets.response_datasets.tulu3`](/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`load_response_dataset`](#nemo_rl-data-datasets-response_datasets-load_response_dataset) | Loads response dataset. | + +### Data + +[`DATASET_REGISTRY`](#nemo_rl-data-datasets-response_datasets-DATASET_REGISTRY) + +[`__all__`](#nemo_rl-data-datasets-response_datasets-__all__) + +### API + + + + + +```python +nemo_rl.data.datasets.response_datasets.load_response_dataset( + data_config: nemo_rl.data.ResponseDatasetConfig +) +``` + + + + + + +Loads response dataset. + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.DATASET_REGISTRY = {'AIME2024': AIME2024Dataset, 'clevr-cogent': CLEVRCoGenTDataset, 'DAPOMath17K':... +``` + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.__all__ = ['AIME2024Dataset', 'CLEVRCoGenTDataset', 'DAPOMath17KDataset', 'DAPOMathAIME202... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx new file mode 100644 index 0000000..334fee4 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/aime24.mdx @@ -0,0 +1,66 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/aime24 +title: nemo_rl.data.datasets.response_datasets.aime24 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AIME2024Dataset`](#nemo_rl-data-datasets-response_datasets-aime24-AIME2024Dataset) | Simple wrapper around the AIME2024 dataset with train split. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.aime24.AIME2024Dataset( + repeat: int = 16, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the AIME2024 dataset with train split. + +**Parameters:** + + +Number of times to repeat the dataset, default is 16 + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.aime24.AIME2024Dataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx new file mode 100644 index 0000000..2bf7236 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/clevr.mdx @@ -0,0 +1,97 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/clevr +title: nemo_rl.data.datasets.response_datasets.clevr +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CLEVRCoGenTDataset`](#nemo_rl-data-datasets-response_datasets-clevr-CLEVRCoGenTDataset) | Simple wrapper around the CLEVR-CoGenT dataset. | + +### Functions + +| Name | Description | +|------|-------------| +| [`format_answer_fromtags`](#nemo_rl-data-datasets-response_datasets-clevr-format_answer_fromtags) | Extract content between <answer> tags and strip whitespace. | +| [`format_clevr_cogent_dataset`](#nemo_rl-data-datasets-response_datasets-clevr-format_clevr_cogent_dataset) | Format the CLEVR-CoGenT dataset into an OpenAI-API-like message log. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.clevr.CLEVRCoGenTDataset( + split: str = 'train', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the CLEVR-CoGenT dataset. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.clevr.format_answer_fromtags( + answer: str +) -> str +``` + + + + + + +Extract content between <answer> tags and strip whitespace. + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.clevr.format_clevr_cogent_dataset( + example: dict[str, typing.Any], + return_pil: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Format the CLEVR-CoGenT dataset into an OpenAI-API-like message log. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx new file mode 100644 index 0000000..6866067 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math.mdx @@ -0,0 +1,84 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/dapo_math +title: nemo_rl.data.datasets.response_datasets.dapo_math +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DAPOMath17KDataset`](#nemo_rl-data-datasets-response_datasets-dapo_math-DAPOMath17KDataset) | Simple wrapper around the DAPO Math 17K dataset with train split. | +| [`DAPOMathAIME2024Dataset`](#nemo_rl-data-datasets-response_datasets-dapo_math-DAPOMathAIME2024Dataset) | - | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.dapo_math.DAPOMath17KDataset( + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the DAPO Math 17K dataset with train split. + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.dapo_math.DAPOMath17KDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + + + + + + +```python +class nemo_rl.data.datasets.response_datasets.dapo_math.DAPOMathAIME2024Dataset( + kwargs = {} +) +``` + + + + + + +**Bases:** [DAPOMath17KDataset](#nemo_rl-data-datasets-response_datasets-dapo_math-DAPOMath17KDataset) + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx new file mode 100644 index 0000000..e1a6e7d --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler.mdx @@ -0,0 +1,59 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/deepscaler +title: nemo_rl.data.datasets.response_datasets.deepscaler +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DeepScalerDataset`](#nemo_rl-data-datasets-response_datasets-deepscaler-DeepScalerDataset) | Simple wrapper around the DeepScaler dataset with train split. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.deepscaler.DeepScalerDataset( + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the DeepScaler dataset with train split. + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.deepscaler.DeepScalerDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx new file mode 100644 index 0000000..a4be5a2 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k.mdx @@ -0,0 +1,76 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/geometry3k +title: nemo_rl.data.datasets.response_datasets.geometry3k +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Geometry3KDataset`](#nemo_rl-data-datasets-response_datasets-geometry3k-Geometry3KDataset) | Simple wrapper around the Geometry3K dataset. | + +### Functions + +| Name | Description | +|------|-------------| +| [`format_geometry3k_dataset`](#nemo_rl-data-datasets-response_datasets-geometry3k-format_geometry3k_dataset) | Format the Geometry3K dataset into an OpenAI-API-like message log. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.geometry3k.Geometry3KDataset( + split: str = 'train', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the Geometry3K dataset. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.geometry3k.format_geometry3k_dataset( + example: dict[str, typing.Any], + return_pil: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Format the Geometry3K dataset into an OpenAI-API-like message log. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx new file mode 100644 index 0000000..7176bf4 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3.mdx @@ -0,0 +1,66 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/helpsteer3 +title: nemo_rl.data.datasets.response_datasets.helpsteer3 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`HelpSteer3Dataset`](#nemo_rl-data-datasets-response_datasets-helpsteer3-HelpSteer3Dataset) | Simple wrapper around the HelpSteer3 dataset with preference subset. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.helpsteer3.HelpSteer3Dataset( + split: str = 'train', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the HelpSteer3 dataset with preference subset. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.helpsteer3.HelpSteer3Dataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx new file mode 100644 index 0000000..54915b0 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset.mdx @@ -0,0 +1,54 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/nemogym_dataset +title: nemo_rl.data.datasets.response_datasets.nemogym_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`NemoGymDataset`](#nemo_rl-data-datasets-response_datasets-nemogym_dataset-NemoGymDataset) | Simple wrapper around the Nemo Gym dataset. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.nemogym_dataset.NemoGymDataset( + data_path: str, + repeat: int = 1, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the Nemo Gym dataset. + +**Parameters:** + + +Path to the dataset JSONL file + + + +Number of times to repeat the dataset, default is 1 + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx new file mode 100644 index 0000000..f1c75e2 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset.mdx @@ -0,0 +1,214 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oai_format_dataset +title: nemo_rl.data.datasets.response_datasets.oai_format_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`OpenAIFormatDataset`](#nemo_rl-data-datasets-response_datasets-oai_format_dataset-OpenAIFormatDataset) | This class is used to load an SFT dataset in the OpenAI format. | +| [`PreservingDataset`](#nemo_rl-data-datasets-response_datasets-oai_format_dataset-PreservingDataset) | A dataset wrapper that preserves original dict structure without None-filling. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.oai_format_dataset.OpenAIFormatDataset( + data_path: str, + chat_key: str = 'messages', + system_key: str | None = None, + system_prompt: str | None = None, + tool_key: str | None = 'tools', + use_preserving_dataset: bool = False, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +This class is used to load an SFT dataset in the OpenAI format. + +The dataset should be in the following format: +{ + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."} + ] +} +Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#openai-format-datasets-with-tool-calling-support for more details. + +**Parameters:** + + +Path to the dataset JSON file + + + +Key for the messages list in the dataset (default: "messages") + + + +Optional key for system prompt in the dataset + + + +Optional system prompt to add if not in the dataset + + + +Key for tools in the dataset (default: "tools") + + + +If True, uses PreservingDataset to maintain +heterogeneous schemas (e.g., for tool calls with varying argument +structures). If False, uses standard HuggingFace dataset loading. +Default is False for backward compatibility. + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oai_format_dataset.OpenAIFormatDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + + + + + + +```python +class nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset( + data: list[dict[str, typing.Any]] +) +``` + + + + + + +A dataset wrapper that preserves original dict structure without None-filling. + +Unlike HuggingFace's Dataset class which enforces schema uniformity across all samples +(filling missing keys with None), this class maintains the exact structure of each sample. +This is critical for heterogeneous data like tool calls where different samples may have +different argument structures. + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.__getitem__( + idx: typing.Union[int, slice, list] +) -> typing.Union[dict[str, typing.Any], list[dict[str, typing.Any]]] +``` + + + + + + +Support integer indexing, slicing, and list indexing. + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.__iter__() +``` + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.__len__() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset.map( + function: typing.Callable, + args = (), + kwargs = {} +) -> nemo_rl.data.datasets.response_datasets.oai_format_dataset.PreservingDataset +``` + + + + + + +Apply a function to each sample in the dataset. + +**Parameters:** + + +Function to apply to each sample + + + +If True, pass index as second argument to function + + +**Returns:** `PreservingDataset` + +New PreservingDataset with transformed samples + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx new file mode 100644 index 0000000..8a92408 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/oasst.mdx @@ -0,0 +1,127 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/oasst +title: nemo_rl.data.datasets.response_datasets.oasst +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`OasstDataset`](#nemo_rl-data-datasets-response_datasets-oasst-OasstDataset) | Simple wrapper around the OASST dataset. | + +### Functions + +| Name | Description | +|------|-------------| +| [`get_data_records`](#nemo_rl-data-datasets-response_datasets-oasst-get_data_records) | - | +| [`parse_conversations`](#nemo_rl-data-datasets-response_datasets-oasst-parse_conversations) | Recusive function that returns all the sub converstaions in a list starting from node tree_obj. | + +### Data + +[`SYSTEM_PROMPT`](#nemo_rl-data-datasets-response_datasets-oasst-SYSTEM_PROMPT) + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.oasst.OasstDataset( + split_validation_size: float = 0.05, + seed: int = 42, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the OASST dataset. + +**Parameters:** + + +Size of the validation data, default is 0.05 + + + +Seed for train/validation split when split_validation_size > 0, default is 42 + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oasst.get_data_records( + objs, + task_name: str = 'oasst' +) +``` + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oasst.parse_conversations( + tree_obj, + first: bool = False +) +``` + + + + + + +Recusive function that returns all the sub converstaions in a list starting from node tree_obj. + +**Parameters:** + + +current conversation node + + +**Returns:** + +a list of sub conversation threads including the current conversation node + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.oasst.SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The ass... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx new file mode 100644 index 0000000..6124a92 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2.mdx @@ -0,0 +1,84 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/openmathinstruct2 +title: nemo_rl.data.datasets.response_datasets.openmathinstruct2 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`OpenMathInstruct2Dataset`](#nemo_rl-data-datasets-response_datasets-openmathinstruct2-OpenMathInstruct2Dataset) | Simple wrapper around the OpenMathInstruct2 dataset. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.openmathinstruct2.OpenMathInstruct2Dataset( + output_key: str = 'expected_answer', + split: str = 'train_1M', + split_validation_size: float = 0.05, + seed: int = 42, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the OpenMathInstruct2 dataset. + +**Parameters:** + + +Key for the output text, default is "expected_answer" + + + +Split name for the dataset, default is "train_1M" + + + +Size of the validation data, default is 0.05 + + + +Seed for train/validation split when split_validation_size > 0, default is 42 + + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.openmathinstruct2.OpenMathInstruct2Dataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx new file mode 100644 index 0000000..34e0b8d --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco.mdx @@ -0,0 +1,160 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/refcoco +title: nemo_rl.data.datasets.response_datasets.refcoco +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RefCOCODataset`](#nemo_rl-data-datasets-response_datasets-refcoco-RefCOCODataset) | Simple wrapper around the RefCOCO dataset. | + +### Functions + +| Name | Description | +|------|-------------| +| [`download_and_unzip`](#nemo_rl-data-datasets-response_datasets-refcoco-download_and_unzip) | Downloads a zip file from a given URL to a target directory and unzips it into a specified subdirectory within the target directory, showing download progress. | +| [`format_refcoco_dataset`](#nemo_rl-data-datasets-response_datasets-refcoco-format_refcoco_dataset) | Format the RefCOCO dataset from huggingface. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.refcoco.RefCOCODataset( + split: str = 'train', + download_dir: str = './coco_images', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the RefCOCO dataset. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + +Directory to download the dataset to, default is "./coco_images" + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.refcoco.RefCOCODataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.refcoco.download_and_unzip( + url: str, + target_directory: str, + subdir_name: str = '.' +) +``` + + + + + + +Downloads a zip file from a given URL to a target directory and unzips it into a specified subdirectory within the target directory, showing download progress. + +**Parameters:** + + +The URL of the zip file to download. + + + +The directory where the zip file will be downloaded + and unzipped. + + + +The name of the subdirectory within the target_directory + where the contents of the zip file will be unzipped. + Defaults to "train". + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.refcoco.format_refcoco_dataset( + example: dict[str, typing.Any], + width: int = 256, + height: int = 256, + caption_type: str = 'random' +) -> dict[str, typing.Any] +``` + + + + + + +Format the RefCOCO dataset from huggingface. + +This should be replaced with our own curated RefCOCO/+/g dataset soon + +**Parameters:** + + +The example to format. + + + +The width of the resized image. + + + +The height of the resized image. + + + +The type of caption to use. + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx new file mode 100644 index 0000000..09bbf07 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset.mdx @@ -0,0 +1,104 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/response_dataset +title: nemo_rl.data.datasets.response_datasets.response_dataset +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ResponseDataset`](#nemo_rl-data-datasets-response_datasets-response_dataset-ResponseDataset) | Dataset class for response data which can be loaded from a JSON file. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.response_dataset.ResponseDataset( + data_path: str, + input_key: str = 'input', + output_key: str = 'output', + subset: typing.Optional[str] = None, + split: typing.Optional[str] = None, + split_validation_size: float = 0, + seed: int = 42, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Dataset class for response data which can be loaded from a JSON file. + +This class handles loading of response data for SFT and RL training. +The input JSONL files should contain valid JSON objects formatted like this: +{ + input_key: str, # The input prompt/context + output_key: str, # The output response/answer +} +Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details. + +**Parameters:** + + +Path to the dataset JSON file + + + +Key for the input text, default is "input" + + + +Key for the output text, default is "output" + + + +Optional subset name for the dataset, used for HuggingFace datasets + + + +Optional split name for the dataset, used for HuggingFace datasets + + + +Size of the validation data, default is 0 + + + +Seed for train/validation split when split_validation_size > 0, default is 42 + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.response_dataset.ResponseDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx new file mode 100644 index 0000000..f201b41 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/squad.mdx @@ -0,0 +1,66 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/squad +title: nemo_rl.data.datasets.response_datasets.squad +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`SquadDataset`](#nemo_rl-data-datasets-response_datasets-squad-SquadDataset) | Simple wrapper around the squad dataset. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.squad.SquadDataset( + split: str = 'train', + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the squad dataset. + +**Parameters:** + + +Split name for the dataset, default is "train" + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.squad.SquadDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx new file mode 100644 index 0000000..68dfa11 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3.mdx @@ -0,0 +1,76 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/response_datasets/tulu3 +title: nemo_rl.data.datasets.response_datasets.tulu3 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Tulu3SftMixtureDataset`](#nemo_rl-data-datasets-response_datasets-tulu3-Tulu3SftMixtureDataset) | Simple wrapper around the Tulu3 SFT mixture dataset with train split. | + +### API + + + + + +```python +class nemo_rl.data.datasets.response_datasets.tulu3.Tulu3SftMixtureDataset( + split_validation_size: float = 0.05, + seed: int = 42, + max_samples: int | None = None, + kwargs = {} +) +``` + + + + + + +**Bases:** [RawDataset](/nemo-rl/nemo_rl/data/datasets/raw_dataset#nemo_rl-data-datasets-raw_dataset-RawDataset) + +Simple wrapper around the Tulu3 SFT mixture dataset with train split. + +**Parameters:** + + +Size of the validation data, default is 0.05 + + + +Seed for train/validation split when split_validation_size > 0, default is 42 + + + +Optional maximum number of samples to use from the dataset + + + + + + + + + + + + +```python +nemo_rl.data.datasets.response_datasets.tulu3.Tulu3SftMixtureDataset.format_data( + data: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx new file mode 100644 index 0000000..d5a02db --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/datasets/utils.mdx @@ -0,0 +1,191 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/datasets/utils +title: nemo_rl.data.datasets.utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`assert_no_double_bos`](#nemo_rl-data-datasets-utils-assert_no_double_bos) | Assert that there are no double starting BOS tokens in the message. | +| [`extract_necessary_env_names`](#nemo_rl-data-datasets-utils-extract_necessary_env_names) | Extract the necessary environment names from the data config. | +| [`load_dataset_from_path`](#nemo_rl-data-datasets-utils-load_dataset_from_path) | Load a dataset from a local file, huggingface dataset, or Arrow dataset (saved with save_to_disk). | +| [`pil_to_base64`](#nemo_rl-data-datasets-utils-pil_to_base64) | Converts a PIL Image object to a base64 encoded string. | +| [`update_single_dataset_config`](#nemo_rl-data-datasets-utils-update_single_dataset_config) | Fill the single dataset config with default dataset config. | + +### Data + +[`TokenizerType`](#nemo_rl-data-datasets-utils-TokenizerType) + +### API + + + + + +```python +nemo_rl.data.datasets.utils.assert_no_double_bos( + token_ids: torch.Tensor, + tokenizer: nemo_rl.data.datasets.utils.TokenizerType +) -> None +``` + + + + + + +Assert that there are no double starting BOS tokens in the message. + +**Parameters:** + + +List of token IDs + + + +Tokenizer + + + + + + + + + +```python +nemo_rl.data.datasets.utils.extract_necessary_env_names( + data_config: dict +) -> list[str] +``` + + + + + + +Extract the necessary environment names from the data config. + +Some environments are set in env_configs but not used in the data config. +This function extracts the necessary environment names from the data config. + +**Parameters:** + + +The data config. + + +**Returns:** `list[str]` + +The necessary environment names. + + + + + + + + +```python +nemo_rl.data.datasets.utils.load_dataset_from_path( + data_path: str, + data_subset: typing.Optional[str] = None, + data_split: typing.Optional[str] = 'train' +) +``` + + + + + + +Load a dataset from a local file, huggingface dataset, or Arrow dataset (saved with save_to_disk). + +**Parameters:** + + +The path to the dataset. + + + +The subset to load from the dataset. Only supported for huggingface datasets. + + + +The split to load from the dataset. + + + + + + + + + +```python +nemo_rl.data.datasets.utils.pil_to_base64( + image: PIL.Image.Image, + format: str = 'PNG' +) -> str +``` + + + + + + +Converts a PIL Image object to a base64 encoded string. + +**Parameters:** + + +The PIL Image object to convert. + + + +The image format (e.g., "PNG", "JPEG"). Defaults to "PNG". + + +**Returns:** `str` + +A base64 encoded string representation of the image. + + + + + + + + +```python +nemo_rl.data.datasets.utils.update_single_dataset_config( + data_config: dict, + default_data_config: dict +) -> None +``` + + + + + + +Fill the single dataset config with default dataset config. + + + + + + + + +```python +nemo_rl.data.datasets.utils.TokenizerType = Union[PreTrainedTokenizerBase, AutoProcessor] +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx new file mode 100644 index 0000000..b435173 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/interfaces.mdx @@ -0,0 +1,284 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/interfaces +title: nemo_rl.data.interfaces +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DatumSpec`](#nemo_rl-data-interfaces-DatumSpec) | - | +| [`PreferenceDatumSpec`](#nemo_rl-data-interfaces-PreferenceDatumSpec) | - | +| [`TaskDataProcessFnCallable`](#nemo_rl-data-interfaces-TaskDataProcessFnCallable) | A callable that processes a loaded datum dictionary into a DatumSpec. | +| [`TaskDataSpec`](#nemo_rl-data-interfaces-TaskDataSpec) | - | + +### Data + +[`FlatMessagesType`](#nemo_rl-data-interfaces-FlatMessagesType) + +[`LLMMessageLogType`](#nemo_rl-data-interfaces-LLMMessageLogType) + +[`PathLike`](#nemo_rl-data-interfaces-PathLike) + +[`TokenizerType`](#nemo_rl-data-interfaces-TokenizerType) + +[`VLMMessageLogType`](#nemo_rl-data-interfaces-VLMMessageLogType) + +### API + + + + + +```python +class nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.interfaces.PreferenceDatumSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.interfaces.TaskDataProcessFnCallable() +``` + + + + + + +Protocol + +A callable that processes a loaded datum dictionary into a DatumSpec. + + + + + + +```python +nemo_rl.data.interfaces.TaskDataProcessFnCallable.__call__( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.interfaces.TokenizerType, + max_seq_length: int | None, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + + + + + + + + + +```python +class nemo_rl.data.interfaces.TaskDataSpec( + task_name: typing.Optional[str] = None, + prompt_file: typing.Optional[nemo_rl.data.interfaces.PathLike] = None, + system_prompt_file: typing.Optional[nemo_rl.data.interfaces.PathLike] = None +) +``` + + + + + + +Dataclass + + + + + + + + + + + + + +```python +nemo_rl.data.interfaces.TaskDataSpec.__post_init__() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.data.interfaces.TaskDataSpec.copy_defaults( + from_spec: nemo_rl.data.interfaces.TaskDataSpec +) -> None +``` + + + + + + +Apply default values from another Task instance for any None attributes. + + + + + + + + + +```python +nemo_rl.data.interfaces.FlatMessagesType = dict[str, Union[list[str], torch.Tensor]] +``` + + + + + + + + + +```python +nemo_rl.data.interfaces.LLMMessageLogType = list[dict[str, Union[str, torch.Tensor]]] +``` + + + + + + + + + +```python +nemo_rl.data.interfaces.PathLike = Union[str, 'os.PathLike[Any]'] +``` + + + + + + + + + +```python +nemo_rl.data.interfaces.TokenizerType = PreTrainedTokenizerBase +``` + + + + + + + + + +```python +nemo_rl.data.interfaces.VLMMessageLogType = list[dict[str, Union[str, torch.Tensor, PackedTensor]]] +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx new file mode 100644 index 0000000..49a124a --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/llm_message_utils.mdx @@ -0,0 +1,548 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/llm_message_utils +title: nemo_rl.data.llm_message_utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_pad_tensor`](#nemo_rl-data-llm_message_utils-_pad_tensor) | Pad a tensor to the specified length. | +| [`_validate_tensor_consistency`](#nemo_rl-data-llm_message_utils-_validate_tensor_consistency) | Validate that all tensors have consistent dtypes and devices. | +| [`add_loss_mask_to_message_log`](#nemo_rl-data-llm_message_utils-add_loss_mask_to_message_log) | Add token-level loss masks to each message in a message log. | +| [`batched_message_log_to_flat_message`](#nemo_rl-data-llm_message_utils-batched_message_log_to_flat_message) | Process and pad a batch of message logs for model input. | +| [`get_first_index_that_differs`](#nemo_rl-data-llm_message_utils-get_first_index_that_differs) | Get the first index that differs between two strings. | +| [`get_formatted_message_log`](#nemo_rl-data-llm_message_utils-get_formatted_message_log) | Format and tokenize chat messages using the specified template. | +| [`get_images_from_message`](#nemo_rl-data-llm_message_utils-get_images_from_message) | Get all images from a message log item. | +| [`get_keys_from_message_log`](#nemo_rl-data-llm_message_utils-get_keys_from_message_log) | Return a new LLMMessageLogType containing only the specified keys from each message. | +| [`message_log_shape`](#nemo_rl-data-llm_message_utils-message_log_shape) | Get the shape of the tensors in the message log. | +| [`message_log_to_flat_messages`](#nemo_rl-data-llm_message_utils-message_log_to_flat_messages) | Converts a message log (sequence of message turns) into a flattened representation. | +| [`remap_dataset_keys`](#nemo_rl-data-llm_message_utils-remap_dataset_keys) | Remap dataset keys as per mapping. | + +### Data + +[`Tensor`](#nemo_rl-data-llm_message_utils-Tensor) + +[`TokenizerType`](#nemo_rl-data-llm_message_utils-TokenizerType) + +### API + + + + + +```python +nemo_rl.data.llm_message_utils._pad_tensor( + tensor: nemo_rl.data.llm_message_utils.Tensor, + max_len: int, + pad_side: str, + pad_value: int = 0 +) -> nemo_rl.data.llm_message_utils.Tensor +``` + + + + + + +Pad a tensor to the specified length. + +**Parameters:** + + +Tensor to pad + + + +Length to pad to + + + +Whether to pad on the 'left' or 'right' + + + +Value to use for padding + + +**Returns:** `Tensor` + +torch.Tensor: Padded tensor + + + + + + + + +```python +nemo_rl.data.llm_message_utils._validate_tensor_consistency( + tensors: list[nemo_rl.data.llm_message_utils.Tensor] +) -> None +``` + + + + + + +Validate that all tensors have consistent dtypes and devices. + +**Parameters:** + + +List of tensors to validate + + +**Raises:** + +- `RuntimeError`: If tensors have different dtypes or devices + + + + + + + + +```python +nemo_rl.data.llm_message_utils.add_loss_mask_to_message_log( + batch_message_log: list[nemo_rl.data.interfaces.LLMMessageLogType], + roles_to_train_on: list[str] = ['assistant'], + only_unmask_final: bool = False +) -> None +``` + + + + + + +Add token-level loss masks to each message in a message log. + +**Parameters:** + + +List of message dictionaries containing token IDs and metadata + + + +List of strings indicating which speakers to unmask. Default: ["assistant"] + + + +If True, only unmask the final message in the log. Default: False + + + + + + + + + +```python +nemo_rl.data.llm_message_utils.batched_message_log_to_flat_message( + message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], + pad_value_dict: typing.Optional[dict[str, int]] = None, + make_sequence_length_divisible_by: int = 1 +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.FlatMessagesType], nemo_rl.data.llm_message_utils.Tensor] +``` + + + + + + +Process and pad a batch of message logs for model input. + +For each message log in the batch: +1. Converts it to a flat representation using message_log_to_flat_messages +2. Pads all resulting tensors to the same length for batching +3. Returns a BatchedDataDict and sequence lengths tensor + +Padding is always applied to the right side of sequences. + +Examples: + + +```python +>>> import torch +>>> from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message +>>> from nemo_rl.distributed.batched_data_dict import BatchedDataDict +>>> # Create a batch of two message logs with different lengths +>>> message_log_batch = [ +... # First conversation +... [ +... {'role': 'user', 'content': 'What is 2+2?', 'token_ids': torch.tensor([1, 2, 3, 4, 5])}, +... {'role': 'assistant', 'content': '4', 'token_ids': torch.tensor([6, 7])} +... ], +... # Second conversation +... [ +... {'role': 'user', 'content': 'Solve x+10=15', 'token_ids': torch.tensor([1, 8, 9, 10, 11, 12])}, +... {'role': 'assistant', 'content': 'x=5', 'token_ids': torch.tensor([13, 14, 15])} +... ] +... ] +>>> pad_value_dict = {'token_ids': 0} +>>> batched_flat, input_lengths = batched_message_log_to_flat_message(message_log_batch, pad_value_dict) +>>> batched_flat['token_ids'][0].tolist() +[1, 2, 3, 4, 5, 6, 7, 0, 0] +>>> batched_flat['token_ids'][1].tolist() +[1, 8, 9, 10, 11, 12, 13, 14, 15] +>>> batched_flat['content'][0] +['What is 2+2?', '4'] +>>> batched_flat['content'][1] +['Solve x+10=15', 'x=5'] +>>> batched_flat['role'] +[['user', 'assistant'], ['user', 'assistant']] +>>> input_lengths +tensor([7, 9], dtype=torch.int32) +>>> +>>> # Multimodal example: include images on both conversations and verify packing +>>> from nemo_rl.data.multimodal_utils import PackedTensor +>>> mm_batch = [ +... [ +... {'role': 'user', 'content': 'look', 'token_ids': torch.tensor([1, 2, 3]), 'images': PackedTensor(torch.randn(2, 3, 4, 4), dim_to_pack=0)}, +... {'role': 'assistant', 'content': 'ok', 'token_ids': torch.tensor([4])} +... ], +... [ +... {'role': 'user', 'content': 'again', 'token_ids': torch.tensor([5, 6]), 'images': PackedTensor(torch.randn(1, 3, 4, 4), dim_to_pack=0)}, +... {'role': 'assistant', 'content': 'fine', 'token_ids': torch.tensor([7, 8])} +... ] +... ] +>>> mm_flat, mm_lengths = batched_message_log_to_flat_message(mm_batch, pad_value_dict={'token_ids': 0}) +>>> isinstance(mm_flat['images'], PackedTensor) +True +>>> tuple(mm_flat['images'].as_tensor().shape) # 2 + 1 images +(3, 3, 4, 4) +>>> mm_lengths +tensor([4, 4], dtype=torch.int32) +>>> +``` + + + +**Parameters:** + + +List of LLMMessageLogType (each a conversation with multiple turns) + + + +Dictionary mapping keys to padding values (default is 0) + + + +forces the data to be divisible by this value + + +**Returns:** `BatchedDataDict[FlatMessagesType]` + +BatchedDataDict[FlatMessagesType]: Dictionary containing padded stacked tensors + +**Raises:** + +- `RuntimeError`: If tensors have different dtypes or devices + + + + + + + + +```python +nemo_rl.data.llm_message_utils.get_first_index_that_differs( + str1: str, + str2: str +) -> int +``` + + + + + + +Get the first index that differs between two strings. + + + + + + + + +```python +nemo_rl.data.llm_message_utils.get_formatted_message_log( + message_log: nemo_rl.data.interfaces.LLMMessageLogType, + tokenizer: nemo_rl.data.llm_message_utils.TokenizerType, + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + add_bos_token: bool = True, + add_eos_token: bool = True, + add_generation_prompt: bool = False, + tools: typing.Optional[list[dict[str, typing.Any]]] = None +) -> nemo_rl.data.interfaces.LLMMessageLogType +``` + + + + + + +Format and tokenize chat messages using the specified template. + +Returns: + The message log with updated 'token_ids' and 'content' fields. + +**Parameters:** + + +List of message dicts with 'role' and 'content' keys + + + +Tokenizer for converting text to token IDs + + + +Task spec for this dataset. + + + +Whether to add bos token to first message if it is not already present. Default: True + + + +Whether to add eos token to last message if it is not already present. Default: True + + + +Whether to include assistant's generation prompt in user messages. Default: False + + + +Optional list of tool/function definitions to pass to the chat template. Default: None + + + + + + + + + +```python +nemo_rl.data.llm_message_utils.get_images_from_message( + message: dict[str, typing.Any] +) -> list[typing.Any] +``` + + + + + + +Get all images from a message log item. + + + + + + + + +```python +nemo_rl.data.llm_message_utils.get_keys_from_message_log( + message_log: nemo_rl.data.interfaces.LLMMessageLogType, + keys: list[str] +) -> nemo_rl.data.interfaces.LLMMessageLogType +``` + + + + + + +Return a new LLMMessageLogType containing only the specified keys from each message. + +**Parameters:** + + +Original message log to extract keys from + + + +List of keys to keep in each message + + +**Returns:** `LLMMessageLogType` + +New list with only specified keys + + + + + + + + +```python +nemo_rl.data.llm_message_utils.message_log_shape( + message_log: nemo_rl.data.interfaces.LLMMessageLogType +) -> list[dict[str, torch.Size]] +``` + + + + + + +Get the shape of the tensors in the message log. + +This utility function examines each message in the message log and reports +the shape of tensor values or recursively processes list values. + +**Parameters:** + + +The message log to analyze + + +**Returns:** `list[dict[str, torch.Size]]` + +List of dictionaries containing tensor shapes for each key in messages + + + + + + + + +```python +nemo_rl.data.llm_message_utils.message_log_to_flat_messages( + message_log: nemo_rl.data.interfaces.LLMMessageLogType +) -> nemo_rl.data.interfaces.FlatMessagesType +``` + + + + + + +Converts a message log (sequence of message turns) into a flattened representation. + +This function takes a message log (list of dict messages with 'role', 'content', 'token_ids', etc.) +and converts it to a flat dictionary where all tensors of the same key are concatenated and +all strings of the same key are put into lists. + +Examples: + + +```python +>>> import torch +>>> from nemo_rl.data.llm_message_utils import message_log_to_flat_messages +>>> # Create a simple message log with two messages +>>> message_log = [ +... {'role': 'user', 'content': 'Hello', 'token_ids': torch.tensor([1, 2, 3])}, +... {'role': 'assistant', 'content': 'Hi there', 'token_ids': torch.tensor([4, 5, 6, 7])} +... ] +>>> flat_msgs = message_log_to_flat_messages(message_log) +>>> flat_msgs['role'] +['user', 'assistant'] +>>> flat_msgs['content'] +['Hello', 'Hi there'] +>>> flat_msgs['token_ids'] +tensor([1, 2, 3, 4, 5, 6, 7]) +>>> +>>> # Multimodal example: +>>> from nemo_rl.data.multimodal_utils import PackedTensor +>>> img1 = torch.randn(2, 3, 4, 4) +>>> img2 = torch.randn(3, 3, 4, 4) +>>> mm_log = [ +... {'role': 'user', 'content': 'see', 'token_ids': torch.tensor([1]), 'images': PackedTensor(img1, dim_to_pack=0)}, +... {'role': 'assistant', 'content': 'ok', 'token_ids': torch.tensor([2, 3]), 'images': PackedTensor(img2, dim_to_pack=0)}, +... ] +>>> flat_mm = message_log_to_flat_messages(mm_log) +>>> tuple(flat_mm['images'].as_tensor().shape) +(5, 3, 4, 4) +>>> +``` + + + +**Parameters:** + + +List of message dictionaries with 'role', 'content', and potentially 'token_ids' + + +**Returns:** `FlatMessagesType` + +Dictionary mapping keys to concatenated tensors and string lists + + + + + + + + +```python +nemo_rl.data.llm_message_utils.remap_dataset_keys( + dataset: datasets.Dataset, + mapping_dict: dict[str, str] +) -> datasets.Dataset +``` + + + + + + +Remap dataset keys as per mapping. + +**Parameters:** + + +The input dataset to remap keys in + + + +A dictionary mapping input keys to output keys + + +**Returns:** `Dataset` + +A new dataset with remapped keys + + + + + + + + +```python +nemo_rl.data.llm_message_utils.Tensor = torch.Tensor +``` + + + + + + + + + +```python +nemo_rl.data.llm_message_utils.TokenizerType = PreTrainedTokenizerBase +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx new file mode 100644 index 0000000..89f71d0 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/multimodal_utils.mdx @@ -0,0 +1,298 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/multimodal_utils +title: nemo_rl.data.multimodal_utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`PackedTensor`](#nemo_rl-data-multimodal_utils-PackedTensor) | Wrapper around a list of torch tensors and a dimension along which to pack the tensors. | + +### Functions + +| Name | Description | +|------|-------------| +| [`get_dim_to_pack_along`](#nemo_rl-data-multimodal_utils-get_dim_to_pack_along) | Special considerations for packing certain keys from certain processors. | +| [`get_multimodal_keys_from_processor`](#nemo_rl-data-multimodal_utils-get_multimodal_keys_from_processor) | Get keys of the multimodal data that can be used as model inputs. | +| [`resolve_to_image`](#nemo_rl-data-multimodal_utils-resolve_to_image) | Resolve the image path to a PIL.Image object. | + +### API + + + + + +```python +class nemo_rl.data.multimodal_utils.PackedTensor( + tensors: typing.Union[torch.Tensor, list[typing.Optional[torch.Tensor]], list[None]], + dim_to_pack: int +) +``` + + + + + + +Wrapper around a list of torch tensors and a dimension along which to pack the tensors. + +This class is used to wrap a list of tensors along with a `dim_to_pack` parameter. +It can be used for data that can be packed along different dimensions (such as multimodal data). + +`dim_to_pack` is used to specify the dimension along which to pack the tensors. + +The list of tensors can be returned as a single packed tensor by calling `as_tensor` which will concatenate the tensors along the `dim_to_pack` dimension. + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.__len__() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.as_tensor( + device: typing.Optional[torch.device] = None +) -> typing.Optional[torch.Tensor] +``` + + + + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.concat( + from_packed_tensors: list[nemo_rl.data.multimodal_utils.PackedTensor] +) -> nemo_rl.data.multimodal_utils.PackedTensor +``` + + + + + + +classmethod + +Concatenate a list of PackedTensor objects into a single PackedTensor. + +The underlying tensors from the PackedTensors are combined into a single list of tensors and used to create a new PackedTensor. + +Each batch must have the same dim_to_pack. + +Example: + + +```python +>>> import torch +>>> from nemo_rl.data.multimodal_utils import PackedTensor +>>> p1 = PackedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])], dim_to_pack=0) +>>> p2 = PackedTensor([torch.tensor([7, 8, 9])], dim_to_pack=0) +>>> p3 = PackedTensor.concat([p1, p2]) +>>> p3.tensors +[tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])] +>>> p3.as_tensor() +tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]) +>>> +``` + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.empty_like( + other: nemo_rl.data.multimodal_utils.PackedTensor +) -> nemo_rl.data.multimodal_utils.PackedTensor +``` + + + + + + +classmethod + +Return a new PackedTensor with same length and dim_to_pack as `other`, with all entries None. + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.flattened_concat( + from_packed_tensors: list[nemo_rl.data.multimodal_utils.PackedTensor] +) -> nemo_rl.data.multimodal_utils.PackedTensor +``` + + + + + + +classmethod + +Given a list of PackedTensor objects, flattens each PackedTensor and then concatenates them into a single PackedTensor. + +Each PackedTensor is first flattened by packing along the PackedTensor's `dim_to_pack` dimension. Then, the resulting flattened tensors are used to create a new PackedTensor. + +This is different from `PackedTensor.concat` which simply extends the underlying list of tensors. This is important because the `slice` and `__len__` methods operate on the underlying list of tensors. Note, however, that calling `as_tensor` on the resulting PackedTensor will result in the same tensor as `concat`. + +Each batch must have the same dim_to_pack. + +Example: + + +```python +>>> import torch +>>> from nemo_rl.data.multimodal_utils import PackedTensor +>>> p1 = PackedTensor([torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])], dim_to_pack=0) +>>> p2 = PackedTensor([torch.tensor([7, 8, 9])], dim_to_pack=0) +>>> p3 = PackedTensor.flattened_concat([p1, p2]) +>>> p3.tensors +[tensor([1, 2, 3, 4, 5, 6]), tensor([7, 8, 9])] +>>> p3.as_tensor() +tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]) +>>> +``` + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.slice( + indices: typing.Union[list[int], torch.Tensor] +) -> nemo_rl.data.multimodal_utils.PackedTensor +``` + + + + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.PackedTensor.to( + device: str | torch.device +) -> nemo_rl.data.multimodal_utils.PackedTensor +``` + + + + + + + + + + + + + + +```python +nemo_rl.data.multimodal_utils.get_dim_to_pack_along( + processor, + key: str +) -> int +``` + + + + + + +Special considerations for packing certain keys from certain processors. + +In most cases, the packed items are along dim 0 + + + + + + + + +```python +nemo_rl.data.multimodal_utils.get_multimodal_keys_from_processor( + processor +) -> list[str] +``` + + + + + + +Get keys of the multimodal data that can be used as model inputs. + +This will be used in the data_processor function to determine which keys to use as model inputs. + + + + + + + + +```python +nemo_rl.data.multimodal_utils.resolve_to_image( + image_path_or_image: str | PIL.Image.Image +) -> PIL.Image.Image +``` + + + + + + +Resolve the image path to a PIL.Image object. + +image_path can be either: +- path to local file +- url to image +- base64 encoded image + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx new file mode 100644 index 0000000..8161181 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing.mdx @@ -0,0 +1,30 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/packing +title: nemo_rl.data.packing +--- + +## Submodules + +- **[`nemo_rl.data.packing.algorithms`](/nemo-rl/nemo_rl/data/packing/algorithms)** +- **[`nemo_rl.data.packing.metrics`](/nemo-rl/nemo_rl/data/packing/metrics)** + +## Package Contents + +### Data + +[`__all__`](#nemo_rl-data-packing-__all__) + +### API + + + + + +```python +nemo_rl.data.packing.__all__ = ['PackingAlgorithm', 'SequencePacker', 'ConcatenativePacker', 'FirstFitDecreasin... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx new file mode 100644 index 0000000..7337f16 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/algorithms.mdx @@ -0,0 +1,791 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/packing/algorithms +title: nemo_rl.data.packing.algorithms +--- + +Sequence packing algorithms for efficient batching of variable-length sequences. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ConcatenativePacker`](#nemo_rl-data-packing-algorithms-ConcatenativePacker) | Concatenative packing algorithm. | +| [`FirstFitDecreasingPacker`](#nemo_rl-data-packing-algorithms-FirstFitDecreasingPacker) | First-Fit Decreasing (FFD) algorithm for sequence packing. | +| [`FirstFitPacker`](#nemo_rl-data-packing-algorithms-FirstFitPacker) | Base class for First-Fit algorithms. | +| [`FirstFitShufflePacker`](#nemo_rl-data-packing-algorithms-FirstFitShufflePacker) | First-Fit Shuffle algorithm for sequence packing. | +| [`ModifiedFirstFitDecreasingPacker`](#nemo_rl-data-packing-algorithms-ModifiedFirstFitDecreasingPacker) | Modified First-Fit Decreasing (MFFD) algorithm for sequence packing. | +| [`PackingAlgorithm`](#nemo_rl-data-packing-algorithms-PackingAlgorithm) | Enum for supported sequence packing algorithms. | +| [`SequencePacker`](#nemo_rl-data-packing-algorithms-SequencePacker) | Abstract base class for sequence packing algorithms. | + +### Functions + +| Name | Description | +|------|-------------| +| [`get_packer`](#nemo_rl-data-packing-algorithms-get_packer) | Factory function to get a sequence packer based on the algorithm. | + +### API + + + + + +```python +class nemo_rl.data.packing.algorithms.ConcatenativePacker() +``` + + + + + + +**Bases:** [SequencePacker](#nemo_rl-data-packing-algorithms-SequencePacker) + +Concatenative packing algorithm. + +This algorithm simply concatenates sequences in order until reaching the bin capacity, +then starts a new bin. It doesn't try to optimize the packing in any way. + +Time complexity: O(n) where n is the number of sequences. + +Example: + + +```python +>>> examples = { +... "sequence_lengths": [4, 1, 3, 2, 1, 3, 4, 5] +... } +>>> # If packed with seq_length=5: +... {"bins": [ [0, 1], [2, 3], [4, 5], [6], [7] ]} +>>> # If packed with seq_length=8: +... {"bins": [ [0, 1, 2], [3, 4, 5], [6], [7] ]} +``` + + + + + + + + + + +```python +nemo_rl.data.packing.algorithms.ConcatenativePacker._pack_implementation( + sequence_lengths: typing.List[int] +) -> typing.List[typing.List[int]] +``` + + + + + + +Pack sequences using the Concatenative algorithm. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[List[int]]` + +A list of bins, where each bin is a list of indices into the original + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.FirstFitDecreasingPacker() +``` + + + + + + +**Bases:** [FirstFitPacker](#nemo_rl-data-packing-algorithms-FirstFitPacker) + +First-Fit Decreasing (FFD) algorithm for sequence packing. + +This algorithm sorts sequences by length in descending order and then +places each sequence into the first bin where it fits. + +Time complexity: O(n log n) for sorting + O(n * m) for packing, +where n is the number of sequences and m is the number of bins. + + + + + + +```python +nemo_rl.data.packing.algorithms.FirstFitDecreasingPacker._prepare_sequences( + sequence_lengths: typing.List[int] +) -> typing.List[typing.Tuple[int, int]] +``` + + + + + + +Prepare sequences for packing by sorting them in descending order. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[Tuple[int, int]]` + +A list of (length, index) pairs sorted by length in descending order. + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.FirstFitPacker() +``` + + + + + + +**Bases:** [SequencePacker](#nemo_rl-data-packing-algorithms-SequencePacker) + +Base class for First-Fit algorithms. + +First-Fit algorithms place each sequence into the first bin where it fits. +If no bin can fit the sequence, a new bin is created. + +This is an abstract base class that provides the common implementation for +First-Fit variants. Subclasses must implement the _prepare_sequences method +to determine the order in which sequences are processed. + + + + + + +```python +nemo_rl.data.packing.algorithms.FirstFitPacker._pack_implementation( + sequence_lengths: typing.List[int] +) -> typing.List[typing.List[int]] +``` + + + + + + +Pack sequences using the First-Fit algorithm. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[List[int]]` + +A list of bins, where each bin is a list of indices into the original + + + + + + + +```python +nemo_rl.data.packing.algorithms.FirstFitPacker._prepare_sequences( + sequence_lengths: typing.List[int] +) -> typing.List[typing.Tuple[int, int]] +``` + + + + + + +Prepare sequences for packing. + +This method determines the order in which sequences are processed. +Subclasses must override this method. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[Tuple[int, int]]` + +A list of (length, index) pairs. + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.FirstFitShufflePacker() +``` + + + + + + +**Bases:** [FirstFitPacker](#nemo_rl-data-packing-algorithms-FirstFitPacker) + +First-Fit Shuffle algorithm for sequence packing. + +This algorithm randomly shuffles the sequences and then places each +sequence into the first bin where it fits. + +Time complexity: O(n * m) for packing, where n is the number of sequences +and m is the number of bins. + + + + + + +```python +nemo_rl.data.packing.algorithms.FirstFitShufflePacker._prepare_sequences( + sequence_lengths: typing.List[int] +) -> typing.List[typing.Tuple[int, int]] +``` + + + + + + +Prepare sequences for packing by randomly shuffling them. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[Tuple[int, int]]` + +A list of (length, index) pairs in random order. + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.ModifiedFirstFitDecreasingPacker() +``` + + + + + + +**Bases:** [SequencePacker](#nemo_rl-data-packing-algorithms-SequencePacker) + +Modified First-Fit Decreasing (MFFD) algorithm for sequence packing. + +This algorithm implements the Johnson & Garey (1985) Modified First-Fit-Decreasing +heuristic. It classifies items into four categories (large, medium, small, tiny) +and uses a sophisticated 5-phase packing strategy to achieve better bin utilization +than standard First-Fit Decreasing. + +The algorithm phases: +1. Classify items by size relative to bin capacity +2. Create one bin per large item +3. Add medium items to large bins (forward pass) +4. Add pairs of small items to bins with medium items (backward pass) +5. Greedily fit remaining items +6. Apply FFD to any leftovers + +Time complexity: O(n log n) for sorting + O(n * m) for packing, +where n is the number of sequences and m is the number of bins. + + + + + + +```python +nemo_rl.data.packing.algorithms.ModifiedFirstFitDecreasingPacker._classify_items( + items: typing.List[typing.Tuple[int, int]] +) -> typing.Tuple[typing.List[typing.Tuple[int, int]], typing.List[typing.Tuple[int, int]], typing.List[typing.Tuple[int, int]], typing.List[typing.Tuple[int, int]]] +``` + + + + + + +Split items into large / medium / small / tiny classes. + +Follows the classification used by Johnson & Garey: + large : (C/2, C] + medium : (C/3, C/2] + small : (C/6, C/3] + tiny : (0 , C/6] + +**Parameters:** + + +List of (index, size) tuples + + +**Returns:** `Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]` + +Tuple of four lists (large, medium, small, tiny) without additional sorting. + + + + + + + +```python +nemo_rl.data.packing.algorithms.ModifiedFirstFitDecreasingPacker._pack_implementation( + sequence_lengths: typing.List[int] +) -> typing.List[typing.List[int]] +``` + + + + + + +Pack sequences using the Modified First-Fit Decreasing algorithm. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[List[int]]` + +A list of bins, where each bin is a list of indices into the original + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.PackingAlgorithm +``` + + + + + + +**Bases:** `enum.Enum` + +Enum for supported sequence packing algorithms. + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.data.packing.algorithms.SequencePacker( + bin_capacity: int, + collect_metrics: bool = False, + min_bin_count: typing.Optional[int] = None, + bin_count_multiple: typing.Optional[int] = None +) +``` + + + + + + +Abstract + +Abstract base class for sequence packing algorithms. + +Sequence packing is the process of efficiently arranging sequences of different +lengths into fixed-capacity bins (batches) to maximize computational efficiency. + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker._adjust_bin_count( + bins: typing.List[typing.List[int]] +) -> typing.List[typing.List[int]] +``` + + + + + + +Adjust the number of bins to meet minimum and multiple constraints. + +This method preserves the existing bin packing as much as possible and only +moves sequences one at a time to create additional bins when needed. + +**Parameters:** + + +The original bins from the packing algorithm. + + +**Returns:** `List[List[int]]` + +Adjusted bins with minimal changes to meet constraints. + +**Raises:** + +- `ValueError`: If there aren't enough sequences to fill the required number of bins. + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker._create_indexed_lengths( + sequence_lengths: typing.List[int], + reverse: bool = False +) -> typing.List[typing.Tuple[int, int]] +``` + + + + + + +Create a list of (length, index) pairs from sequence lengths. + +**Parameters:** + + +A list of sequence lengths. + + + +Whether to sort in descending order (True) or ascending order (False). + + +**Returns:** `List[Tuple[int, int]]` + +A list of (length, index) pairs, optionally sorted. + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker._estimate_bins_needed( + sequence_lengths: typing.List[int] +) -> int +``` + + + + + + +Estimate the number of bins needed based on total length. + +**Parameters:** + + +A list of sequence lengths. + + +**Returns:** `int` + +Estimated number of bins needed. + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker._pack_implementation( + sequence_lengths: typing.List[int] +) -> typing.List[typing.List[int]] +``` + + + + + + +abstract + +Implementation of the packing algorithm. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[List[int]]` + +A list of bins, where each bin is a list of indices into the original + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker._validate_sequence_lengths( + sequence_lengths: typing.List[int] +) -> None +``` + + + + + + +Validate that all sequence lengths are within bin capacity. + +**Parameters:** + + +A list of sequence lengths to validate. + + +**Raises:** + +- `ValueError`: If any sequence length exceeds bin capacity. + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker.compute_metrics( + sequence_lengths: typing.List[int], + bins: typing.List[typing.List[int]] +) -> typing.Dict[str, float] +``` + + + + + + +Calculate metrics for a packing solution without updating the metrics tracker. + +**Parameters:** + + +List of sequence lengths + + + +List of bins, where each bin is a list of indices + + +**Returns:** `Dict[str, float]` + +Dictionary of packing metrics + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker.get_aggregated_metrics() -> typing.Dict[str, float] +``` + + + + + + +Get aggregated metrics across all packing operations. + +**Returns:** `Dict[str, float]` + +Dictionary of aggregated metrics, or empty dict if not collecting + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker.pack( + sequence_lengths: typing.List[int] +) -> typing.List[typing.List[int]] +``` + + + + + + +Pack sequences into bins and update metrics if enabled. + +**Parameters:** + + +A list of sequence lengths to pack. + + +**Returns:** `List[List[int]]` + +A list of bins, where each bin is a list of indices into the original + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker.print_metrics() -> None +``` + + + + + + +Print the current metrics in a formatted way. + + + + + + + +```python +nemo_rl.data.packing.algorithms.SequencePacker.reset_metrics() -> None +``` + + + + + + +Reset collected metrics. + + + + + + + + + +```python +nemo_rl.data.packing.algorithms.get_packer( + algorithm: typing.Union[nemo_rl.data.packing.algorithms.PackingAlgorithm, str], + bin_capacity: int, + collect_metrics: bool = False, + min_bin_count: typing.Optional[int] = None, + bin_count_multiple: typing.Optional[int] = None +) -> nemo_rl.data.packing.algorithms.SequencePacker +``` + + + + + + +Factory function to get a sequence packer based on the algorithm. + +**Parameters:** + + +The packing algorithm to use. Can be either a PackingAlgorithm enum value + or a string (case-insensitive) matching one of the enum names. + + + +The maximum capacity of each bin. + + + +Whether to collect metrics across multiple packing operations. + + + +Minimum number of bins to create, even if fewer would suffice. + If None, no minimum is enforced. + + + +The total number of bins must be a multiple of this value. + If None, no multiple constraint is enforced. + + +**Returns:** `SequencePacker` + +A SequencePacker instance for the specified algorithm. + +**Raises:** + +- `ValueError`: If the algorithm is not recognized. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx new file mode 100644 index 0000000..3f1b4d0 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/packing/metrics.mdx @@ -0,0 +1,177 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/packing/metrics +title: nemo_rl.data.packing.metrics +--- + +Metrics for evaluating sequence packing algorithms. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`PackingMetrics`](#nemo_rl-data-packing-metrics-PackingMetrics) | Class for tracking and computing metrics for sequence packing algorithms. | + +### API + + + + + +```python +class nemo_rl.data.packing.metrics.PackingMetrics() +``` + + + + + + +Class for tracking and computing metrics for sequence packing algorithms. + +This class provides methods to calculate various metrics that evaluate the +efficiency and effectiveness of sequence packing algorithms, such as bin +utilization, waste, and imbalance. + + + + + + +```python +nemo_rl.data.packing.metrics.PackingMetrics.calculate_stats_only( + sequence_lengths: typing.List[int], + bins: typing.List[typing.List[int]], + bin_capacity: int +) -> typing.Dict[str, float] +``` + + + + + + +Calculate metrics for a packing solution without updating the tracker. + +**Parameters:** + + +List of sequence lengths + + + +List of bins, where each bin is a list of indices + + + +Maximum capacity of each bin + + +**Returns:** `Dict[str, float]` + +Dictionary of metrics for this packing solution + + + + + + + +```python +nemo_rl.data.packing.metrics.PackingMetrics.get_aggregated_stats() -> typing.Dict[str, float] +``` + + + + + + +Get aggregated metrics across all packing operations. + +**Returns:** `Dict[str, float]` + +Dictionary of aggregated metrics + + + + + + + +```python +nemo_rl.data.packing.metrics.PackingMetrics.print_aggregated_stats() -> None +``` + + + + + + +Print the aggregated metrics in a formatted way. + + + + + + + +```python +nemo_rl.data.packing.metrics.PackingMetrics.reset() -> None +``` + + + + + + +Reset all metrics. + + + + + + + +```python +nemo_rl.data.packing.metrics.PackingMetrics.update( + sequence_lengths: typing.List[int], + bins: typing.List[typing.List[int]], + bin_capacity: int, + packing_time: typing.Optional[float] = None +) -> typing.Dict[str, float] +``` + + + + + + +Update metrics with a new packing solution. + +**Parameters:** + + +List of sequence lengths + + + +List of bins, where each bin is a list of indices + + + +Maximum capacity of each bin + + + +Optional time taken to compute the packing solution + + +**Returns:** `Dict[str, float]` + +Dictionary of metrics for this packing solution + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx new file mode 100644 index 0000000..7660a0d --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/processors.mdx @@ -0,0 +1,353 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/processors +title: nemo_rl.data.processors +--- + +Contains data processors for evaluation. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_construct_multichoice_prompt`](#nemo_rl-data-processors-_construct_multichoice_prompt) | Construct prompt from question and options. | +| [`helpsteer3_data_processor`](#nemo_rl-data-processors-helpsteer3_data_processor) | Process a HelpSteer3 preference datum into a DatumSpec for GRPO training. | +| [`math_data_processor`](#nemo_rl-data-processors-math_data_processor) | Process a datum dictionary (directly loaded from dataset) into a DatumSpec for the Math Environment. | +| [`math_hf_data_processor`](#nemo_rl-data-processors-math_hf_data_processor) | Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Reward Model Environment. | +| [`multichoice_qa_processor`](#nemo_rl-data-processors-multichoice_qa_processor) | Process a datum dictionary (directly loaded from dataset) into a DatumSpec for multiple-choice problems. | +| [`nemo_gym_data_processor`](#nemo_rl-data-processors-nemo_gym_data_processor) | Process a datum dictionary (directly loaded from dataset) into a DatumSpec for Nemo Gym. | +| [`preference_preprocessor`](#nemo_rl-data-processors-preference_preprocessor) | Process a datum dictionary for RM/DPO training. | +| [`register_processor`](#nemo_rl-data-processors-register_processor) | - | +| [`sft_processor`](#nemo_rl-data-processors-sft_processor) | Process a datum dictionary for SFT training. | +| [`vlm_hf_data_processor`](#nemo_rl-data-processors-vlm_hf_data_processor) | Process a datum dictionary (directly loaded from response_datasets/<dataset_name>.py) into a DatumSpec for the VLM Environment. | + +### Data + +[`PROCESSOR_REGISTRY`](#nemo_rl-data-processors-PROCESSOR_REGISTRY) + +[`TokenizerType`](#nemo_rl-data-processors-TokenizerType) + +### API + + + + + +```python +nemo_rl.data.processors._construct_multichoice_prompt( + prompt: str, + question: str, + options: dict[str, str] +) -> str +``` + + + + + + +Construct prompt from question and options. + + + + + + + + +```python +nemo_rl.data.processors.helpsteer3_data_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.processors.TokenizerType, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a HelpSteer3 preference datum into a DatumSpec for GRPO training. + +This function converts HelpSteer3 preference data to work with GRPO by: +1. Using the context as the prompt +2. Using the preferred completion as the target response +3. Creating a reward signal based on preference scores + + + + + + + + +```python +nemo_rl.data.processors.math_data_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.processors.TokenizerType, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary (directly loaded from dataset) into a DatumSpec for the Math Environment. + + + + + + + + +```python +nemo_rl.data.processors.math_hf_data_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.processors.TokenizerType, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Reward Model Environment. + + + + + + + + +```python +nemo_rl.data.processors.multichoice_qa_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.processors.TokenizerType, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary (directly loaded from dataset) into a DatumSpec for multiple-choice problems. + + + + + + + + +```python +nemo_rl.data.processors.nemo_gym_data_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer: nemo_rl.data.processors.TokenizerType, + max_seq_length: int | None, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary (directly loaded from dataset) into a DatumSpec for Nemo Gym. + + + + + + + + +```python +nemo_rl.data.processors.preference_preprocessor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.PreferenceDatumSpec +``` + + + + + + +Process a datum dictionary for RM/DPO training. + +**Examples:** + + + +```python +>>> from transformers import AutoTokenizer +>>> from nemo_rl.data.interfaces import TaskDataSpec +>>> from nemo_rl.data.processors import preference_preprocessor +>>> +>>> # Initialize tokenizer and task spec +>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") +>>> ## set a passthrough chat template for simplicity +>>> tokenizer.chat_template = "{% for message in messages %}{{ message['content'] }}{% endfor %}" +>>> task_spec = TaskDataSpec(task_name="test_preference") +>>> +>>> datum = { +... "context": [{"role": "user", "content": "What is 2+2?"}], +... "completions": [ +... {"rank": 0, "completion": [{"role": "assistant", "content": "4"}]}, +... {"rank": 1, "completion": [{"role": "assistant", "content": "5"}]} +... ] +... } +>>> +>>> processed = preference_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) # doctest: +ELLIPSIS + +... +>>> len(processed["message_log_chosen"]) +2 +>>> processed["message_log_chosen"][0]["content"] +'<|begin_of_text|>What is 2+2?' +>>> processed["message_log_chosen"][-1]["content"] +'4<|eot_id|>' +>>> processed["message_log_rejected"][-1]["content"] +'5<|eot_id|>' +>>> +>>> # context can also contain multiple turns +>>> datum = { +... "context": [{"role": "user", "content": "I have a question."}, {"role": "assistant", "content": "Sure!"}, {"role": "user", "content": "What is 2+2?"}], +... "completions": [ +... {"rank": 0, "completion": [{"role": "assistant", "content": "4"}]}, +... {"rank": 1, "completion": [{"role": "assistant", "content": "5"}]} +... ] +... } +>>> processed = preference_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) +>>> len(processed["message_log_chosen"]) +4 +>>> processed["message_log_chosen"][1]["content"] +'Sure!' +>>> processed["message_log_chosen"][-1]["content"] +'4<|eot_id|>' +>>> processed["message_log_rejected"][-1]["content"] +'5<|eot_id|>' +``` + + + + + + + + + + +```python +nemo_rl.data.processors.register_processor( + processor_name: str, + processor_function: nemo_rl.data.interfaces.TaskDataProcessFnCallable +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.data.processors.sft_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, + add_bos: bool = True, + add_eos: bool = True, + add_generation_prompt: bool = False +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary for SFT training. + + + + + + + + +```python +nemo_rl.data.processors.vlm_hf_data_processor( + datum_dict: dict[str, typing.Any], + task_data_spec: nemo_rl.data.interfaces.TaskDataSpec, + processor: transformers.AutoProcessor, + max_seq_length: int, + idx: int +) -> nemo_rl.data.interfaces.DatumSpec +``` + + + + + + +Process a datum dictionary (directly loaded from response_datasets/<dataset_name>.py) into a DatumSpec for the VLM Environment. + + + + + + + + +```python +nemo_rl.data.processors.PROCESSOR_REGISTRY: Dict[str, TaskDataProcessFnCallable] = cast(Dict[str, TaskDataProcessFnCallable], {'default': math_hf_data_processor, '... +``` + + + + + + + + + +```python +nemo_rl.data.processors.TokenizerType = PreTrainedTokenizerBase +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx new file mode 100644 index 0000000..386880c --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/data/utils.mdx @@ -0,0 +1,104 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/data/utils +title: nemo_rl.data.utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`setup_preference_data`](#nemo_rl-data-utils-setup_preference_data) | Setup preference data. | +| [`setup_response_data`](#nemo_rl-data-utils-setup_response_data) | Setup data with environments. | + +### API + + + + + +```python +nemo_rl.data.utils.setup_preference_data( + tokenizer: transformers.AutoTokenizer, + data_config: nemo_rl.data.DataConfig +) +``` + + + + + + +Setup preference data. + +This function is used to setup the preference data for the training and validation datasets. + +**Parameters:** + + +Tokenizer. + + + +Data config for preference dataset. + + +**Returns:** + +A tuple of (train dataset, validation dataset). + + + + + + + + +```python +nemo_rl.data.utils.setup_response_data( + tokenizer: transformers.AutoProcessor | transformers.AutoTokenizer, + data_config: nemo_rl.data.DataConfig, + env_configs: typing.Optional[dict[str, typing.Any]] = None, + is_vlm: bool = False +) -> typing.Union[tuple[nemo_rl.data.datasets.AllTaskProcessedDataset, typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset]], tuple[nemo_rl.data.datasets.AllTaskProcessedDataset, typing.Optional[nemo_rl.data.datasets.AllTaskProcessedDataset], dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], dict[str, nemo_rl.environments.interfaces.EnvironmentInterface]]] +``` + + + + + + +Setup data with environments. + +This function is used to setup the data and environments for the training and validation datasets. + +**Parameters:** + + +Tokenizer or processor. + + + +Data config. + + + +Environment configs. +If None, no environments will be created. This is used for: +- Algorithms like SFT which do not need environments. +- Environments like NeMo-Gym which need to handle the environment creation outside of this function. + + + +Whether to use VLM training or not. + + +**Returns:** `Union[tuple[AllTaskProcessedDataset, Optional[AllTaskProcessedDataset]], tuple[AllTaskProcessedDataset, Optional[AllTaskProcessedDataset], dict[str, EnvironmentInterface], dict[str, EnvironmentInterface]]]` + +If env_configs is not None: +A tuple of (train dataset, validation dataset, task to environment, task to validation environment). + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx new file mode 100644 index 0000000..d615e36 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed.mdx @@ -0,0 +1,17 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed +title: nemo_rl.distributed +--- + +## Submodules + +- **[`nemo_rl.distributed.batched_data_dict`](/nemo-rl/nemo_rl/distributed/batched_data_dict)** +- **[`nemo_rl.distributed.collectives`](/nemo-rl/nemo_rl/distributed/collectives)** +- **[`nemo_rl.distributed.model_utils`](/nemo-rl/nemo_rl/distributed/model_utils)** +- **[`nemo_rl.distributed.named_sharding`](/nemo-rl/nemo_rl/distributed/named_sharding)** +- **[`nemo_rl.distributed.ray_actor_environment_registry`](/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry)** +- **[`nemo_rl.distributed.stateless_process_group`](/nemo-rl/nemo_rl/distributed/stateless_process_group)** +- **[`nemo_rl.distributed.virtual_cluster`](/nemo-rl/nemo_rl/distributed/virtual_cluster)** +- **[`nemo_rl.distributed.worker_group_utils`](/nemo-rl/nemo_rl/distributed/worker_group_utils)** +- **[`nemo_rl.distributed.worker_groups`](/nemo-rl/nemo_rl/distributed/worker_groups)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx new file mode 100644 index 0000000..c134df2 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/batched_data_dict.mdx @@ -0,0 +1,671 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/batched_data_dict +title: nemo_rl.distributed.batched_data_dict +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BatchedDataDict`](#nemo_rl-distributed-batched_data_dict-BatchedDataDict) | - | +| [`DynamicBatchingArgs`](#nemo_rl-distributed-batched_data_dict-DynamicBatchingArgs) | Configuration settings for dynamic batching. | +| [`SequencePackingArgs`](#nemo_rl-distributed-batched_data_dict-SequencePackingArgs) | Configuration settings for sequence packing. | +| [`SlicedDataDict`](#nemo_rl-distributed-batched_data_dict-SlicedDataDict) | A specialized subclass of BatchedDataDict that represents a slice or shard of a larger batch. | + +### Data + +[`DictT`](#nemo_rl-distributed-batched_data_dict-DictT) + +### API + + + + + +```python +class nemo_rl.distributed.batched_data_dict.BatchedDataDict( + args = (), + kwargs = {} +) +``` + + + + + + +**Bases:** `UserDict`, `Generic[DictT]` + + + + + +Get the batch size of the batch. + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.all_gather( + group: torch.distributed.ProcessGroup +) -> typing_extensions.Self +``` + + + + + + +Gathers batches with possibly jagged leading dimensions across the DP ranks. + +If using reshard, it will treat PP as DP ranks. +Works with data that is either tensors or string lists. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.chunk( + rank: int, + chunks: int +) -> nemo_rl.distributed.batched_data_dict.SlicedDataDict +``` + + + + + + +Chunks a global batch into 'chunks' splits and returns the 'rank'th split batch=[A A A B B B D D E], rank=2, chunks=3 -> [D D E]. + +Requires all leading dimensions of tensors and lengths of lists to be the same over the batch +and the chunks must divide batch size. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.from_batches( + batches: typing.Sequence[typing.Mapping[typing.Any, typing.Any]], + pad_value_dict: typing.Optional[dict[str, int | float]] = None +) -> typing_extensions.Self +``` + + + + + + +classmethod + +Given a list of batches, stack the tensors/lists within and put them in a single dictionary. + +Pad sequences to the max length in the batch using either 0(default) or a non-default value for a given key provided in pad_value_dict. + +**Parameters:** + + +A list of dictionaries, each containing a batch of data. + + + +An optional dict mapping keys to non-default(0) padding values. + + +**Returns:** `Self` + +A new BatchedDataDict containing the stacked data. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_batch( + batch_idx, + batch_size = None +) -> nemo_rl.distributed.batched_data_dict.SlicedDataDict +``` + + + + + + +Slices a subbatch from the batch. + +**Parameters:** + + +the batch index to slice + + + +the size of the batch to be sliced + + +**Returns:** `SlicedDataDict` + +A new BatchedDataDict containing the sliced data + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_dict() -> dict[typing.Any, typing.Any] +``` + + + + + + +Get the underlying data dictionary. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_microbatch_iterator_dynamic_shapes_len() -> int +``` + + + + + + +Get the length of the microbatch iterator for dynamic shapes. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_microbatch_iterator_for_packable_sequences_len() -> tuple[int, int] +``` + + + + + + +Get the length of the microbatch iterator for sequence packing and the max packed seqlen. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.get_multimodal_dict( + as_tensors: bool = False, + device: typing.Optional[torch.device] = None +) -> dict[str, typing.Any] +``` + + + + + + +Return a regular dict of tensors or packed multimodal data items. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.make_microbatch_iterator( + microbatch_size: int +) -> typing.Iterator[nemo_rl.distributed.batched_data_dict.SlicedDataDict] +``` + + + + + + +Make an iterator over the batch that yields microbatches of size microbatch_size. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.make_microbatch_iterator_for_packable_sequences() -> typing.Iterator[nemo_rl.distributed.batched_data_dict.SlicedDataDict] +``` + + + + + + +Make an iterator over the batch that yields microbatches that can be packed into a given max_tokens_per_microbatch. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.make_microbatch_iterator_with_dynamic_shapes( + sequence_dim: int = 1 +) -> typing.Iterator[nemo_rl.distributed.batched_data_dict.SlicedDataDict] +``` + + + + + + +Makes an iterator that yields microbatchs of dynamic batch and sequence sizes. + +**Parameters:** + + +the index of the sequence dim for all tensors in the data dict + + +**Returns:** `Iterator[SlicedDataDict]` + +Iterator["SlicedDataDict"]: An iterator that yield dynamic microbatches + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.reorder_data( + reorded_indices: list[int] +) +``` + + + + + + +Reorders the data along the batch dimension by the given indices. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.repeat_interleave( + num_repeats: int +) -> typing_extensions.Self +``` + + + + + + +Repeats the batch num_repeats times. + +For each element in the batch, repeat each value num_repeats times. +i.e: +{"key": torch.tensor([1, 2, 3]), "other_key": [1, 2, 3]} -> {"key": torch.tensor([1, 1, 2, 2, 3, 3]), "other_key": [1, 1, 2, 2, 3, 3]} + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.select_indices( + indices: typing.Union[list[int], torch.Tensor] +) -> typing_extensions.Self +``` + + + + + + +Selects specific rows from the batch based on indices. + +**Parameters:** + + +A list or tensor of integer indices to select. + + +**Returns:** `Self` + +A new BatchedDataDict containing only the selected rows. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.shard_by_batch_size( + shards: int, + batch_size: typing.Optional[int] = None, + allow_uneven_shards: bool = False, + dynamic_batching_args: typing.Optional[nemo_rl.distributed.batched_data_dict.DynamicBatchingArgs] = None, + sequence_packing_args: typing.Optional[nemo_rl.distributed.batched_data_dict.SequencePackingArgs] = None +) -> list[nemo_rl.distributed.batched_data_dict.SlicedDataDict] | tuple[list[nemo_rl.distributed.batched_data_dict.SlicedDataDict], list[int]] +``` + + + + + + +Shards a batch by first dividing it into chunks of size batch_size, then further dividing each chunk into shards equal parts. Finally aggregates the sub-shards by their position. + +If batch_size is None, there will be no chunking beforehand (will default to the total batch size). + +For example, with data [A A B B C C D D], batch_size=2, shards=2: +- Element 0: [A B C D] (first elements from each chunk) +- Element 1: [A B C D] (second elements from each chunk) + +Examples: + + +```python +>>> from nemo_rl.distributed.batched_data_dict import BatchedDataDict +>>> # Create a batch of two message logs with different lengths +>>> batch = BatchedDataDict({ +... 'problem_id': [0, 0, 1, 1, 2, 2, 3, 3], +... 'arbitrary_data': [1, 2, 3, 4, 5, 6, 7, 8] +... }) +>>> shards = batch.shard_by_batch_size(shards=2) +>>> shards +[{'problem_id': [0, 0, 1, 1], 'arbitrary_data': [1, 2, 3, 4]}, {'problem_id': [2, 2, 3, 3], 'arbitrary_data': [5, 6, 7, 8]}] +>>> # Now say that I'm training with a GBS of 4 and I want to take gradients steps on problems 0 and 1 before 2 and 3 (problems are repeated because GRPO) +>>> # In the current case, problems 0 and 2 will be trained on first since they're the first elements in each DP rank's batch. +>>> # So, we'll use the batch_size argument to split the batch into chunks of size 4 first. +>>> shards = batch.shard_by_batch_size(shards=2, batch_size=4) +>>> shards +[{'problem_id': [0, 0, 2, 2], 'arbitrary_data': [1, 2, 5, 6]}, {'problem_id': [1, 1, 3, 3], 'arbitrary_data': [3, 4, 7, 8]}] +>>> # Now, the ranks have 0 and 1 first so when they split their batches into microbatches (of size 2 since GBS=4 and DP=2), they'll train on 0 and 1 first. +>>> # Another way to use this function is with the 'allow_uneven_shards' flag, which allows the last shard to be smaller than the others when necessary. +>>> # This is necessary in multi-turn rollouts when some sequences terminate early, leaving unclean batch sizes. +>>> batch = BatchedDataDict({ +... 'problem_id': [0, 1, 2, 3, 4], +... 'arbitrary_data': [10, 11, 12, 13, 14] +... }) +>>> shards = batch.shard_by_batch_size(shards=2, allow_uneven_shards=True) +>>> shards +[{'problem_id': [0, 1, 2], 'arbitrary_data': [10, 11, 12]}, {'problem_id': [3, 4], 'arbitrary_data': [13, 14]}] +>>> # This is incompatible with the batch_size argument +``` + + + +**Parameters:** + + +The number of shards to divide each batch_size chunk into. + + + +The size of each initial chunk. + + + +Whether to allow shards to be unevenly sized. + If True, the last shard may be smaller than the others. + + + +If passed, preprocess batch for dynamic batching. This + dict requires four keys: + 1. max_tokens_per_microbatch (int): the maximum + number of tokens in a microbatch + 2. sequence_length_round (int): round each all + sequence lengths to this multiple + 3. input_key (str): the key in the batch + which holds input ids. + 4. input_lengths_key (str): the key in the batch + which holds the sequence length per value. + The sequence dim index is assumed to be 1. + Cannot be passed with sequence_packing_args. + + + +If passed, preprocess batch for sequence packing. This + dict requires five keys: + 1. max_tokens_per_microbatch (int): the maximum + number of tokens in a microbatch + 2. input_key (str): the key in the batch + which holds input ids. + 3. input_lengths_key (str): the key in the batch + which holds the sequence length per value. + The sequence dim index is assumed to be 1. + 4. algorithm (str): the algorithm to use for sequence packing. + 5. sequence_length_pad_multiple (int): the multiple to pad each sequence to. + With CP enabled, this should be set to a multiple of 2*CP and SP. + Cannot be passed with dynamic_batching_args. + + +**Returns:** `list[SlicedDataDict] | tuple[list[SlicedDataDict], list[int]]` + +list[BatchedDataDict]: A list of BatchedDataDicts, length equal to shards. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.slice( + start: int, + end: int +) -> nemo_rl.distributed.batched_data_dict.SlicedDataDict +``` + + + + + + +Slices the batch from start to end. + +**Parameters:** + + +Starting index (inclusive) + + + +Ending index (exclusive) + + +**Returns:** `SlicedDataDict` + +A new BatchedDataDict containing the sliced data + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.to( + device: str | torch.device +) -> typing_extensions.Self +``` + + + + + + +Move tensors in batched dict to device. + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.BatchedDataDict.truncate_tensors( + dim: int, + truncated_len: int +) +``` + + + + + + +Truncates tensors in this dict of a given dim to a given length. + + + + + + + + + +```python +class nemo_rl.distributed.batched_data_dict.DynamicBatchingArgs +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration settings for dynamic batching. + +Pass this to 'shard_by_batch_size()' to preprocess batches for dynamic batching. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.batched_data_dict.SequencePackingArgs +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration settings for sequence packing. + +Pass this to 'shard_by_batch_size()' to preprocess batches for sequence packing. + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.batched_data_dict.SlicedDataDict() +``` + + + + + + +**Bases:** [BatchedDataDict](#nemo_rl-distributed-batched_data_dict-BatchedDataDict) + +A specialized subclass of BatchedDataDict that represents a slice or shard of a larger batch. + +This class provides a distinct type to differentiate between full batches and sliced/sharded batches, which can be helpful for +type checking. + + + + + + + + +```python +nemo_rl.distributed.batched_data_dict.DictT = TypeVar('DictT', bound=(Mapping[str, Any])) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx new file mode 100644 index 0000000..5d756ce --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/collectives.mdx @@ -0,0 +1,108 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/collectives +title: nemo_rl.distributed.collectives +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`gather_jagged_object_lists`](#nemo_rl-distributed-collectives-gather_jagged_object_lists) | Gathers jagged lists of picklable objects from all ranks and flattens them into a single list. | +| [`rebalance_nd_tensor`](#nemo_rl-distributed-collectives-rebalance_nd_tensor) | Takes tensors with variable leading sizes (at dim=0) and stacks them into a single tensor. | + +### Data + +[`T`](#nemo_rl-distributed-collectives-T) + +### API + + + + + +```python +nemo_rl.distributed.collectives.gather_jagged_object_lists( + local_objects: list[nemo_rl.distributed.collectives.T], + group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> list[nemo_rl.distributed.collectives.T] +``` + + + + + + +Gathers jagged lists of picklable objects from all ranks and flattens them into a single list. + +This function handles the case where different GPUs have lists of different lengths +and combines them into a single list containing all objects from all ranks. + +For example, with 3 GPUs: + GPU0: [obj0, obj1] + GPU1: [obj2, obj3, obj4] + GPU2: [obj5] + +WARNING: synchronous + +**Parameters:** + + +List of objects to gather from current rank + + + +Optional process group + + +**Returns:** `list[T]` + +Flattened list of all objects from all ranks in order [rank0, rank1, ...] + + + + + + + + +```python +nemo_rl.distributed.collectives.rebalance_nd_tensor( + tensor: torch.Tensor, + group: typing.Optional[torch.distributed.ProcessGroup] = None +) -> torch.Tensor +``` + + + + + + +Takes tensors with variable leading sizes (at dim=0) and stacks them into a single tensor. + +This function handles the case where different GPUs have tensors with different batch sizes +and combines them into a single balanced tensor across all ranks. + +For example, with 3 GPUs: + GPU0: tensor of shape [3, D] + GPU1: tensor of shape [5, D] + GPU2: tensor of shape [2, D] + +NOTE: assumes all other (i.e., non-zero) dimensions are equal. + + + + + + + + +```python +nemo_rl.distributed.collectives.T = TypeVar('T') +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx new file mode 100644 index 0000000..57539ab --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/model_utils.mdx @@ -0,0 +1,851 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/model_utils +title: nemo_rl.distributed.model_utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AllGatherCPTensor`](#nemo_rl-distributed-model_utils-AllGatherCPTensor) | - | +| [`ChunkedDistributedEntropy`](#nemo_rl-distributed-model_utils-ChunkedDistributedEntropy) | Compute H_all = sum_v p_v log p_v across TP with chunking over sequence. | +| [`ChunkedDistributedGatherLogprob`](#nemo_rl-distributed-model_utils-ChunkedDistributedGatherLogprob) | Compute distributed log-softmax once and gather logprobs at given global indices. | +| [`ChunkedDistributedLogprob`](#nemo_rl-distributed-model_utils-ChunkedDistributedLogprob) | Custom autograd function for computing log probabilities in a distributed setting. | +| [`DistributedLogprob`](#nemo_rl-distributed-model_utils-DistributedLogprob) | Custom autograd function for computing log probabilities in a distributed setting. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_compute_distributed_log_softmax`](#nemo_rl-distributed-model_utils-_compute_distributed_log_softmax) | Compute a stable distributed log softmax across tensor parallel workers. | +| [`_get_tokens_on_this_cp_rank`](#nemo_rl-distributed-model_utils-_get_tokens_on_this_cp_rank) | Get tokens on this context parallelism rank. | +| [`allgather_cp_sharded_tensor`](#nemo_rl-distributed-model_utils-allgather_cp_sharded_tensor) | - | +| [`distributed_vocab_topk`](#nemo_rl-distributed-model_utils-distributed_vocab_topk) | Compute global top-k over TP-sharded vocabulary logits. | +| [`dtensor_from_parallel_logits_to_logprobs`](#nemo_rl-distributed-model_utils-dtensor_from_parallel_logits_to_logprobs) | Get log probabilities from TP+CP sharded vocab logits. | +| [`from_parallel_logits_to_logprobs`](#nemo_rl-distributed-model_utils-from_parallel_logits_to_logprobs) | Get log probabilities from TP+CP sharded vocab logits. | +| [`from_parallel_logits_to_logprobs_packed_sequences`](#nemo_rl-distributed-model_utils-from_parallel_logits_to_logprobs_packed_sequences) | Get log probabilities from TP sharded vocab logits for packed sequences. | +| [`gather_logits_at_global_indices`](#nemo_rl-distributed-model_utils-gather_logits_at_global_indices) | Gather student logits at given global token indices under TP+CP sharding. | +| [`get_logprobs_from_vocab_parallel_logits`](#nemo_rl-distributed-model_utils-get_logprobs_from_vocab_parallel_logits) | Computes log probabilities from vocabulary-parallel logits. | + +### API + + + + + +```python +class nemo_rl.distributed.model_utils.AllGatherCPTensor() +``` + + + + + + +**Bases:** `Function` + + + + + +```python +nemo_rl.distributed.model_utils.AllGatherCPTensor.backward( + ctx, + grad_output +) +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.model_utils.AllGatherCPTensor.forward( + ctx, + tensor, + cp_group: torch.distributed.ProcessGroup, + seq_dim = 1 +) +``` + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.model_utils.ChunkedDistributedEntropy() +``` + + + + + + +**Bases:** `Function` + +Compute H_all = sum_v p_v log p_v across TP with chunking over sequence. + +Forward returns [B, S] tensor of global entropy; backward propagates through logits. + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedEntropy.backward( + ctx: typing.Any, + grad_outputs: torch.Tensor = () +) -> tuple[torch.Tensor, None, None, None] +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedEntropy.forward( + ctx: typing.Any, + vocab_parallel_logits: torch.Tensor, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False +) -> torch.Tensor +``` + + + + + + +staticmethod + + + + + + + + + +```python +class nemo_rl.distributed.model_utils.ChunkedDistributedGatherLogprob() +``` + + + + + + +**Bases:** `Function` + +Compute distributed log-softmax once and gather logprobs at given global indices. + +Forward computes per-chunk distributed log-softmax across TP, gathers selected +log probabilities at the provided global indices (shape [B, S, K]), and returns +a tensor of shape [B, S, K]. + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedGatherLogprob.backward( + ctx: typing.Any, + grad_outputs: torch.Tensor = () +) -> tuple[torch.Tensor, None, None, None, None, None, None] +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedGatherLogprob.forward( + ctx: typing.Any, + vocab_parallel_logits: torch.Tensor, + global_indices: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False +) -> torch.Tensor +``` + + + + + + +staticmethod + + + + + + + + + +```python +class nemo_rl.distributed.model_utils.ChunkedDistributedLogprob() +``` + + + + + + +**Bases:** `Function` + +Custom autograd function for computing log probabilities in a distributed setting. + +The log probabilities computation is chunked in the sequence dimension +to mitigate GPU OOM (especially during backward pass). +In addition, logits casting from float16 or bfloat16 -> float32 is performed +inside the chunk loop to avoid materializing a whole float32 logits tensor. + +Adapted from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedLogprob.backward( + ctx: typing.Any, + grad_outputs: torch.Tensor = () +) -> tuple[torch.Tensor, None, None, None, None, None, None] +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.distributed.model_utils.ChunkedDistributedLogprob.forward( + ctx: typing.Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False +) -> torch.Tensor +``` + + + + + + +staticmethod + + + + + + + + + +```python +class nemo_rl.distributed.model_utils.DistributedLogprob() +``` + + + + + + +**Bases:** `Function` + +Custom autograd function for computing log probabilities in a distributed setting. + +Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + + + + + + +```python +nemo_rl.distributed.model_utils.DistributedLogprob.backward( + ctx: typing.Any, + grad_outputs: torch.Tensor = () +) -> tuple[torch.Tensor, None, None, None, None, None, None] +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.distributed.model_utils.DistributedLogprob.forward( + ctx: typing.Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + group: torch.distributed.ProcessGroup, + inference_only: bool = False +) -> torch.Tensor +``` + + + + + + +staticmethod + + + + + + + + + +```python +nemo_rl.distributed.model_utils._compute_distributed_log_softmax( + vocab_parallel_logits: torch.Tensor, + group: torch.distributed.ProcessGroup +) -> torch.Tensor +``` + + + + + + +Compute a stable distributed log softmax across tensor parallel workers. + +Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265 + +**Parameters:** + + +Logits tensor with shape [batch_size, seq_length, vocab_size//TP] +where TP is the tensor parallel size. + + + +Process group for the all-reduce operations. + + +**Returns:** `torch.Tensor` + +torch.Tensor: Log softmax output with the same shape as input, but values represent +log probabilities normalized across the full vocabulary dimension. + + + + + + + + +```python +nemo_rl.distributed.model_utils._get_tokens_on_this_cp_rank( + input_ids: torch.Tensor, + cp_rank: int, + cp_size: int, + seq_dim: int = 1 +) -> torch.Tensor +``` + + + + + + +Get tokens on this context parallelism rank. + +Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1. + +**Parameters:** + + +Input token IDs [seq_length, ] + + + +Context parallelism rank + + + +Context parallelism size + + +**Returns:** `torch.Tensor` + +Tokens on this context parallelism rank [1, seq_length // cp_size] + + + + + + + + +```python +nemo_rl.distributed.model_utils.allgather_cp_sharded_tensor( + tensor, + cp_group, + seq_dim = 1 +) +``` + + + + + + + + + + + + + +```python +nemo_rl.distributed.model_utils.distributed_vocab_topk( + vocab_parallel_logits: torch.Tensor, + k: int, + tp_group: torch.distributed.ProcessGroup, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: typing.Optional[int] = None +) -> tuple[torch.Tensor, torch.Tensor] +``` + + + + + + +Compute global top-k over TP-sharded vocabulary logits. + +**Parameters:** + + +[B, S, V_local] + + + +number of top tokens to select globally + + + +tensor-parallel process group + + + +global vocab start for this rank (inclusive) + + + +global vocab end for this rank (exclusive) + + + +optional chunk along sequence dim to bound memory + + +**Returns:** `torch.Tensor` + +[B, S, k] + + + + + + + + +```python +nemo_rl.distributed.model_utils.dtensor_from_parallel_logits_to_logprobs( + vocab_parallel_logits: torch.Tensor, + target: torch.distributed.tensor.DTensor | torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + seq_index: typing.Optional[torch.Tensor] = None, + chunk_size: typing.Optional[int] = None +) -> torch.Tensor +``` + + + + + + +Get log probabilities from TP+CP sharded vocab logits. + +**Parameters:** + + +Logits distributed across tensor parallel workers, +with shape [batch_size, seq_len, vocab_size/tp_size]. + + + +Target token indices with shape [batch_size, seq_len]. +NOTE: Must be the unmodified targets as this function will shift them internally. + + + +Starting vocabulary index for this worker's partition. + + + +Ending vocabulary index for this worker's partition. + + + +Process group for distributed communication. + + + +If True, tensors won't be saved for backward pass. Defaults to False. + + + +Sequence index tensor with shape [seq_len]. +It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension. + + + +Sequence dimension chunk size for computing the log probabilities. + + +**Returns:** `torch.Tensor` + +torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. +The sequence dimension is reduced by 1 due to the target shifting. + + + + + + + + +```python +nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + cp_group: typing.Optional[torch.distributed.ProcessGroup] = None, + chunk_size: typing.Optional[int] = None +) -> torch.Tensor +``` + + + + + + +Get log probabilities from TP+CP sharded vocab logits. + +Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 + +**Parameters:** + + +Logits tensor with shape [batch_size, seq_len // CP, vocab_size // TP] +where TP is the tensor parallel size. + + + +Target token indices with shape [batch_size, seq_len]. +NOTE: Must be the unmodified targets as this function will shift them internally. + + + +Starting vocabulary index for this worker's partition. + + + +Ending vocabulary index for this worker's partition. + + + +Process group for distributed communication. + + + +If True, tensors won't be saved for backward pass. Defaults to False. + + + +Context parallelism process group. Defaults to None. + + + +Sequence dimension chunk size for computing the log probabilities. + + +**Returns:** `torch.Tensor` + +torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. +The sequence dimension is reduced by 1 due to the target shifting. + + + + + + + + +```python +nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs_packed_sequences( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + unpacked_seqlen: int, + vocab_start_index: int, + vocab_end_index: int, + group: torch.distributed.ProcessGroup, + inference_only: bool = False, + cp_group: typing.Optional[torch.distributed.ProcessGroup] = None, + chunk_size: typing.Optional[int] = None +) -> torch.Tensor +``` + + + + + + +Get log probabilities from TP sharded vocab logits for packed sequences. + +**Parameters:** + + +Packed logits tensor with shape [1, T // CP, vocab_size//TP] +where T is the total number of tokens across all packed sequences. + + + +Packed target token indices with shape [1, T]. +NOTE: Must be the unmodified targets as this function will shift them internally. + + + +Cumulative sequence lengths tensor with shape [batch_size + 1]. +cu_seqlens[i] indicates the start position of sequence i in the packed format. + + + +The length of the unpacked sequence tensor. + + + +Starting vocabulary index for this worker's partition. + + + +Ending vocabulary index for this worker's partition. + + + +Process group for distributed communication. + + + +If True, tensors won't be saved for backward pass. Defaults to False. + + + +Context parallelism process group. Defaults to None. + + + +Sequence dimension chunk size for computing the log probabilities. + + +**Returns:** `torch.Tensor` + +torch.Tensor: Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. +The total length is reduced by batch_size due to target shifting (one token per sequence). + + + + + + + + +```python +nemo_rl.distributed.model_utils.gather_logits_at_global_indices( + vocab_parallel_logits: torch.Tensor, + global_indices: torch.Tensor, + tp_group: typing.Optional[torch.distributed.ProcessGroup] = None, + cp_group: typing.Optional[torch.distributed.ProcessGroup] = None, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: typing.Optional[int] = None +) -> torch.Tensor +``` + + + + + + +Gather student logits at given global token indices under TP+CP sharding. + +Differentiable w.r.t. vocab_parallel_logits. + +**Parameters:** + + +[B, S_cp, V_local] where S_cp is CP sharded sequence length + + + +[B, S_full, k] where S_full is full sequence length + + + +Optional tensor-parallel process group. If None, treats logits as full-vocab (no TP) and skips TP all-reduce. + + + +global vocab start for this rank (inclusive) + + + +global vocab end for this rank (exclusive) + + + +optional chunk along sequence dim to bound memory + + + +Optional context-parallel process group + + +**Returns:** `torch.Tensor` + +[B, S_full, k] + + + + + + + + +```python +nemo_rl.distributed.model_utils.get_logprobs_from_vocab_parallel_logits( + vocab_parallel_logits: torch.distributed.tensor.DTensor, + input_ids: torch.Tensor | torch.distributed.tensor.DTensor, + seq_index: typing.Optional[torch.Tensor] = None, + chunk_size: typing.Optional[int] = None +) +``` + + + + + + +Computes log probabilities from vocabulary-parallel logits. + +This function takes logits that are sharded across the vocabulary dimension (tensor parallel) +and computes the log probabilities for the given input IDs. + +**Parameters:** + + +Logits distributed across tensor parallel workers, +with shape [batch_size, seq_len, vocab_size/tp_size]. + + + +Input token IDs for which to compute log probabilities, +with shape [batch_size, seq_len]. + + + +Sequence index for the input IDs, +with shape [sequence_length]. + + + +Sequence dimension chunk size for computing log probabilities. + + +**Returns:** + +torch.Tensor: Log probabilities for the given input IDs. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx new file mode 100644 index 0000000..20d0bde --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/named_sharding.mdx @@ -0,0 +1,236 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/named_sharding +title: nemo_rl.distributed.named_sharding +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`NamedSharding`](#nemo_rl-distributed-named_sharding-NamedSharding) | Represents an N-dimensional arrangement of ranks with named axes, facilitating data sharding, replication, and collection based on these axes. | + +### API + + + + + +```python +class nemo_rl.distributed.named_sharding.NamedSharding( + layout: typing.Sequence[typing.Any] | numpy.ndarray, + names: list[str] +) +``` + + + + + + +Represents an N-dimensional arrangement of ranks with named axes, facilitating data sharding, replication, and collection based on these axes. + + + + + + + + + + + + +Returns the underlying NumPy array representing the layout. + + + +Returns the names of the axes. + + + +Returns the number of dimensions. + + + +Returns the shape of the rank layout. + + + +Returns the total number of ranks. + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.__eq__( + other: object +) -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.__repr__() -> str +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.get_axis_index( + name: str +) -> int +``` + + + + + + +Gets the numerical index of a named axis. + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.get_axis_size( + name: str +) -> int +``` + + + + + + +Gets the size of a named axis. + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.get_ranks( + kwargs: int = {} +) -> typing.Union[nemo_rl.distributed.named_sharding.NamedSharding, int] +``` + + + + + + +Gets the ranks corresponding to specific indices along named axes. + +**Parameters:** + + +Keyword arguments where the key is the axis name (e.g., "dp", "tp") + and the value is the index along that axis. + + +**Returns:** `Union[NamedSharding, int]` + +A new NamedSharding instance representing the subset of ranks. + +**Raises:** + +- `ValueError`: If an invalid axis name is provided or if an index is out of bounds. + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.get_ranks_by_coord( + coords: int = {} +) -> list[int] +``` + + + + + + +Gets all ranks that match the specified coordinates for named axes. + +**Parameters:** + + +Keyword arguments where the key is the axis name (e.g., "dp", "tp") + and the value is the integer coordinate along that axis. + Axes not specified will match all coordinates along that axis. + + +**Returns:** `list[int]` + +A sorted list of unique rank integers that match the given coordinate criteria. + +**Raises:** + +- `ValueError`: If an invalid axis name is provided. + + + + + + + +```python +nemo_rl.distributed.named_sharding.NamedSharding.get_worker_coords( + worker_id: int +) -> dict[str, int] +``` + + + + + + +Gets the coordinates of a specific worker ID in the sharding layout. + +**Parameters:** + + +The integer ID of the worker. + + +**Returns:** `dict[str, int]` + +A dictionary mapping axis names to their integer coordinates for the given worker_id. + +**Raises:** + +- `ValueError`: If the worker_id is not found in the layout. + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx new file mode 100644 index 0000000..5396879 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/ray_actor_environment_registry.mdx @@ -0,0 +1,105 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/ray_actor_environment_registry +title: nemo_rl.distributed.ray_actor_environment_registry +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_actor_python_env`](#nemo_rl-distributed-ray_actor_environment_registry-get_actor_python_env) | - | + +### Data + +[`ACTOR_ENVIRONMENT_REGISTRY`](#nemo_rl-distributed-ray_actor_environment_registry-ACTOR_ENVIRONMENT_REGISTRY) + +[`MCORE_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-MCORE_EXECUTABLE) + +[`SGLANG_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-SGLANG_EXECUTABLE) + +[`USE_SYSTEM_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-USE_SYSTEM_EXECUTABLE) + +[`VLLM_EXECUTABLE`](#nemo_rl-distributed-ray_actor_environment_registry-VLLM_EXECUTABLE) + +### API + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.get_actor_python_env( + actor_class_fqn: str +) -> str +``` + + + + + + + + + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.ACTOR_ENVIRONMENT_REGISTRY: dict[str, str] = {'nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker': VLLM_EXECUTA... +``` + + + + + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.MCORE_EXECUTABLE = PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.MCORE +``` + + + + + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.SGLANG_EXECUTABLE = PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.SGLANG +``` + + + + + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.USE_SYSTEM_EXECUTABLE = os.environ.get('NEMO_RL_PY_EXECUTABLES_SYSTEM', '0') == '1' +``` + + + + + + + + + +```python +nemo_rl.distributed.ray_actor_environment_registry.VLLM_EXECUTABLE = PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.VLLM +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx new file mode 100644 index 0000000..ebd4125 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/stateless_process_group.mdx @@ -0,0 +1,73 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/stateless_process_group +title: nemo_rl.distributed.stateless_process_group +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`StatelessProcessGroup`](#nemo_rl-distributed-stateless_process_group-StatelessProcessGroup) | - | + +### API + + + + + +```python +class nemo_rl.distributed.stateless_process_group.StatelessProcessGroup( + master_address: str, + port: int, + rank: int, + world_size: int +) +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.stateless_process_group.StatelessProcessGroup.broadcast( + tensor: torch.Tensor, + src: int, + stream: typing.Optional[torch.cuda.Stream] = None +) +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.stateless_process_group.StatelessProcessGroup.init_nccl_communicator( + device: int +) +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx new file mode 100644 index 0000000..108df41 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/virtual_cluster.mdx @@ -0,0 +1,514 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/virtual_cluster +title: nemo_rl.distributed.virtual_cluster +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ClusterConfig`](#nemo_rl-distributed-virtual_cluster-ClusterConfig) | - | +| [`GetGPUIDActor`](#nemo_rl-distributed-virtual_cluster-GetGPUIDActor) | Util actor class to return GPU id of the current worker. | +| [`PY_EXECUTABLES`](#nemo_rl-distributed-virtual_cluster-PY_EXECUTABLES) | - | +| [`RayVirtualCluster`](#nemo_rl-distributed-virtual_cluster-RayVirtualCluster) | Creates a virtual distributed cluster using Ray placement groups. | +| [`ResourceInsufficientError`](#nemo_rl-distributed-virtual_cluster-ResourceInsufficientError) | Exception raised when the cluster does not have enough resources to satisfy the requested configuration. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_free_port_local`](#nemo_rl-distributed-virtual_cluster-_get_free_port_local) | - | +| [`_get_node_ip_and_free_port`](#nemo_rl-distributed-virtual_cluster-_get_node_ip_and_free_port) | - | +| [`_get_node_ip_local`](#nemo_rl-distributed-virtual_cluster-_get_node_ip_local) | - | +| [`init_ray`](#nemo_rl-distributed-virtual_cluster-init_ray) | Initialise Ray. | + +### Data + +[`dir_path`](#nemo_rl-distributed-virtual_cluster-dir_path) + +[`git_root`](#nemo_rl-distributed-virtual_cluster-git_root) + +[`logger`](#nemo_rl-distributed-virtual_cluster-logger) + +### API + + + + + +```python +class nemo_rl.distributed.virtual_cluster.ClusterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.virtual_cluster.GetGPUIDActor() +``` + + + + + + +Util actor class to return GPU id of the current worker. + + + + + + +```python +nemo_rl.distributed.virtual_cluster.GetGPUIDActor.get_gpu_id() +``` + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.virtual_cluster.PY_EXECUTABLES() +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.virtual_cluster.RayVirtualCluster( + bundle_ct_per_node_list: list[int], + use_gpus: bool = True, + max_colocated_worker_groups: int = 1, + num_gpus_per_node: int = 8, + name: str = '', + placement_group_strategy: str = 'SPREAD' +) +``` + + + + + + +Creates a virtual distributed cluster using Ray placement groups. + +This class simplifies distributed training setup by: +- Creating placement groups that represent logical compute nodes +- Allocating GPU and CPU resources for distributed workers +- Managing communication between distributed processes + +- Bundle: A resource allocation unit (ex: 4 GPUs on a single node) +- Worker: A process that performs computation (model training/inference) +- Node: A physical or virtual machine containing multiple bundles + + + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.__del__() -> None +``` + + + + + + +Shutsdown the virtual cluster when the object is deleted or is garbage collected. + +This is an extra safety net in case the user forgets to call shutdown and the pointer to +the cluster is lost due to leaving a function scope. It's always recommended that the +user calls shutdown(). + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster._create_placement_groups_internal( + strategy: str, + use_unified_pg: bool = False +) -> list[ray.util.placement_group.PlacementGroup] +``` + + + + + + +Internal method to create placement groups without retry logic. + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster._get_sorted_bundle_indices() -> typing.Optional[list[int]] +``` + + + + + + +Gets the sorted bundle indices for the placement groups. + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups( + strategy: str | None = None, + use_unified_pg: bool = False +) -> list[ray.util.placement_group.PlacementGroup] +``` + + + + + + +Creates placement groups based on whether cross-node model parallelism is needed. + +**Parameters:** + + +Ray placement group strategy (defaults to self.placement_group_strategy) + + + +If True, create a single unified placement group. + If False, create per-node placement groups. + + +**Returns:** `list[PlacementGroup]` + +List of placement groups + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.get_available_address_and_port( + pg_idx: int, + bundle_idx: int +) -> tuple[str, int] +``` + + + + + + +Gets an available address and port for the given placement group index and bundle index. + +**Returns:** `tuple[str, int]` + +Tuple of (address, port) + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.get_master_address_and_port() -> tuple[str, int] +``` + + + + + + +Gets the master address and port for the distributed training setup. + +**Returns:** `tuple[str, int]` + +Tuple of (address, port) + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.get_placement_groups() -> list[ray.util.placement_group.PlacementGroup] +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.node_count() -> int +``` + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.shutdown() -> bool +``` + + + + + + +Cleans up and releases all resources associated with this virtual cluster. + +This includes removing all placement groups and resetting the internal state. + +This method is idempotent and can be safely called multiple times. + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.RayVirtualCluster.world_size() -> int +``` + + + + + + + + + + + + + + +```python +class nemo_rl.distributed.virtual_cluster.ResourceInsufficientError() +``` + + + + + + +Exception + +**Bases:** `Exception` + +Exception raised when the cluster does not have enough resources to satisfy the requested configuration. + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster._get_free_port_local() -> int +``` + + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster._get_node_ip_and_free_port() -> tuple[str, int] +``` + + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster._get_node_ip_local() -> str +``` + + + + + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.init_ray( + log_dir: typing.Optional[str] = None +) -> None +``` + + + + + + +Initialise Ray. + +Try to attach to an existing local cluster. +If that cluster uses the same CUDA_VISIBLE_DEVICES or Slurm managed tag we will reuse it. +Otherwise, we will detach and start a fresh local cluster. + +**Parameters:** + + +Optional directory to store Ray logs and temp files. + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.dir_path = os.path.dirname(os.path.abspath(__file__)) +``` + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.git_root = os.path.abspath(os.path.join(dir_path, '../..')) +``` + + + + + + + + + +```python +nemo_rl.distributed.virtual_cluster.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx new file mode 100644 index 0000000..8519d0f --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_group_utils.mdx @@ -0,0 +1,81 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/worker_group_utils +title: nemo_rl.distributed.worker_group_utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_nsight_config_if_pattern_matches`](#nemo_rl-distributed-worker_group_utils-get_nsight_config_if_pattern_matches) | Check if worker name matches patterns in NRL_NSYS_WORKER_PATTERNS and return nsight config. | +| [`recursive_merge_options`](#nemo_rl-distributed-worker_group_utils-recursive_merge_options) | Recursively merge extra options into default options using OmegaConf. | + +### API + + + + + +```python +nemo_rl.distributed.worker_group_utils.get_nsight_config_if_pattern_matches( + worker_name: str +) -> dict[str, typing.Any] +``` + + + + + + +Check if worker name matches patterns in NRL_NSYS_WORKER_PATTERNS and return nsight config. + +**Parameters:** + + +Name of the worker to check against patterns + + +**Returns:** `dict[str, Any]` + +Dictionary containing {"nsight": config} if pattern matches, empty dict otherwise + + + + + + + + +```python +nemo_rl.distributed.worker_group_utils.recursive_merge_options( + default_options: dict[str, typing.Any], + extra_options: dict[str, typing.Any] +) -> dict[str, typing.Any] +``` + + + + + + +Recursively merge extra options into default options using OmegaConf. + +**Parameters:** + + +Default options dictionary (lower precedence) + + + +Extra options provided by the caller (higher precedence) + + +**Returns:** `dict[str, Any]` + +Merged options dictionary with extra_options taking precedence over default_options + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx new file mode 100644 index 0000000..e0205d5 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/distributed/worker_groups.mdx @@ -0,0 +1,603 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/distributed/worker_groups +title: nemo_rl.distributed.worker_groups +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MultiWorkerFuture`](#nemo_rl-distributed-worker_groups-MultiWorkerFuture) | Container for Ray futures with associated worker information. | +| [`RayWorkerBuilder`](#nemo_rl-distributed-worker_groups-RayWorkerBuilder) | - | +| [`RayWorkerGroup`](#nemo_rl-distributed-worker_groups-RayWorkerGroup) | Manages a group of distributed Ray worker/actor processes that execute tasks in parallel. | + +### API + + + + + +```python +class nemo_rl.distributed.worker_groups.MultiWorkerFuture( + futures: list[ray.ObjectRef], + return_from_workers: typing.Optional[list[int]] = None, + called_workers: typing.Optional[list[int]] = None +) +``` + + + + + + +Dataclass + +Container for Ray futures with associated worker information. + + + + + + + + + + + + + + +```python +nemo_rl.distributed.worker_groups.MultiWorkerFuture.get_results( + worker_group: nemo_rl.distributed.worker_groups.RayWorkerGroup, + return_generators_as_proxies: bool = False +) -> list[typing.Any] +``` + + + + + + +Get results from the futures, optionally respecting tied workers. + +The method uses worker_group.worker_to_tied_group_index to identify which tied +worker group each worker belongs to, then selects only the first result from each group. + +**Parameters:** + + +The RayWorkerGroup that spawned the futures. The +mapping contained in worker_group.worker_to_tied_group_index +is required for the deduplication path. + + + +If True, and a future is an ObjectRefGenerator, + return the ObjectRefGenerator itself instead of consuming it. + + +**Returns:** `list[Any]` + +List of results + + + + + + + + + +```python +class nemo_rl.distributed.worker_groups.RayWorkerBuilder( + ray_actor_class_fqn: str, + args = (), + kwargs = {} +) +``` + + + + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerBuilder.__call__( + placement_group: ray.util.placement_group.PlacementGroup, + placement_group_bundle_index: int, + num_gpus: float | int, + bundle_indices: typing.Optional[tuple[int, list[int]]] = None, + extra_options: typing.Any = {} +) -> ray.actor.ActorHandle +``` + + + + + + +Create a Ray worker with the specified configuration. + +Order of precedence for worker options configuration (from lowest to highest): +1. Options passed by the user to __call__ (extra_options) +2. Options required by the worker via configure_worker (may override user options with warning) +3. Options set by the RayWorkerBuilder.__call__ (specifically scheduling strategy) + +If the worker needs to override user-provided options, it should log a warning +to inform the user about the change and the reason for it. + +**Parameters:** + + +Ray placement group for resource allocation + + + +Index of the bundle in the placement group + + + +Number of GPUs to allocate to this worker (can be fractional) + + + +Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) + + + +Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method) + + +**Returns:** `ray.actor.ActorHandle` + +A Ray actor reference to the created worker + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerBuilder.create_worker_async( + placement_group: ray.util.placement_group.PlacementGroup, + placement_group_bundle_index: int, + num_gpus: float | int, + bundle_indices: typing.Optional[tuple[int, list[int]]] = None, + extra_options: typing.Any = {} +) -> tuple[ray.ObjectRef, ray.actor.ActorHandle] +``` + + + + + + +Create a Ray worker asynchronously, returning futures. + +This method returns immediately with futures that can be awaited later. + +**Parameters:** + + +Ray placement group for resource allocation + + + +Index of the bundle in the placement group + + + +Number of GPUs to allocate to this worker (can be fractional) + + + +Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) + + + +Additional options to pass to the Ray actor + + +**Returns:** `tuple[ray.ObjectRef, ray.actor.ActorHandle]` + +Tuple of (worker_future, initializer_actor): +- worker_future: A Ray ObjectRef that will resolve to the worker actor +- initializer_actor: The initializer actor (needed to prevent GC) + + + + + + + + + +```python +class nemo_rl.distributed.worker_groups.RayWorkerGroup( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, + remote_worker_builder: nemo_rl.distributed.worker_groups.RayWorkerBuilder, + workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None, + name_prefix: str = '', + bundle_indices_list: typing.Optional[list[tuple[int, list[int]]]] = None, + sharding_annotations: typing.Optional[nemo_rl.distributed.named_sharding.NamedSharding] = None, + env_vars: dict[str, str] = {} +) +``` + + + + + + +Manages a group of distributed Ray worker/actor processes that execute tasks in parallel. + +This class creates and manages Ray actor instances that run on resources +allocated by a RayVirtualCluster. It handles: +- Worker creation and placement on specific GPU resources +- Setting up distributed training environment variables (rank, world size, etc.) +- Executing methods across all workers in parallel +- Collecting and aggregating results +- Support for tied worker groups where multiple workers process the same data + + + + + + + + + + + + +Number of data parallel shards. + + + + + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup._create_workers_from_bundle_indices( + remote_worker_builder: nemo_rl.distributed.worker_groups.RayWorkerBuilder, + bundle_indices_list: list[tuple[int, list[int]]], + env_vars: dict[str, str] = {} +) -> None +``` + + + + + + +Create workers based on explicit bundle indices for tied worker groups. + +**Parameters:** + + +Builder function for Ray actors + + + +List of (node_idx, local_bundle_indices) tuples, where each tuple + specifies a tied group with its node and local bundle indices. If the local_bundle_indices + spans multiple nodes, the node_idx will be the first node's index in the tied group. + + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.get_all_worker_results( + future_bundle: nemo_rl.distributed.worker_groups.MultiWorkerFuture, + return_generators_as_proxies: bool = False +) -> list[typing.Any] +``` + + + + + + +Get results from all workers, optionally filtering to get just one result per tied worker group. + +**Parameters:** + + +MultiWorkerFuture containing futures and worker information. + + + +If True, and a future in the bundle is an ObjectRefGenerator, + return the ObjectRefGenerator itself instead of consuming it. + + +**Returns:** `list[Any]` + +List of results, deduplicated as specified in the future_bundle + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.get_dp_leader_worker_idx( + dp_shard_idx: int +) -> int +``` + + + + + + +Returns the index of the primary worker for a given data parallel shard. + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.run_all_workers_multiple_data( + method_name: str, + args = (), + run_rank_0_only_axes: list[str] | None = None, + common_kwargs: typing.Optional[dict[str, typing.Any]] = None, + kwargs = {} +) -> list[ray.ObjectRef] +``` + + + + + + +Run a method on all workers in parallel with different data. + +**Parameters:** + + +Name of the method to call on each worker + + + +List of arguments to pass to workers/groups + e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]] + + + +List of named axes for which only rank 0 should run the method. + + + +Keyword arguments to pass to all workers + + + +Keyword arguments to pass to workers/groups + e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]} + + +**Returns:** `list[ray.ObjectRef]` + +list[ray.ObjectRef]: A list of ray futures + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.run_all_workers_sharded_data( + method_name: str, + args = (), + in_sharded_axes: list[str] | None = None, + replicate_on_axes: list[str] | None = None, + output_is_replicated: list[str] | None = None, + make_dummy_calls_to_free_axes: bool = False, + common_kwargs: typing.Optional[dict[str, typing.Any]] = None, + kwargs = {} +) -> nemo_rl.distributed.worker_groups.MultiWorkerFuture +``` + + + + + + +Run a method on all workers in parallel with sharded data. + +Axes in in_sharded_axes: Data is already split across these axes, so we just send the appropriate slice to each worker (along this axis) +Axes in replicate_on_axes: Data is replicated to all workers along these dimensions +Free axes (axes not in either list): Data is only sent to workers at index 0 of these axes + +**Parameters:** + + +Name of the method to call on each worker + + + +List of arguments to pass to workers/groups + e.g. [[arg1_for_worker_1, arg1_for_worker_2], [arg2_for_worker_1, arg2_for_worker_2]] + + + +List of axes that are sharded + + + +List of axes that are to be replicated + + + +List of axes along which the output is replicated (and we should just return the first result). + We also just return from rank 0 of free axes. + + + +Whether to make dummy calls (with None) to workers that + aren't rank 0 on 'free axes' (axes not in in_sharded_axes or replicate_on_axes). + + + +Keyword arguments to pass to all workers + + + +Keyword arguments to pass to workers/groups + e.g. {"key1": [value_for_worker_1, value_for_worker_2], "key2": [value_for_worker_1, value_for_worker_2]} + + +**Returns:** `MultiWorkerFuture` + +Object containing futures and their associated worker information + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.run_all_workers_single_data( + method_name: str, + args = (), + run_rank_0_only_axes: list[str] | None = None, + kwargs = {} +) -> list[ray.ObjectRef] +``` + + + + + + +Run a method on all workers in parallel with the same data. + +**Parameters:** + + +Name of the method to call on each worker + + + +Arguments to pass to the method + + + +List of named axes for which only rank 0 should run the method. + + +**Returns:** `list[ray.ObjectRef]` + +list[ray.ObjectRef]: A list of ray futures + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.run_single_worker_single_data( + method_name: str, + worker_idx: int, + args = (), + kwargs = {} +) -> ray.ObjectRef +``` + + + + + + +Run a method on a single, specific worker. + +**Parameters:** + + +Name of the method to call on the worker. + + + +The index of the worker to run the method on. + + + +Arguments to pass to the method. + + +**Returns:** `ray.ObjectRef` + +ray.ObjectRef: A Ray future for the result. + + + + + + + +```python +nemo_rl.distributed.worker_groups.RayWorkerGroup.shutdown( + cleanup_method: typing.Optional[str] = None, + timeout: typing.Optional[float] = 30.0, + force: bool = False +) -> bool +``` + + + + + + +Shutdown all workers in the worker group. + +**Parameters:** + + +Optional method name to call on each worker before termination. + If provided, this method will be called on each worker to allow + for graceful cleanup. + + + +Timeout in seconds for graceful shutdown. Only applicable if cleanup_method is provided. + If None, wait indefinitely for workers to complete their cleanup. + + + +If True, forcefully terminate workers with ray.kill() even if cleanup_method is provided. + If cleanup_method is None, workers are always forcefully terminated. + + +**Returns:** `bool` + +True if all workers were successfully shut down + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx new file mode 100644 index 0000000..de15010 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments.mdx @@ -0,0 +1,19 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments +title: nemo_rl.environments +--- + +## Submodules + +- **[`nemo_rl.environments.code_environment`](/nemo-rl/nemo_rl/environments/code_environment)** +- **[`nemo_rl.environments.code_jaccard_environment`](/nemo-rl/nemo_rl/environments/code_jaccard_environment)** +- **[`nemo_rl.environments.dapo_math_verifier`](/nemo-rl/nemo_rl/environments/dapo_math_verifier)** +- **[`nemo_rl.environments.interfaces`](/nemo-rl/nemo_rl/environments/interfaces)** +- **[`nemo_rl.environments.math_environment`](/nemo-rl/nemo_rl/environments/math_environment)** +- **[`nemo_rl.environments.metrics`](/nemo-rl/nemo_rl/environments/metrics)** +- **[`nemo_rl.environments.nemo_gym`](/nemo-rl/nemo_rl/environments/nemo_gym)** +- **[`nemo_rl.environments.reward_model_environment`](/nemo-rl/nemo_rl/environments/reward_model_environment)** +- **[`nemo_rl.environments.rewards`](/nemo-rl/nemo_rl/environments/rewards)** +- **[`nemo_rl.environments.utils`](/nemo-rl/nemo_rl/environments/utils)** +- **[`nemo_rl.environments.vlm_environment`](/nemo-rl/nemo_rl/environments/vlm_environment)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx new file mode 100644 index 0000000..5c46941 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_environment.mdx @@ -0,0 +1,290 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/code_environment +title: nemo_rl.environments.code_environment +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CodeEnvConfig`](#nemo_rl-environments-code_environment-CodeEnvConfig) | - | +| [`CodeEnvMetadata`](#nemo_rl-environments-code_environment-CodeEnvMetadata) | - | +| [`CodeEnvironment`](#nemo_rl-environments-code_environment-CodeEnvironment) | Code execution environment that maintains state between steps. | +| [`CodeExecutionWorker`](#nemo_rl-environments-code_environment-CodeExecutionWorker) | Helper class to process individual code execution steps. | + +### API + + + + + +```python +class nemo_rl.environments.code_environment.CodeEnvConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.code_environment.CodeEnvMetadata +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.code_environment.CodeEnvironment( + cfg: nemo_rl.environments.code_environment.CodeEnvConfig +) +``` + + + + + + +**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + +Code execution environment that maintains state between steps. + + + + + + + + + + + + + + +```python +nemo_rl.environments.code_environment.CodeEnvironment.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict +) -> typing.Tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict] +``` + + + + + + +Compute metrics for the batch. + + + + + + + +```python +nemo_rl.environments.code_environment.CodeEnvironment.shutdown() +``` + + + + + + + + + + + + +```python +nemo_rl.environments.code_environment.CodeEnvironment.step( + message_log_batch: typing.List[nemo_rl.data.interfaces.LLMMessageLogType], + metadata_batch: typing.List[nemo_rl.environments.code_environment.CodeEnvMetadata], + return_extracted_answer: bool = False +) -> nemo_rl.environments.interfaces.EnvironmentReturn +``` + + + + + + +Process a batch of code execution steps. + + + + + + + + + +```python +class nemo_rl.environments.code_environment.CodeExecutionWorker() +``` + + + + + + +Helper class to process individual code execution steps. + + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.chdir( + dir: str +) +``` + + + + + + +Change to temporary directory for file operations. + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.execute( + message_batch: str, + metadata_batch: typing.List[nemo_rl.environments.code_environment.CodeEnvMetadata] +) -> typing.Tuple[typing.List[typing.Dict[str, str]], typing.List[bool], typing.List[typing.Any]] +``` + + + + + + +Execute code in a sandboxed environment. + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.format_result( + result: typing.Any, + code: typing.Optional[str] = None, + lookahead: typing.Optional[str] = None +) -> str +``` + + + + + + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.safe_import( + name: str, + args = (), + kwargs = {} +) +``` + + + + + + +Safe version of import that blocks risky modules. + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.safe_open( + file: str, + args = (), + kwargs = {} +) +``` + + + + + + +Safe version of open() that only allows access to temporary directory. + + + + + + + +```python +nemo_rl.environments.code_environment.CodeExecutionWorker.sanitize( + obj: typing.Any +) -> typing.Any +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx new file mode 100644 index 0000000..0bdc82a --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/code_jaccard_environment.mdx @@ -0,0 +1,268 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/code_jaccard_environment +title: nemo_rl.environments.code_jaccard_environment +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CodeJaccardEnvConfig`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardEnvConfig) | - | +| [`CodeJaccardEnvironment`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardEnvironment) | Environment for evaluating code responses using Jaccard similarity. | +| [`CodeJaccardEnvironmentMetadata`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardEnvironmentMetadata) | - | +| [`CodeJaccardVerifyWorker`](#nemo_rl-environments-code_jaccard_environment-CodeJaccardVerifyWorker) | Worker for evaluating code responses using Jaccard-based similarity. | + +### API + + + + + +```python +class nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment( + cfg: nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvConfig +) +``` + + + + + + +**Bases:** [EnvironmentInterface[CodeJaccardEnvironmentMetadata]](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + +Environment for evaluating code responses using Jaccard similarity. + + + + + + + + + + + +```python +nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], dict[str, float | int]] +``` + + + + + + +Post-process batch and compute metrics for CodeJaccard. + + + + + + + +```python +nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment.shutdown() -> None +``` + + + + + + +Shutdown all workers. + + + + + + + +```python +nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironment.step( + message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], + metadata: list[nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironmentMetadata], + return_extracted_answer: bool = False +) -> nemo_rl.environments.interfaces.EnvironmentReturn[nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironmentMetadata] +``` + + + + + + +Runs a step in the CodeJaccard environment. + +**Parameters:** + + +Batch of OpenAI-API-like message logs. + + + +Batch of CodeJaccardEnvironmentMetadata with ground truth. + + + +Whether to return extracted answers. + + +**Returns:** `EnvironmentReturn[CodeJaccardEnvironmentMetadata]` + +Tuple containing observations, metadata, stop strings, rewards, and done flags. + + + + + + + + + +```python +class nemo_rl.environments.code_jaccard_environment.CodeJaccardEnvironmentMetadata +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.environments.code_jaccard_environment.CodeJaccardVerifyWorker() +``` + + + + + + +Worker for evaluating code responses using Jaccard-based similarity. + + + + + + +```python +nemo_rl.environments.code_jaccard_environment.CodeJaccardVerifyWorker._calculate_preference_score( + response: str, + ground_truth: str +) -> float +``` + + + + + + +Calculate a Jaccard-based alignment score between response and ground truth. + +This is a simplified scoring function. In practice, you might want to use: +- Semantic similarity models +- BLEU/ROUGE scores +- Tokenize both texts into sets A and B (here we use whitespace tokenization). +- Compute intersection size |A ∩ B| and union size |A ∪ B|. +- J(A, B) = |A ∩ B| / |A ∪ B|, with guards for union=0 -> 0.0. +- Optionally combine with a length-ratio penalty to discourage degenerate very short/long matches. + +Complexity: +- Tokenization: O(n + m) +- Set ops: O(n + m) average (hash sets) + +**Parameters:** + + +The model's response + + +**Returns:** `float` + +Score between 0.0 and 1.0 + + + + + + + +```python +nemo_rl.environments.code_jaccard_environment.CodeJaccardVerifyWorker.verify( + pred_responses: list[str], + ground_truths: list[str], + return_extracted_answer: bool = False +) -> typing.Union[list[float], tuple[list[float], list[str | None]]] +``` + + + + + + +Verify code responses against ground-truth solutions using Jaccard-based similarity. + +We use a simple text similarity approach (Jaccard over tokenized words) +to evaluate how well the model's response aligns with the ground truth. + +**Parameters:** + + +list[str]. The predicted responses from the LLM. + + + +list[str]. The ground-truth solutions. + + + +bool. Whether to return extracted answers (here, the full response). + + +**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` + +Union[list[float], tuple[list[float], list[str | None]]]. + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx new file mode 100644 index 0000000..ecea315 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/dapo_math_verifier.mdx @@ -0,0 +1,316 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/dapo_math_verifier +title: nemo_rl.environments.dapo_math_verifier +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`compute_score`](#nemo_rl-environments-dapo_math_verifier-compute_score) | Compute the reward score for a solution. | +| [`is_correct_minerva`](#nemo_rl-environments-dapo_math_verifier-is_correct_minerva) | Check if the solution is correct according to Minerva criteria. | +| [`is_correct_strict_box`](#nemo_rl-environments-dapo_math_verifier-is_correct_strict_box) | Check if the prediction is correct using strict boxed answer criteria. | +| [`last_boxed_only_string`](#nemo_rl-environments-dapo_math_verifier-last_boxed_only_string) | Extract the last LaTeX boxed expression from a string. | +| [`normalize_final_answer`](#nemo_rl-environments-dapo_math_verifier-normalize_final_answer) | Normalize a final answer to a quantitative reasoning question. | +| [`remove_boxed`](#nemo_rl-environments-dapo_math_verifier-remove_boxed) | Remove the LaTeX boxed command from a string. | +| [`verify`](#nemo_rl-environments-dapo_math_verifier-verify) | Verify if the solution is correct. | + +### Data + +[`REMOVED_EXPRESSIONS`](#nemo_rl-environments-dapo_math_verifier-REMOVED_EXPRESSIONS) + +[`SUBSTITUTIONS`](#nemo_rl-environments-dapo_math_verifier-SUBSTITUTIONS) + +### API + + + + + +```python +nemo_rl.environments.dapo_math_verifier.compute_score( + solution_str: str, + ground_truth: str, + strict_box_verify: bool = False, + pause_tokens_index: typing.Optional[list[int]] = None +) -> float +``` + + + + + + +Compute the reward score for a solution. + +**Parameters:** + + +The solution string + + + +The ground truth answer + + + +Whether to use strict box verification + + + +Indices of pause tokens + + +**Returns:** `float` + +Reward score (1.0 for correct, 0.0 for incorrect) + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.is_correct_minerva( + solution_str: str, + gt: str, + gt_need_extract: bool = False, + answer_pattern: str = '(?i)Answer\\s*:\\s*([^\\n]+)' +) -> tuple[bool, str] +``` + + + + + + +Check if the solution is correct according to Minerva criteria. + +**Parameters:** + + +The solution string to check + + + +The ground truth answer + + + +Whether the ground truth needs extraction + + + +Regex pattern to extract the answer + + +**Returns:** `tuple[bool, str]` + +Tuple of (is_correct, normalized_prediction) + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.is_correct_strict_box( + pred: str, + gt: str, + pause_tokens_index: typing.Optional[list[int]] = None +) -> tuple[int, typing.Optional[str]] +``` + + + + + + +Check if the prediction is correct using strict boxed answer criteria. + +**Parameters:** + + +The prediction string + + + +The ground truth answer + + + +Indices of pause tokens + + +**Returns:** `tuple[int, Optional[str]]` + +Tuple of (score, extracted_prediction) + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.last_boxed_only_string( + string: str +) -> typing.Optional[str] +``` + + + + + + +Extract the last LaTeX boxed expression from a string. + +**Parameters:** + + +Input string containing LaTeX code + + +**Returns:** `Optional[str]` + +The last boxed expression or None if not found + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.normalize_final_answer( + final_answer: str +) -> str +``` + + + + + + +Normalize a final answer to a quantitative reasoning question. + +**Parameters:** + + +The answer string to normalize + + +**Returns:** `str` + +Normalized answer string + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.remove_boxed( + s: str +) -> str +``` + + + + + + +Remove the LaTeX boxed command from a string. + +**Parameters:** + + +String with format "\\boxed{content}" + + +**Returns:** `str` + +The content inside the boxed command + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.verify( + solution_str: str, + answer: str, + strict_box_verify: bool = False, + pause_tokens_index: typing.Optional[list[int]] = None +) -> bool +``` + + + + + + +Verify if the solution is correct. + +**Parameters:** + + +The solution string to verify + + + +The ground truth answer + + + +Whether to use strict box verification + + + +Indices of pause tokens + + +**Returns:** `bool` + +True if the solution is correct, False otherwise + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.REMOVED_EXPRESSIONS = ['square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'hours', 'km', 'units... +``` + + + + + + + + + +```python +nemo_rl.environments.dapo_math_verifier.SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''), ('\\ ', ''), (' ', ''), ('mb... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx new file mode 100644 index 0000000..487b356 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/interfaces.mdx @@ -0,0 +1,151 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/interfaces +title: nemo_rl.environments.interfaces +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`EnvironmentInterface`](#nemo_rl-environments-interfaces-EnvironmentInterface) | - | +| [`EnvironmentReturn`](#nemo_rl-environments-interfaces-EnvironmentReturn) | Standard batched return type for environment step methods. | + +### Data + +[`MetadataT`](#nemo_rl-environments-interfaces-MetadataT) + +### API + + + + + +```python +class nemo_rl.environments.interfaces.EnvironmentInterface() +``` + + + + + + +Abstract + +**Bases:** `ABC`, `Generic[MetadataT]` + + + + + +```python +nemo_rl.environments.interfaces.EnvironmentInterface.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict] +``` + + + + + + +abstract + +Post processing function after all rollouts are done for the batch and returns metrics. + + + + + + + +```python +nemo_rl.environments.interfaces.EnvironmentInterface.step( + message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], + metadata: list[nemo_rl.environments.interfaces.MetadataT] +) -> nemo_rl.environments.interfaces.EnvironmentReturn[nemo_rl.environments.interfaces.MetadataT] +``` + + + + + + +abstract + +Runs a step in the environment. Allows for asynchrony with remote servers, but it's not required (this function is a ray remote). + +metadata: batch of whatever the environment needs to keep track of. I.e. + math solutions, code unit tests, or agent states. Can be None if episode terminated. + +Returns: +- EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminateds flags. + + + + + + + + + +```python +class nemo_rl.environments.interfaces.EnvironmentReturn() +``` + + + + + + +**Bases:** `NamedTuple`, `Generic[MetadataT]` + +Standard batched return type for environment step methods. + +**All elements are batched.** +observations: New observation from the environment. + It's a (batched) 'message' type, which is a dict + with keys 'role' and 'content'. +metadata: Updated metadata from the environment. +next_stop_strings: The stop strings for the next turn. + If your environment is a game or similar, + you may want to return a list of stop strings + that are valid actions for the next turn or + similar. This field lets you control this per turn. +rewards: the rewards for this turn. +terminateds: whether the episode ended this turn. +answers: the answers for this turn. + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.environments.interfaces.MetadataT = TypeVar('MetadataT') +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx new file mode 100644 index 0000000..eccdd26 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/math_environment.mdx @@ -0,0 +1,356 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/math_environment +title: nemo_rl.environments.math_environment +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`EnglishMultichoiceVerifyWorker`](#nemo_rl-environments-math_environment-EnglishMultichoiceVerifyWorker) | - | +| [`HFVerifyWorker`](#nemo_rl-environments-math_environment-HFVerifyWorker) | - | +| [`MathEnvConfig`](#nemo_rl-environments-math_environment-MathEnvConfig) | - | +| [`MathEnvironment`](#nemo_rl-environments-math_environment-MathEnvironment) | - | +| [`MathEnvironmentMetadata`](#nemo_rl-environments-math_environment-MathEnvironmentMetadata) | - | +| [`MultilingualMultichoiceVerifyWorker`](#nemo_rl-environments-math_environment-MultilingualMultichoiceVerifyWorker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_mute_output`](#nemo_rl-environments-math_environment-_mute_output) | - | + +### API + + + + + +```python +class nemo_rl.environments.math_environment.EnglishMultichoiceVerifyWorker() +``` + + + + + + + + + + +```python +nemo_rl.environments.math_environment.EnglishMultichoiceVerifyWorker.verify( + pred_responses: list[str], + ground_truths: list[str], + return_extracted_answer: bool = False, + kwargs = {} +) -> typing.Union[list[float], tuple[list[float], list[str | None]]] +``` + + + + + + +Verify the correctness of the predicted responses against the ground truth. + +**Parameters:** + + +list[str]. The predicted responses from the LLM. + + + +list[str]. The ground truth responses. + + +**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` + +Union[list[float], tuple[list[float], list[str | None]]]. + + + + + + + + + +```python +class nemo_rl.environments.math_environment.HFVerifyWorker() +``` + + + + + + + + + + + + +```python +nemo_rl.environments.math_environment.HFVerifyWorker.verify( + pred_responses: list[str], + ground_truths: list[str], + return_extracted_answer: bool = False, + kwargs = {} +) -> typing.Union[list[float], tuple[list[float], list[str | None]]] +``` + + + + + + +Verify the correctness of the predicted responses against the ground truth. + +**Parameters:** + + +list[str]. The predicted responses from the LLM. + + + +list[str]. The ground truth responses. + + +**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` + +Union[list[float], tuple[list[float], list[str | None]]]. + + + + + + + + + +```python +class nemo_rl.environments.math_environment.MathEnvConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.math_environment.MathEnvironment( + cfg: nemo_rl.environments.math_environment.MathEnvConfig +) +``` + + + + + + +**Bases:** [EnvironmentInterface[MathEnvironmentMetadata]](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + + + + + + + + + + +```python +nemo_rl.environments.math_environment.MathEnvironment.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], dict[str, float | int]] +``` + + + + + + +Computes metrics for this environment given a global rollout batch. + +Every rank will run this function, so you're free to use distributed +calculations if you'd prefer for heavy metrics. + + + + + + + +```python +nemo_rl.environments.math_environment.MathEnvironment.shutdown() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.environments.math_environment.MathEnvironment.step( + message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType], + metadata: list[nemo_rl.environments.math_environment.MathEnvironmentMetadata], + return_extracted_answer: bool = False +) -> nemo_rl.environments.interfaces.EnvironmentReturn[nemo_rl.environments.math_environment.MathEnvironmentMetadata] +``` + + + + + + +Runs a step in the math environment. + +**Parameters:** + + +list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM. + + + +list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. The extracted answer will be stored to caculate cons@k. + + +**Returns:** `EnvironmentReturn[MathEnvironmentMetadata]` + +A tuple containing: +- list[dict[str, str]]: Observations/responses batch +- list[dict]: Updated metadata +- list[str]: Next stop strings for the next turn +- Tensor: Rewards tensor +- Tensor: Done flags tensor + + + + + + + + + +```python +class nemo_rl.environments.math_environment.MathEnvironmentMetadata +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.math_environment.MultilingualMultichoiceVerifyWorker() +``` + + + + + + + + + + +```python +nemo_rl.environments.math_environment.MultilingualMultichoiceVerifyWorker.verify( + pred_responses: list[str], + ground_truths: list[str], + return_extracted_answer: bool = False, + kwargs = {} +) -> typing.Union[list[float], tuple[list[float], list[str | None]]] +``` + + + + + + +Verify the correctness of the predicted responses against the ground truth. + +**Parameters:** + + +list[str]. The predicted responses from the LLM. + + + +list[str]. The ground truth responses. + + +**Returns:** `Union[list[float], tuple[list[float], list[str | None]]]` + +Union[list[float], tuple[list[float], list[str | None]]]. + + + + + + + + + +```python +nemo_rl.environments.math_environment._mute_output() +``` + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx new file mode 100644 index 0000000..12386b2 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/metrics.mdx @@ -0,0 +1,42 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/metrics +title: nemo_rl.environments.metrics +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`calculate_pass_rate_per_prompt`](#nemo_rl-environments-metrics-calculate_pass_rate_per_prompt) | Function to compute fraction of prompts that have at least one correct answer (reward > 0). | + +### API + + + + + +```python +nemo_rl.environments.metrics.calculate_pass_rate_per_prompt( + prompts: torch.Tensor, + is_correct: torch.Tensor +) -> float +``` + + + + + + +Function to compute fraction of prompts that have at least one correct answer (reward > 0). + +prompts: tensor (b, s) Tensor of prompts the model used. May be on any device +is_correct: tensor (b,) bool-valued label. May be on any device + +Returns: +pass rate: float + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx new file mode 100644 index 0000000..dcc121c --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/nemo_gym.mdx @@ -0,0 +1,213 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/nemo_gym +title: nemo_rl.environments.nemo_gym +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`NemoGym`](#nemo_rl-environments-nemo_gym-NemoGym) | This environment class isn't really used for training. It's really meant as an integration wrapper around NeMo-Gym that hooks into the existing NeMo RL resource management via ray. So there is still one source of truth for resource management in NeMo RL. | +| [`NemoGymConfig`](#nemo_rl-environments-nemo_gym-NemoGymConfig) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`setup_nemo_gym_config`](#nemo_rl-environments-nemo_gym-setup_nemo_gym_config) | - | + +### API + + + + + +```python +class nemo_rl.environments.nemo_gym.NemoGym( + cfg: nemo_rl.environments.nemo_gym.NemoGymConfig +) +``` + + + + + + +**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + +This environment class isn't really used for training. It's really meant as an integration wrapper around NeMo-Gym that hooks into the existing NeMo RL resource management via ray. So there is still one source of truth for resource management in NeMo RL. + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym._postprocess_nemo_gym_to_nemo_rl_result( + nemo_gym_result: dict, + tokenizer: transformers.PreTrainedTokenizerBase +) -> dict +``` + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym.global_post_process_and_metrics( + batch +) +``` + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym.health_check() -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym.run_rollouts( + nemo_gym_examples: list[dict], + tokenizer: transformers.PreTrainedTokenizerBase, + timer_prefix: str +) -> list[dict] +``` + + + + + + +async + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym.shutdown() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.NemoGym.step( + message_log_batch, + metadata +) +``` + + + + + + + + + + + + + + +```python +class nemo_rl.environments.nemo_gym.NemoGymConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.environments.nemo_gym.setup_nemo_gym_config( + config, + tokenizer +) -> None +``` + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx new file mode 100644 index 0000000..67f23d3 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/reward_model_environment.mdx @@ -0,0 +1,276 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/reward_model_environment +title: nemo_rl.environments.reward_model_environment +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RewardModelEnvironment`](#nemo_rl-environments-reward_model_environment-RewardModelEnvironment) | Environment that uses a reward model to score conversations. | +| [`RewardModelEnvironmentConfig`](#nemo_rl-environments-reward_model_environment-RewardModelEnvironmentConfig) | Configuration for RewardModelEnvironment. | + +### API + + + + + +```python +class nemo_rl.environments.reward_model_environment.RewardModelEnvironment( + config: typing.Dict[str, typing.Any] +) +``` + + + + + + +**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + +Environment that uses a reward model to score conversations. + +This environment implements a reward model-based scoring system for reinforcement +learning tasks. It takes conversation logs as input and returns rewards based on +the quality of the assistant's responses as judged by a pre-trained reward model. + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.environments.reward_model_environment.RewardModelEnvironment.__del__() +``` + + + + + + +Destructor that ensures proper cleanup when the object is garbage collected. + +This is an extra safety net in case the user forgets to call shutdown() and +the pointer to the object is lost due to leaving a function scope. It's always +recommended that the user calls shutdown() explicitly for better resource +management. + + + + + + + +```python +nemo_rl.environments.reward_model_environment.RewardModelEnvironment.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict +) -> typing.Tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict] +``` + + + + + + +Post processing function after all rollouts are done for the batch and returns metrics. + +This method computes aggregate statistics and metrics from the processed batch. +It provides insights into reward distribution and processing statistics. + +**Parameters:** + + +The batch data dictionary containing processed conversations and rewards. + + +**Returns:** `BatchedDataDict` + +Tuple of (processed_batch, metrics_dict) where: + + + + + + + +```python +nemo_rl.environments.reward_model_environment.RewardModelEnvironment.preprocess_data( + message_logs: typing.List[nemo_rl.data.interfaces.LLMMessageLogType] +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec] +``` + + + + + + +Preprocess the message logs for the reward model. + +This method tokenizes and formats conversation logs into the format expected +by the reward model. It handles: +- Tokenization of user and assistant messages +- Formatting with proper special tokens +- Batching and padding for efficient processing +- Sequence length validation and truncation + +**Parameters:** + + +List of conversation message logs, where each log contains + a list of messages with 'role' and 'content' fields. + + +**Returns:** `BatchedDataDict[GenerationDatumSpec]` + +BatchedDataDict containing tokenized and formatted data ready for + + + + + + + +```python +nemo_rl.environments.reward_model_environment.RewardModelEnvironment.shutdown() +``` + + + + + + +Shutdown the reward model worker and virtual cluster. + +This method properly cleans up resources by shutting down the reward model +policy and virtual cluster. It should be called when the environment is +no longer needed to prevent resource leaks. + + + + + + + +```python +nemo_rl.environments.reward_model_environment.RewardModelEnvironment.step( + message_logs: typing.List[nemo_rl.data.interfaces.LLMMessageLogType], + env_infos: typing.List[typing.Dict[str, typing.Any]] +) -> nemo_rl.environments.interfaces.EnvironmentReturn +``` + + + + + + +Calculate rewards for the given message logs using the reward model. + +This method processes conversation logs through the reward model to compute +quality scores for each conversation. The rewards are based on the reward +model's assessment of how well the assistant's responses align with human +preferences. + +**Parameters:** + + +List of conversation message logs to be scored. + Each log should contain alternating user and assistant messages. + + + +List of environment info dictionaries (currently unused + but required by the interface). + + +**Returns:** `EnvironmentReturn` + +EnvironmentReturn containing: + + + + + + + + + +```python +class nemo_rl.environments.reward_model_environment.RewardModelEnvironmentConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for RewardModelEnvironment. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx new file mode 100644 index 0000000..aaa41ba --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/rewards.mdx @@ -0,0 +1,180 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/rewards +title: nemo_rl.environments.rewards +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`bbox_giou_reward`](#nemo_rl-environments-rewards-bbox_giou_reward) | Given [x1, y1, x2, y2] normalized bounding box coordinates within the <{answer_tag}> tags, compute the GIoU between the ground truth and the response. | +| [`combine_reward_functions`](#nemo_rl-environments-rewards-combine_reward_functions) | Returns a callable function that takes (ground_truth, response) and collects multiple reward functions in sequence. | +| [`exact_answer_alphanumeric_reward`](#nemo_rl-environments-rewards-exact_answer_alphanumeric_reward) | Reward the agent when the answer within the <{answer_tag}> tags is the same as the ground truth (case-insensitive). | +| [`format_reward`](#nemo_rl-environments-rewards-format_reward) | Reward the agent when the response follows the format: (.*) <think> (.*) </think> <answer> (.*) </answer>. | +| [`math_expression_reward`](#nemo_rl-environments-rewards-math_expression_reward) | Reward the agent when the answer within the <{tag}> tags is the same expression as the ground truth. | + +### Data + +[`boxed`](#nemo_rl-environments-rewards-boxed) + +[`math_verify_func`](#nemo_rl-environments-rewards-math_verify_func) + +### API + + + + + +```python +nemo_rl.environments.rewards.bbox_giou_reward( + ground_truth: str, + response: str, + giou_penalty_thres: float = 10.0, + answer_tag: str = 'answer' +) -> tuple[float, bool] +``` + + + + + + +Given [x1, y1, x2, y2] normalized bounding box coordinates within the <{answer_tag}> tags, compute the GIoU between the ground truth and the response. + +The `answer_tag` is customizable and must be specified as part of the user COT prompt text file. + + + + + + + + +```python +nemo_rl.environments.rewards.combine_reward_functions( + reward_functions: list[tuple[typing.Callable[[str, str], tuple[float, bool]], float]] +) -> typing.Callable[[str, str], tuple[float, bool]] +``` + + + + + + +Returns a callable function that takes (ground_truth, response) and collects multiple reward functions in sequence. + +The reward functions are weighted by the second element of the tuple. +This information can be provided in the YAML config file and resolved in the VLMEnvironment class. + +**Parameters:** + + +list[tuple[Callable[[str, str], tuple[float, bool]], float]]. A list of reward functions and their weights. + + +**Returns:** `Callable[[str, str], tuple[float, bool]]` + +Callable[[str, str], tuple[float, bool]]: A callable function that takes (ground_truth, response) and collects multiple reward functions in sequence + + + + + + + + +```python +nemo_rl.environments.rewards.exact_answer_alphanumeric_reward( + ground_truth: str, + response: str, + answer_tag: str = 'answer' +) -> tuple[float, bool] +``` + + + + + + +Reward the agent when the answer within the <{answer_tag}> tags is the same as the ground truth (case-insensitive). + +The `answer_tag` is customizable and must be specified as part of the user COT prompt text file. + + + + + + + + +```python +nemo_rl.environments.rewards.format_reward( + ground_truth: str, + response: str, + think_tag: str = 'think', + answer_tag: str = 'answer' +) -> tuple[float, typing.Optional[bool]] +``` + + + + + + +Reward the agent when the response follows the format: (.*) <think> (.*) </think> <answer> (.*) </answer>. + +The `think_tag` and `answer_tag` are customizable and must be specified as part of the user COT prompt text file. + + + + + + + + +```python +nemo_rl.environments.rewards.math_expression_reward( + ground_truth: str, + response: str, + tag: str = 'answer' +) -> tuple[float, bool] +``` + + + + + + +Reward the agent when the answer within the <{tag}> tags is the same expression as the ground truth. + +The `tag` is customizable and must be specified as part of the user COT prompt text file. + + + + + + + + +```python +nemo_rl.environments.rewards.boxed = lambda x: '\\boxed{' + x + '}' if not x.startswith('\\boxed{') else x +``` + + + + + + + + + +```python +nemo_rl.environments.rewards.math_verify_func = math_metric(gold_extraction_target=(LatexExtractionConfig(),), pred_extraction_t... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx new file mode 100644 index 0000000..67e63b2 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/utils.mdx @@ -0,0 +1,152 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/utils +title: nemo_rl.environments.utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`EnvRegistryEntry`](#nemo_rl-environments-utils-EnvRegistryEntry) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`chunk_list_to_workers`](#nemo_rl-environments-utils-chunk_list_to_workers) | Chunk a list into a list of lists, where each sublist is assigned to a worker. Keeps ordering of elements. | +| [`create_env`](#nemo_rl-environments-utils-create_env) | - | +| [`register_env`](#nemo_rl-environments-utils-register_env) | - | + +### Data + +[`ENV_REGISTRY`](#nemo_rl-environments-utils-ENV_REGISTRY) + +### API + + + + + +```python +class nemo_rl.environments.utils.EnvRegistryEntry +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +nemo_rl.environments.utils.chunk_list_to_workers( + to_chunk: list[typing.Any], + num_workers: int +) -> list[list[typing.Any]] +``` + + + + + + +Chunk a list into a list of lists, where each sublist is assigned to a worker. Keeps ordering of elements. + +If the list is not divisible by the number of workers, the last worker may have fewer elements. +If there are more workers than elements, the first len(list) workers will have a single element each, +and the remaining workers will have empty lists. + +Examples: + + +```python +>>> from nemo_rl.environments.utils import chunk_list_to_workers +>>> chunk_list_to_workers([1, 2, 3, 4, 5], 3) +[[1, 2], [3, 4], [5]] +``` + + + +**Parameters:** + + +The list to be chunked. + + + +The number of workers to distribute the list to. + + +**Returns:** `list[list[Any]]` + +A list of lists, where each sublist contains elements assigned to a worker. + + + + + + + + +```python +nemo_rl.environments.utils.create_env( + env_name: str, + env_config: dict +) -> nemo_rl.environments.interfaces.EnvironmentInterface +``` + + + + + + + + + + + + + +```python +nemo_rl.environments.utils.register_env( + env_name: str, + actor_class_fqn: str +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.environments.utils.ENV_REGISTRY: Dict[str, EnvRegistryEntry] = {'math_default': {'actor_class_fqn': 'nemo_rl.environments.math_environment.Math... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx new file mode 100644 index 0000000..40b7a2a --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/environments/vlm_environment.mdx @@ -0,0 +1,243 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/environments/vlm_environment +title: nemo_rl.environments.vlm_environment +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`VLMEnvConfig`](#nemo_rl-environments-vlm_environment-VLMEnvConfig) | - | +| [`VLMEnvironment`](#nemo_rl-environments-vlm_environment-VLMEnvironment) | - | +| [`VLMEnvironmentMetadata`](#nemo_rl-environments-vlm_environment-VLMEnvironmentMetadata) | - | +| [`VLMVerifyWorker`](#nemo_rl-environments-vlm_environment-VLMVerifyWorker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_mute_output`](#nemo_rl-environments-vlm_environment-_mute_output) | - | + +### API + + + + + +```python +class nemo_rl.environments.vlm_environment.VLMEnvConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.environments.vlm_environment.VLMEnvironment( + cfg: nemo_rl.environments.vlm_environment.VLMEnvConfig +) +``` + + + + + + +**Bases:** [EnvironmentInterface](/nemo-rl/nemo_rl/environments/interfaces#nemo_rl-environments-interfaces-EnvironmentInterface) + + + + + + + + + + +```python +nemo_rl.environments.vlm_environment.VLMEnvironment.global_post_process_and_metrics( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], dict[str, float | int]] +``` + + + + + + +Computes metrics for this environment given a global rollout batch. + +Every rank will run this function, so you're free to use distributed +calculations if you'd prefer for heavy metrics. + + + + + + + +```python +nemo_rl.environments.vlm_environment.VLMEnvironment.shutdown() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.environments.vlm_environment.VLMEnvironment.step( + message_log_batch: list[list[dict[str, str]]], + metadata: list[nemo_rl.environments.vlm_environment.VLMEnvironmentMetadata] +) -> nemo_rl.environments.interfaces.EnvironmentReturn +``` + + + + + + +Runs a step in the vlm environment. + +**Parameters:** + + +list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the VLM. + + + +list[VLMEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. + + +**Returns:** `EnvironmentReturn` + +A tuple containing: +- list[dict[str, str]]: Observations/responses batch +- list[dict]: Updated metadata +- list[str]: Next stop strings for the next turn +- Tensor: Rewards tensor +- Tensor: Done flags tensor + + + + + + + + + +```python +class nemo_rl.environments.vlm_environment.VLMEnvironmentMetadata +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.environments.vlm_environment.VLMVerifyWorker( + cfg: nemo_rl.environments.vlm_environment.VLMEnvConfig +) +``` + + + + + + + + + + + + +```python +nemo_rl.environments.vlm_environment.VLMVerifyWorker.verify( + pred_responses: list[str], + ground_truths: list[str] +) -> list[float] +``` + + + + + + +Verify the correctness of the predicted responses against the ground truth. + +**Parameters:** + + +list[str]. The predicted responses from the LLM. + + + +list[str]. The ground truth responses. + + +**Returns:** `list[float]` + +list[float]. The rewards for each predicted response. + + + + + + + + + +```python +nemo_rl.environments.vlm_environment._mute_output() +``` + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx new file mode 100644 index 0000000..e7d333c --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals.mdx @@ -0,0 +1,10 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/evals +title: nemo_rl.evals +--- + +## Submodules + +- **[`nemo_rl.evals.answer_parsing`](/nemo-rl/nemo_rl/evals/answer_parsing)** +- **[`nemo_rl.evals.eval`](/nemo-rl/nemo_rl/evals/eval)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx new file mode 100644 index 0000000..9126f64 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals/answer_parsing.mdx @@ -0,0 +1,86 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/evals/answer_parsing +title: nemo_rl.evals.answer_parsing +--- + +Contains utility functions for answer parsing. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`normalize_extracted_answer`](#nemo_rl-evals-answer_parsing-normalize_extracted_answer) | - | +| [`normalize_response`](#nemo_rl-evals-answer_parsing-normalize_response) | Normalize the response by removing markdown and LaTeX formatting that may prevent a match. | + +### Data + +[`MULTILINGUAL_ANSWER_PATTERN_TEMPLATE`](#nemo_rl-evals-answer_parsing-MULTILINGUAL_ANSWER_PATTERN_TEMPLATE) + +[`MULTILINGUAL_ANSWER_REGEXES`](#nemo_rl-evals-answer_parsing-MULTILINGUAL_ANSWER_REGEXES) + +### API + + + + + +```python +nemo_rl.evals.answer_parsing.normalize_extracted_answer( + extracted_answer: str +) -> str +``` + + + + + + + + + + + + + +```python +nemo_rl.evals.answer_parsing.normalize_response( + response: str +) -> str +``` + + + + + + +Normalize the response by removing markdown and LaTeX formatting that may prevent a match. + + + + + + + + +```python +nemo_rl.evals.answer_parsing.MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = '(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])' +``` + + + + + + + + + +```python +nemo_rl.evals.answer_parsing.MULTILINGUAL_ANSWER_REGEXES = ['Answer\\s*:', 'Answer\\s*:\u200b\u200b\u200b\u200b\u200b\u200b', 'উত্তর\\s*:',... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx new file mode 100644 index 0000000..f2712a9 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/evals/eval.mdx @@ -0,0 +1,399 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/evals/eval +title: nemo_rl.evals.eval +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`EvalConfig`](#nemo_rl-evals-eval-EvalConfig) | - | +| [`MasterConfig`](#nemo_rl-evals-eval-MasterConfig) | - | +| [`_PassThroughMathConfig`](#nemo_rl-evals-eval-_PassThroughMathConfig) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_generate_texts`](#nemo_rl-evals-eval-_generate_texts) | Generate texts using either sync or async method. | +| [`_print_results`](#nemo_rl-evals-eval-_print_results) | Print evaluation results. | +| [`_run_env_eval_impl`](#nemo_rl-evals-eval-_run_env_eval_impl) | Unified implementation for both sync and async evaluation. | +| [`_save_evaluation_data_to_json`](#nemo_rl-evals-eval-_save_evaluation_data_to_json) | Save evaluation data to a JSON file. | +| [`eval_cons_k`](#nemo_rl-evals-eval-eval_cons_k) | Evaluate cons@k score using an unbiased estimator. | +| [`eval_pass_k`](#nemo_rl-evals-eval-eval_pass_k) | Evaluate pass@k score using an unbiased estimator. | +| [`run_env_eval`](#nemo_rl-evals-eval-run_env_eval) | Main entry point for running evaluation using environment. | +| [`setup`](#nemo_rl-evals-eval-setup) | Set up components for model evaluation. | + +### API + + + + + +```python +class nemo_rl.evals.eval.EvalConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.evals.eval.MasterConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.evals.eval._PassThroughMathConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +nemo_rl.evals.eval._generate_texts( + vllm_generation, + inputs, + use_async +) +``` + + + + + + +async + +Generate texts using either sync or async method. + + + + + + + + +```python +nemo_rl.evals.eval._print_results( + master_config, + generation_config, + score, + dataset_size, + metric, + k_value, + num_tests_per_prompt +) +``` + + + + + + +Print evaluation results. + + + + + + + + +```python +nemo_rl.evals.eval._run_env_eval_impl( + vllm_generation, + dataloader, + env, + master_config, + use_async = False +) +``` + + + + + + +async + +Unified implementation for both sync and async evaluation. + + + + + + + + +```python +nemo_rl.evals.eval._save_evaluation_data_to_json( + evaluation_data, + master_config, + save_path +) +``` + + + + + + +Save evaluation data to a JSON file. + +**Parameters:** + + +List of evaluation samples + + + +Configuration dictionary + + + +Path to save evaluation results. Set to null to disable saving. + Example: "results/eval_output" or "/path/to/evaluation_results" + + + + + + + + + +```python +nemo_rl.evals.eval.eval_cons_k( + rewards: torch.Tensor, + num_tests_per_prompt: int, + k: int, + extracted_answers: list[str | None] +) -> float +``` + + + + + + +Evaluate cons@k score using an unbiased estimator. + +**Parameters:** + + +Tensor of shape (batch_size * num_tests_per_prompt) + + + +int + + + +int + + + +list[str| None] + + +**Returns:** `float` + +float + + + + + + + + +```python +nemo_rl.evals.eval.eval_pass_k( + rewards: torch.Tensor, + num_tests_per_prompt: int, + k: int +) -> float +``` + + + + + + +Evaluate pass@k score using an unbiased estimator. + +Reference: https://github.com/huggingface/evaluate/blob/32546aafec25cdc2a5d7dd9f941fc5be56ba122f/metrics/code_eval/code_eval.py#L198-L213 +Args: + rewards: Tensor of shape (batch_size * num_tests_per_prompt) + k: int (pass@k value) + +**Returns:** `float` + +float + + + + + + + + +```python +nemo_rl.evals.eval.run_env_eval( + vllm_generation, + dataloader, + env, + master_config +) +``` + + + + + + +Main entry point for running evaluation using environment. + +Generates model responses and evaluates them by env. + +**Parameters:** + + +Model for generating responses. + + + +Data loader with evaluation samples. + + + +Environment that scores responses. + + + +Configuration settings. + + + + + + + + + +```python +nemo_rl.evals.eval.setup( + master_config: nemo_rl.evals.eval.MasterConfig, + tokenizer: transformers.AutoTokenizer, + dataset: nemo_rl.data.datasets.AllTaskProcessedDataset +) -> tuple[nemo_rl.models.generation.vllm.VllmGeneration, torch.utils.data.DataLoader, nemo_rl.evals.eval.MasterConfig] +``` + + + + + + +Set up components for model evaluation. + +Initializes the VLLM model and data loader. + +**Parameters:** + + +Configuration settings. + + + +Dataset to evaluate on. + + +**Returns:** `tuple[VllmGeneration, DataLoader, MasterConfig]` + +VLLM model, data loader, and config. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx new file mode 100644 index 0000000..eacff25 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/experience.mdx @@ -0,0 +1,9 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/experience +title: nemo_rl.experience +--- + +## Submodules + +- **[`nemo_rl.experience.rollouts`](/nemo-rl/nemo_rl/experience/rollouts)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx new file mode 100644 index 0000000..a0cb3e5 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/experience/rollouts.mdx @@ -0,0 +1,489 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/experience/rollouts +title: nemo_rl.experience.rollouts +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AsyncNemoGymRolloutResult`](#nemo_rl-experience-rollouts-AsyncNemoGymRolloutResult) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_calculate_single_metric`](#nemo_rl-experience-rollouts-_calculate_single_metric) | - | +| [`_tensorize_by_key`](#nemo_rl-experience-rollouts-_tensorize_by_key) | - | +| [`async_generate_response_for_sample_turn`](#nemo_rl-experience-rollouts-async_generate_response_for_sample_turn) | Generate a response for a single sample's turn using async generation. | +| [`calculate_rewards`](#nemo_rl-experience-rollouts-calculate_rewards) | Calculate rewards for generated responses and get environment feedback. | +| [`generate_responses`](#nemo_rl-experience-rollouts-generate_responses) | Generate responses from policy using synchronous generation. | +| [`generate_responses_async`](#nemo_rl-experience-rollouts-generate_responses_async) | Async version of generate_responses that properly calls generate_async. | +| [`run_async_multi_turn_rollout`](#nemo_rl-experience-rollouts-run_async_multi_turn_rollout) | Run multi-turn rollouts with sample-level processing. | +| [`run_async_nemo_gym_rollout`](#nemo_rl-experience-rollouts-run_async_nemo_gym_rollout) | Run multi-turn rollouts with NeMo-Gym. Please refer to the `run_async_multi_turn_rollout` docs for more information on the parameters. | +| [`run_multi_turn_rollout`](#nemo_rl-experience-rollouts-run_multi_turn_rollout) | Runs a multi-turn rollout loop, interacting with the environment. | +| [`run_sample_multi_turn_rollout`](#nemo_rl-experience-rollouts-run_sample_multi_turn_rollout) | Run a multi-turn rollout for a single sample. | + +### Data + +[`TokenizerType`](#nemo_rl-experience-rollouts-TokenizerType) + +### API + + + + + +```python +class nemo_rl.experience.rollouts.AsyncNemoGymRolloutResult( + input_ids: torch.Tensor, + final_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + rollout_metrics: dict[str, typing.Any] +) +``` + + + + + + +Dataclass + + + + + + + + + + + + + + + +```python +nemo_rl.experience.rollouts._calculate_single_metric( + values: list[float], + batch_size: int, + key_name: str +) -> dict +``` + + + + + + + + + + + + + +```python +nemo_rl.experience.rollouts._tensorize_by_key( + message_logs: list, + key: str +) +``` + + + + + + + + + + + + + +```python +nemo_rl.experience.rollouts.async_generate_response_for_sample_turn( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + sample_message_log: list[dict], + sample_stop_strings: list[str] | None, + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + max_seq_len: int, + greedy: bool = False +) -> tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, float]] +``` + + + + + + +async + +Generate a response for a single sample's turn using async generation. + +**Parameters:** + + +The generation interface to use + + + +Message log for a single sample + + + +Stop strings for this sample + + + +Tokenizer to use + + + +Maximum sequence length + + + +Whether to use greedy decoding + + +**Returns:** `tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, float]]` + +Tuple of (updated_message_log, generated_tokens, input_lengths, generation_metrics) + + + + + + + + +```python +nemo_rl.experience.rollouts.calculate_rewards( + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface] +) -> nemo_rl.environments.interfaces.EnvironmentReturn +``` + + + + + + +Calculate rewards for generated responses and get environment feedback. + +**Parameters:** + + +Batch containing message_log (LLMMessageLogType) with generated responses + + + +Dictionary mapping task names to their corresponding environments + + +**Returns:** `EnvironmentReturn` + +EnvironmentReturn namedtuple containing: +- observations: List of observations from the environment for the next turn. +- metadata: List of extracted metadata from the environment. +- next_stop_strings: List of stop strings for the next generation step. +- rewards: Tensor of rewards for the last turn. +- terminateds: Tensor of booleans indicating if an episode ended naturally. + + + + + + + + +```python +nemo_rl.experience.rollouts.generate_responses( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + generation_input_data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + input_lengths: torch.Tensor, + include_logprobs: bool = True, + greedy: bool = False +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], list[torch.Tensor], dict[str, float | int]] +``` + + + + + + +Generate responses from policy using synchronous generation. + + + + + + + + +```python +nemo_rl.experience.rollouts.generate_responses_async( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + generation_input_data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + input_lengths: torch.Tensor, + include_logprobs: bool = True, + greedy: bool = False +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], list[torch.Tensor], dict[str, float | int]] +``` + + + + + + +async + +Async version of generate_responses that properly calls generate_async. + + + + + + + + +```python +nemo_rl.experience.rollouts.run_async_multi_turn_rollout( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + max_seq_len: int, + max_rollout_turns: int = 999999, + greedy: bool = False +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], dict[str, typing.Any]] +``` + + + + + + +Run multi-turn rollouts with sample-level processing. + +Each sample in the batch proceeds through its interaction independently. +Async generation is used internally when available but the function is synchronous. + +**Parameters:** + + +The generation interface (policy) + + + +The starting batch containing initial message logs + + + +The tokenizer + + + +Dictionary mapping task names to environment instances + + + +Maximum sequence length allowed + + + +Maximum number of agent-environment interaction turns + + + +Whether to use greedy decoding + + +**Returns:** `tuple[BatchedDataDict[DatumSpec], dict[str, Any]]` + +Tuple containing: +- BatchedDataDict with the full interaction history and accumulated rewards +- Dictionary of rollout metrics + + + + + + + + +```python +nemo_rl.experience.rollouts.run_async_nemo_gym_rollout( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + generation_config: nemo_rl.models.generation.interfaces.GenerationConfig, + max_seq_len: typing.Optional[int] = None, + max_rollout_turns: typing.Optional[int] = None, + greedy: bool = False +) -> nemo_rl.experience.rollouts.AsyncNemoGymRolloutResult +``` + + + + + + +Run multi-turn rollouts with NeMo-Gym. Please refer to the `run_async_multi_turn_rollout` docs for more information on the parameters. + + + + + + + + +```python +nemo_rl.experience.rollouts.run_multi_turn_rollout( + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + max_seq_len: int, + max_rollout_turns: int = 999999, + greedy: bool = False +) -> tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], dict[str, typing.Any]] +``` + + + + + + +Runs a multi-turn rollout loop, interacting with the environment. + +**Parameters:** + + +The generation interface (policy). + + + +The starting batch containing initial message logs. + + + +The tokenizer. + + + +Dictionary mapping task names to environment instances. + + + +Maximum number of agent-environment interaction turns. + + + +Maximum sequence length allowed. + + + +Whether to use greedy decoding. + + +**Returns:** `tuple[BatchedDataDict[DatumSpec], dict[str, Any]]` + +Tuple containing: +- BatchedDataDict with the full interaction history and accumulated rewards +- Dictionary of rollout metrics + + + + + + + + +```python +nemo_rl.experience.rollouts.run_sample_multi_turn_rollout( + sample_idx: int, + initial_sample_state: dict, + policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface, + tokenizer: nemo_rl.experience.rollouts.TokenizerType, + task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface], + max_seq_len: int, + max_rollout_turns: int = 999999, + greedy: bool = False +) -> tuple[dict, dict[str, typing.Any]] +``` + + + + + + +async + +Run a multi-turn rollout for a single sample. + +This function manages the complete lifecycle of one sample's interaction. +Async generation is used internally when available. + +**Parameters:** + + +Index of this sample in the original batch + + + +Initial state containing message_log, extra_env_info, etc. + + + +The generation interface + + + +Tokenizer to use + + + +Environment mapping + + + +Maximum sequence length + + + +Maximum number of turns + + + +Whether to use greedy decoding + + +**Returns:** `tuple[dict, dict[str, Any]]` + +Tuple of (final_sample_state, sample_metrics) + + + + + + + + +```python +nemo_rl.experience.rollouts.TokenizerType = PreTrainedTokenizerBase +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx new file mode 100644 index 0000000..78aeea0 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models.mdx @@ -0,0 +1,14 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models +title: nemo_rl.models +--- + +## Subpackages + +- **[`nemo_rl.models.automodel`](/nemo-rl/nemo_rl/models/automodel)** +- **[`nemo_rl.models.dtensor`](/nemo-rl/nemo_rl/models/dtensor)** +- **[`nemo_rl.models.generation`](/nemo-rl/nemo_rl/models/generation)** +- **[`nemo_rl.models.huggingface`](/nemo-rl/nemo_rl/models/huggingface)** +- **[`nemo_rl.models.megatron`](/nemo-rl/nemo_rl/models/megatron)** +- **[`nemo_rl.models.policy`](/nemo-rl/nemo_rl/models/policy)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx new file mode 100644 index 0000000..a8d957d --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel.mdx @@ -0,0 +1,12 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/automodel +title: nemo_rl.models.automodel +--- + +## Submodules + +- **[`nemo_rl.models.automodel.config`](/nemo-rl/nemo_rl/models/automodel/config)** +- **[`nemo_rl.models.automodel.data`](/nemo-rl/nemo_rl/models/automodel/data)** +- **[`nemo_rl.models.automodel.setup`](/nemo-rl/nemo_rl/models/automodel/setup)** +- **[`nemo_rl.models.automodel.train`](/nemo-rl/nemo_rl/models/automodel/train)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx new file mode 100644 index 0000000..c409b87 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/config.mdx @@ -0,0 +1,125 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/automodel/config +title: nemo_rl.models.automodel.config +--- + +Configuration classes for automodel-based training in NeMo RL. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ModelAndOptimizerState`](#nemo_rl-models-automodel-config-ModelAndOptimizerState) | Container for model and optimizer state. | +| [`RuntimeConfig`](#nemo_rl-models-automodel-config-RuntimeConfig) | Runtime configuration for model training and inference. | + +### API + + + + + +```python +class nemo_rl.models.automodel.config.ModelAndOptimizerState() +``` + + + + + + +**Bases:** `NamedTuple` + +Container for model and optimizer state. + +This named tuple holds all model-related state including the model itself, +optimizer, scheduler, and metadata about the model type and configuration. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.automodel.config.RuntimeConfig() +``` + + + + + + +**Bases:** `NamedTuple` + +Runtime configuration for model training and inference. + +This contains all validated runtime settings needed for model initialization, +parallelization, and training. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx new file mode 100644 index 0000000..f777494 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/data.mdx @@ -0,0 +1,374 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/automodel/data +title: nemo_rl.models.automodel.data +--- + +Data processing utilities for automodel training and inference. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ProcessedInputs`](#nemo_rl-models-automodel-data-ProcessedInputs) | Processed microbatch inputs ready for model forward pass. | +| [`ProcessedMicrobatch`](#nemo_rl-models-automodel-data-ProcessedMicrobatch) | Container for a processed microbatch ready for model forward pass. | + +### Functions + +| Name | Description | +|------|-------------| +| [`check_sequence_dim`](#nemo_rl-models-automodel-data-check_sequence_dim) | Check and validate sequence dimension across all tensors. | +| [`get_microbatch_iterator`](#nemo_rl-models-automodel-data-get_microbatch_iterator) | Create processed microbatch iterator based on batching strategy. | +| [`make_processed_microbatch_iterator`](#nemo_rl-models-automodel-data-make_processed_microbatch_iterator) | Wrap a raw microbatch iterator to yield processed microbatches. | +| [`process_global_batch`](#nemo_rl-models-automodel-data-process_global_batch) | Process a global batch and compute normalization factors. | +| [`process_microbatch`](#nemo_rl-models-automodel-data-process_microbatch) | Process a microbatch and prepare inputs for model forward. | + +### API + + + + + +```python +class nemo_rl.models.automodel.data.ProcessedInputs( + input_ids: torch.Tensor, + seq_len: int, + attention_mask: typing.Optional[torch.Tensor] = None, + position_ids: typing.Optional[torch.Tensor] = None, + flash_attn_kwargs: dict[str, typing.Any] = dict(), + vlm_kwargs: dict[str, typing.Any] = dict(), + cp_buffers: list[torch.Tensor] = list(), + seq_index: typing.Optional[torch.Tensor] = None +) +``` + + + + + + +Dataclass + +Processed microbatch inputs ready for model forward pass. + +This structure contains all necessary tensors and metadata for a forward pass, +including context parallel buffers and flash attention configuration. + + + + + + + + + + + + +Check if context parallel is enabled. + + + +Check if flash attention is configured. + +Works for both empty dict {} and dataclass objects like FlashAttnKwargs. + + + + + + +Check if this is a multimodal input. + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.automodel.data.ProcessedMicrobatch( + data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + original_batch_size: int, + original_seq_len: int +) +``` + + + + + + +Dataclass + +Container for a processed microbatch ready for model forward pass. + +This dataclass holds both the original data dictionary and the processed +tensors needed for the automodel forward pass. It follows the same pattern +as nemo_rl/models/megatron/data.py ProcessedMicrobatch. + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.automodel.data.check_sequence_dim( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +) -> typing.Tuple[int, int] +``` + + + + + + +Check and validate sequence dimension across all tensors. + +Verifies that dimension 1 is the sequence dimension for all tensors +in the data dictionary that have more than one dimension. + +**Parameters:** + + +BatchedDataDict to validate + + +**Returns:** `Tuple[int, int]` + +Tuple of (sequence_dim, seq_dim_size) + +**Raises:** + +- `AssertionError`: If any tensor has inconsistent sequence dimension + + + + + + + + +```python +nemo_rl.models.automodel.data.get_microbatch_iterator( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + cfg: dict[str, typing.Any], + mbs: int, + dp_mesh: typing.Any, + tokenizer: transformers.AutoTokenizer, + cp_size: int = 1 +) -> tuple[typing.Iterator[nemo_rl.models.automodel.data.ProcessedMicrobatch], int] +``` + + + + + + +Create processed microbatch iterator based on batching strategy. + +**Parameters:** + + +Full dataset to iterate over + + + +Configuration dictionary (enable_seq_packing is inferred from cfg["sequence_packing"]["enabled"]) + + + +Microbatch size + + + +Data parallel mesh + + + +Tokenizer for processing + + + +Context parallel size + + +**Returns:** `tuple[Iterator[ProcessedMicrobatch], int]` + +Tuple of (processed_microbatch_iterator, iterator_length) + + + + + + + + +```python +nemo_rl.models.automodel.data.make_processed_microbatch_iterator( + raw_iterator: typing.Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any]], + tokenizer: transformers.AutoTokenizer, + cfg: dict[str, typing.Any], + cp_size: int +) -> typing.Iterator[nemo_rl.models.automodel.data.ProcessedMicrobatch] +``` + + + + + + +Wrap a raw microbatch iterator to yield processed microbatches. + +This function takes a raw iterator that yields BatchedDataDict objects and +wraps it to yield ProcessedMicrobatch objects that contain both the original +data and the processed tensors ready for model forward pass. + +**Parameters:** + + +Iterator yielding raw BatchedDataDict microbatches + + + +Tokenizer for processing + + + +Configuration dictionary (enable_seq_packing is inferred from cfg["sequence_packing"]["enabled"]) + + + +Context parallel size + + + + + + + + + +```python +nemo_rl.models.automodel.data.process_global_batch( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + dp_group: torch.distributed.ProcessGroup, + batch_idx: int, + batch_size: int +) -> dict[str, typing.Any] +``` + + + + + + +Process a global batch and compute normalization factors. + +**Parameters:** + + +Full dataset + + + +Loss function (used to check loss type) + + + +Data parallel process group (for consistency with Megatron naming) + + + +Index of batch to extract + + + +Size of batch to extract + + +**Returns:** `dict[str, Any]` + +Dictionary containing: + + + + + + + + +```python +nemo_rl.models.automodel.data.process_microbatch( + mb: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + tokenizer: transformers.AutoTokenizer, + enable_seq_packing: bool, + cfg: dict[str, typing.Any], + cp_size: int +) -> nemo_rl.models.automodel.data.ProcessedInputs +``` + + + + + + +Process a microbatch and prepare inputs for model forward. + +**Parameters:** + + +Microbatch data + + + +Tokenizer for padding value + + + +Whether sequence packing is enabled + + + +Configuration dictionary + + + +Context parallel size + + +**Returns:** `ProcessedInputs` + +ProcessedInputs containing all tensors and metadata for forward pass + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx new file mode 100644 index 0000000..2b2a115 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/setup.mdx @@ -0,0 +1,229 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/automodel/setup +title: nemo_rl.models.automodel.setup +--- + +Setup utilities for automodel-based training in NeMo RL. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`setup_distributed`](#nemo_rl-models-automodel-setup-setup_distributed) | Set up distributed training environment and create FSDP2Manager. | +| [`setup_model_and_optimizer`](#nemo_rl-models-automodel-setup-setup_model_and_optimizer) | Set up model, parallelization, and optimizer. | +| [`setup_reference_model_state`](#nemo_rl-models-automodel-setup-setup_reference_model_state) | Set up reference model state dict by creating a CPU copy of the model's state dict. | +| [`validate_and_prepare_config`](#nemo_rl-models-automodel-setup-validate_and_prepare_config) | Validate configuration and prepare runtime settings. | + +### Data + +[`STRING_TO_DTYPE`](#nemo_rl-models-automodel-setup-STRING_TO_DTYPE) + +### API + + + + + +```python +nemo_rl.models.automodel.setup.setup_distributed( + config: nemo_rl.models.policy.PolicyConfig, + runtime_config: nemo_rl.models.automodel.config.RuntimeConfig +) -> nemo_automodel.components.distributed.fsdp2.FSDP2Manager +``` + + + + + + +Set up distributed training environment and create FSDP2Manager. + +Initializes torch.distributed process group and creates an FSDP2Manager +with the appropriate parallelization and precision settings. + +**Parameters:** + + +Policy configuration dictionary + + + +RuntimeConfig named tuple from validate_and_prepare_config + + +**Returns:** `FSDP2Manager` + +FSDP2Manager instance with all distributed configuration + + + + + + + + +```python +nemo_rl.models.automodel.setup.setup_model_and_optimizer( + config: nemo_rl.models.policy.PolicyConfig, + tokenizer: transformers.AutoTokenizer, + runtime_config: nemo_rl.models.automodel.config.RuntimeConfig, + distributed_manager: nemo_automodel.components.distributed.fsdp2.FSDP2Manager, + checkpoint_manager: typing.Any, + is_vlm: bool = False, + init_optimizer: bool = True, + weights_path: typing.Optional[str] = None, + optimizer_path: typing.Optional[str] = None +) -> nemo_rl.models.automodel.config.ModelAndOptimizerState +``` + + + + + + +Set up model, parallelization, and optimizer. + +Creates the model from config, applies parallelization strategies (FSDP2, TP, CP), +loads base weights, and optionally initializes optimizer and scheduler. + +**Parameters:** + + +Policy configuration dictionary + + + +Tokenizer for the model + + + +RuntimeConfig named tuple from validate_and_prepare_config + + + +FSDP2Manager from setup_distributed + + + +Checkpoint manager for loading/saving weights + + + +Whether this is a vision-language model + + + +Whether to initialize optimizer + + + +Optional path to checkpoint weights to load + + + +Optional path to optimizer state to load + + +**Returns:** `ModelAndOptimizerState` + +ModelAndOptimizerState containing model, optimizer, scheduler, and metadata + + + + + + + + +```python +nemo_rl.models.automodel.setup.setup_reference_model_state( + model: torch.nn.Module +) -> dict[str, torch.Tensor] +``` + + + + + + +Set up reference model state dict by creating a CPU copy of the model's state dict. + +This creates a reference copy of the model weights on CPU with pinned memory +for efficient CPU-GPU transfers. The reference model is typically used to +compute reference log probabilities during RL training. + +**Parameters:** + + +The model to create a reference copy from + + +**Returns:** `dict[str, torch.Tensor]` + +Dictionary mapping parameter names to CPU tensors with pinned memory + + + + + + + + +```python +nemo_rl.models.automodel.setup.validate_and_prepare_config( + config: nemo_rl.models.policy.PolicyConfig, + processor: typing.Optional[transformers.AutoProcessor], + rank: int +) -> nemo_rl.models.automodel.config.RuntimeConfig +``` + + + + + + +Validate configuration and prepare runtime settings. + +This function validates the policy configuration, sets environment variables, +determines model configuration, and returns runtime settings as a named tuple. + +**Parameters:** + + +Policy configuration dictionary + + + +Optional processor for multimodal models + + + +Current process rank + + +**Returns:** `RuntimeConfig` + +RuntimeConfig named tuple containing validated configuration values + +**Raises:** + +- `ValueError`: If configuration is invalid +- `RuntimeError`: If incompatible settings are detected + + + + + + + + +```python +nemo_rl.models.automodel.setup.STRING_TO_DTYPE = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16} +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx new file mode 100644 index 0000000..7126780 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/automodel/train.mdx @@ -0,0 +1,841 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/automodel/train +title: nemo_rl.models.automodel.train +--- + +Training utilities for automodel (DTensor-based) policy workers. + +This module provides post-processor classes and forward/backward functions +that follow the same pattern as nemo_rl/models/megatron/train.py. + +Key differences from megatron approach: +- Post-processors compute results directly (no callable return pattern) +- forward_with_post_processing_fn calls post-processor directly +- automodel_forward_backward uses PyTorch autograd instead of Megatron's pipeline + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LogprobsPostProcessor`](#nemo_rl-models-automodel-train-LogprobsPostProcessor) | Post-processor for computing log probabilities from model outputs. | +| [`LossPostProcessor`](#nemo_rl-models-automodel-train-LossPostProcessor) | Post-processor for computing training loss from model outputs. | +| [`ScorePostProcessor`](#nemo_rl-models-automodel-train-ScorePostProcessor) | Post-processor for computing reward model scores from model outputs. | +| [`TopkLogitsPostProcessor`](#nemo_rl-models-automodel-train-TopkLogitsPostProcessor) | Post-processor for computing top-k logits from model outputs. | + +### Functions + +| Name | Description | +|------|-------------| +| [`aggregate_training_statistics`](#nemo_rl-models-automodel-train-aggregate_training_statistics) | Aggregate training statistics across microbatches and ranks. | +| [`apply_temperature_scaling`](#nemo_rl-models-automodel-train-apply_temperature_scaling) | Apply temperature scaling to logits. | +| [`automodel_forward_backward`](#nemo_rl-models-automodel-train-automodel_forward_backward) | Execute forward and backward passes for automodel. | +| [`extract_logits`](#nemo_rl-models-automodel-train-extract_logits) | Extract logits from model outputs. | +| [`forward_with_post_processing_fn`](#nemo_rl-models-automodel-train-forward_with_post_processing_fn) | Perform forward pass with pre-processed microbatch and apply post-processing. | +| [`model_forward`](#nemo_rl-models-automodel-train-model_forward) | Perform a single forward pass through the model. | +| [`prepare_data_for_cp`](#nemo_rl-models-automodel-train-prepare_data_for_cp) | Prepare data for context parallel processing. | +| [`redistribute_logits_for_cp`](#nemo_rl-models-automodel-train-redistribute_logits_for_cp) | Redistribute logits for context parallel processing. | + +### Data + +[`PostProcessingFunction`](#nemo_rl-models-automodel-train-PostProcessingFunction) + +### API + + + + + +```python +class nemo_rl.models.automodel.train.LogprobsPostProcessor( + cfg: nemo_rl.models.policy.PolicyConfig, + device_mesh: typing.Any, + cp_mesh: typing.Any, + tp_mesh: typing.Any, + cp_size: int, + enable_seq_packing: bool = False +) +``` + + + + + + +Post-processor for computing log probabilities from model outputs. + + + + + + + + +```python +nemo_rl.models.automodel.train.LogprobsPostProcessor.__call__( + logits: torch.Tensor, + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + input_lengths: torch.Tensor, + original_batch_size: int, + original_seq_len: int, + sequence_dim: int = 1 +) -> torch.Tensor +``` + + + + + + +Compute token log probabilities from logits. + +**Parameters:** + + +Model output logits + + + +Processed inputs + + + +Sequence lengths + + + +Original batch size before packing + + + +Original sequence length before packing + + + +Sequence dimension + + +**Returns:** `torch.Tensor` + +Token log probabilities tensor [batch_size, seq_length] + + + + + + + +```python +nemo_rl.models.automodel.train.LogprobsPostProcessor._compute_local_logprobs( + logits: torch.Tensor, + input_ids: torch.Tensor +) -> torch.Tensor +``` + + + + + + +Compute logprobs locally without distributed processing. + +**Parameters:** + + +Model output logits + + + +Input token IDs + + +**Returns:** `torch.Tensor` + +Token log probabilities + + + + + + + + + +```python +class nemo_rl.models.automodel.train.LossPostProcessor( + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + cfg: nemo_rl.models.policy.PolicyConfig, + device_mesh: typing.Any, + cp_mesh: typing.Any, + tp_mesh: typing.Any, + cp_size: int, + dp_size: int, + enable_seq_packing: bool = False +) +``` + + + + + + +Post-processor for computing training loss from model outputs. + + + + + + +```python +nemo_rl.models.automodel.train.LossPostProcessor.__call__( + logits: torch.Tensor, + mb: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + sequence_dim: int = 1 +) -> tuple[torch.Tensor, dict[str, typing.Any]] +``` + + + + + + +Compute loss from logits. + +**Parameters:** + + +Model output logits + + + +Microbatch data + + + +Processed inputs + + + +Global valid sequence count + + + +Global valid token count + + + +Sequence dimension + + +**Returns:** `tuple[torch.Tensor, dict[str, Any]]` + +Tuple of (loss, metrics) + + + + + + + + + +```python +class nemo_rl.models.automodel.train.ScorePostProcessor( + cfg: nemo_rl.models.policy.PolicyConfig +) +``` + + + + + + +Post-processor for computing reward model scores from model outputs. + + + + + + +```python +nemo_rl.models.automodel.train.ScorePostProcessor.__call__( + logits: torch.Tensor +) -> torch.Tensor +``` + + + + + + +Extract scores from reward model outputs. + +**Parameters:** + + +Model output logits + + +**Returns:** `torch.Tensor` + +Scores tensor + + + + + + + + + +```python +class nemo_rl.models.automodel.train.TopkLogitsPostProcessor( + cfg: nemo_rl.models.policy.PolicyConfig, + device_mesh: typing.Any, + cp_mesh: typing.Any, + tp_mesh: typing.Any, + cp_size: int, + k: int, + enable_seq_packing: bool = False +) +``` + + + + + + +Post-processor for computing top-k logits from model outputs. + + + + + + +```python +nemo_rl.models.automodel.train.TopkLogitsPostProcessor.__call__( + logits: torch.Tensor, + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + input_lengths: torch.Tensor, + original_batch_size: int, + original_seq_len: int, + sequence_dim: int = 1 +) -> tuple[torch.Tensor, torch.Tensor] +``` + + + + + + +Compute top-k logits and indices from model outputs. + +**Parameters:** + + +Model output logits + + + +Processed inputs + + + +Sequence lengths + + + +Original batch size before packing + + + +Original sequence length before packing + + + +Sequence dimension + + +**Returns:** `tuple[torch.Tensor, torch.Tensor]` + +Tuple of (top-k values, top-k indices) tensors + + + + + + + + + +```python +nemo_rl.models.automodel.train.aggregate_training_statistics( + losses: list[float], + all_mb_metrics: list[dict[str, typing.Any]], + grad_norm: typing.Optional[torch.Tensor], + dp_group: typing.Any, + dtype: torch.dtype +) -> dict[str, typing.Any] +``` + + + + + + +Aggregate training statistics across microbatches and ranks. + +**Parameters:** + + +List of loss values from each microbatch + + + +List of metrics dictionaries from each microbatch + + + +Gradient norm tensor (or None if eval mode) + + + +Data parallel process group for all-reduce + + + +Model dtype for metrics + + +**Returns:** `dict[str, Any]` + +Dictionary containing aggregated metrics including global_loss, grad_norm, etc. + + + + + + + + +```python +nemo_rl.models.automodel.train.apply_temperature_scaling( + logits: torch.Tensor, + cfg: nemo_rl.models.policy.PolicyConfig +) -> torch.Tensor +``` + + + + + + +Apply temperature scaling to logits. + +**Parameters:** + + +Logits tensor to scale + + + +Configuration dictionary containing generation settings + + +**Returns:** `torch.Tensor` + +torch.Tensor: Temperature-scaled logits + + + + + + + + +```python +nemo_rl.models.automodel.train.automodel_forward_backward( + model: torch.nn.Module, + cfg: nemo_rl.models.policy.PolicyConfig, + data_iterator: typing.Iterator[nemo_rl.models.automodel.data.ProcessedMicrobatch], + post_processing_fn: nemo_rl.models.automodel.train.PostProcessingFunction, + forward_only: bool = False, + is_reward_model: bool = False, + allow_flash_attn_args: bool = True, + global_valid_seqs: typing.Optional[torch.Tensor] = None, + global_valid_toks: typing.Optional[torch.Tensor] = None, + sequence_dim: int = 1, + dp_size: int = 1, + cp_size: int = 1, + num_global_batches: int = 1, + train_context_fn: typing.Optional[typing.Callable[[ProcessedInputs], typing.Any]] = None, + num_valid_microbatches: typing.Optional[int] = None, + on_microbatch_start: typing.Optional[typing.Callable[[int], None]] = None +) -> list[typing.Tuple[typing.Any, dict[str, typing.Any]]] +``` + + + + + + +Execute forward and backward passes for automodel. + +This is the main training loop function that coordinates forward and backward +passes across multiple microbatches using PyTorch autograd. + +Unlike megatron_forward_backward which uses Megatron's pipeline parallel +framework, this uses standard PyTorch operations. + +**Parameters:** + + +The model to train + + + +Configuration dictionary + + + +Iterator yielding ProcessedMicrobatch objects (already processed) + + + +Number of microbatches to process + + + +Post-processing function to apply to the logits + + + +If True, skip backward pass + + + +Whether this is a reward model + + + +Whether to pass flash_attn_kwargs to model + + + +Global valid sequence count for loss normalization + + + +Global valid token count for loss normalization + + + +Sequence dimension + + + +Data parallel size + + + +Context parallel size + + + +Number of global batches (for metric scaling) + + + +Optional callable that takes ProcessedInputs and returns +a context manager for the forward/backward pass. If None, no context is used. + + + +Number of valid (non-dummy) microbatches. If provided, +microbatches beyond this index are treated as dummy batches (loss *= 0). +If None, all microbatches are considered valid. + + + +Optional callback called at the start of each microbatch +with the microbatch index. Useful for cache clearing, etc. + + +**Returns:** `list[Tuple[Any, dict[str, Any]]]` + +List of (result, metrics) tuples from each microbatch + + + + + + + + +```python +nemo_rl.models.automodel.train.extract_logits( + model: torch.nn.Module, + outputs: typing.Any +) -> torch.Tensor +``` + + + + + + +Extract logits from model outputs. + +**Parameters:** + + +The model (used for lm_head if needed) + + + +Model outputs (can be tensor, DTensor, or object with logits attribute) + + +**Returns:** `torch.Tensor` + +torch.Tensor: Logits tensor + + + + + + + + +```python +nemo_rl.models.automodel.train.forward_with_post_processing_fn( + model: torch.nn.Module, + cfg: nemo_rl.models.policy.PolicyConfig, + post_processing_fn: nemo_rl.models.automodel.train.PostProcessingFunction, + processed_mb: nemo_rl.models.automodel.data.ProcessedMicrobatch, + is_reward_model: bool = False, + allow_flash_attn_args: bool = True, + global_valid_seqs: typing.Optional[torch.Tensor] = None, + global_valid_toks: typing.Optional[torch.Tensor] = None, + sequence_dim: int = 1 +) -> typing.Tuple[typing.Any, dict[str, typing.Any], nemo_rl.models.automodel.data.ProcessedMicrobatch] +``` + + + + + + +Perform forward pass with pre-processed microbatch and apply post-processing. + +This function takes a pre-processed microbatch (with sequence packing already handled), +runs the forward step through the model, and applies the post-processing function +to compute the result. + +Unlike the megatron approach which returns a callable, this directly computes +and returns the result since automodel uses PyTorch autograd. + +**Parameters:** + + +The model to run forward pass on + + + +Configuration dictionary + + + +Post-processing function to apply to the logits + + + +Pre-fetched ProcessedMicrobatch containing data and processed inputs + + + +Whether this is a reward model + + + +Whether to pass flash_attn_kwargs to model + + + +Global valid sequence count for loss normalization + + + +Global valid token count for loss normalization + + + +Sequence dimension + + +**Returns:** `Tuple[Any, dict[str, Any], ProcessedMicrobatch]` + +(result, metrics, processed_microbatch) +- result: Output from post-processing (loss, logprobs, topk, or scores) +- metrics: Dictionary of metrics from post-processing +- processed_microbatch: The ProcessedMicrobatch that was processed + + + + + + + + +```python +nemo_rl.models.automodel.train.model_forward( + model: torch.nn.Module, + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + is_reward_model: bool = False, + allow_flash_attn_args: bool = True +) -> torch.Tensor +``` + + + + + + +Perform a single forward pass through the model. + +**Parameters:** + + +The model to run forward pass on + + + +ProcessedInputs containing all tensors for forward pass + + + +Whether this is a reward model + + + +Whether to pass flash_attn_kwargs to model + + +**Returns:** `torch.Tensor` + +torch.Tensor: Output tensor from the model (logits) + + + + + + + + +```python +nemo_rl.models.automodel.train.prepare_data_for_cp( + mb: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + processed_inputs: nemo_rl.models.automodel.data.ProcessedInputs, + cp_mesh: typing.Any, + sequence_dim: int = 1 +) -> tuple[torch.Tensor, nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any]] +``` + + + + + + +Prepare data for context parallel processing. + +Converts seq_index to full tensor and wraps CP-sharded tensors in DTensor. + +**Parameters:** + + +Microbatch data dictionary + + + +Processed inputs containing CP buffers + + + +Context parallel mesh + + + +Dimension for sequence sharding + + +**Returns:** `tuple[torch.Tensor, BatchedDataDict[Any]]` + +Tuple of (seq_index_dtensor, updated_mb) + + + + + + + + +```python +nemo_rl.models.automodel.train.redistribute_logits_for_cp( + logits: torch.Tensor, + device_mesh: typing.Any, + cp_mesh: typing.Any, + sequence_dim: int = 1 +) -> torch.distributed.tensor.DTensor +``` + + + + + + +Redistribute logits for context parallel processing. + +Handles the case where logits may be TP-sharded DTensor or regular tensor, +and converts them to CP+TP sharded DTensor. + +**Parameters:** + + +Logits tensor (may be DTensor or regular tensor) + + + +Full device mesh + + + +Context parallel mesh (kept for signature compatibility) + + + +Dimension for sequence sharding + + +**Returns:** `DTensor` + +DTensor sharded on both CP and TP dimensions + + + + + + + + +```python +nemo_rl.models.automodel.train.PostProcessingFunction = Union['LossPostProcessor', 'LogprobsPostProcessor', 'TopkLogitsPostProcessor', '... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx new file mode 100644 index 0000000..b29df46 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor.mdx @@ -0,0 +1,9 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/dtensor +title: nemo_rl.models.dtensor +--- + +## Submodules + +- **[`nemo_rl.models.dtensor.parallelize`](/nemo-rl/nemo_rl/models/dtensor/parallelize)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx new file mode 100644 index 0000000..72877d4 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/dtensor/parallelize.mdx @@ -0,0 +1,454 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/dtensor/parallelize +title: nemo_rl.models.dtensor.parallelize +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`RotaryEmbedParallel`](#nemo_rl-models-dtensor-parallelize-RotaryEmbedParallel) | Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_parallelize_gemma3`](#nemo_rl-models-dtensor-parallelize-_parallelize_gemma3) | Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions. | +| [`_parallelize_llama`](#nemo_rl-models-dtensor-parallelize-_parallelize_llama) | Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions. | +| [`_parallelize_model`](#nemo_rl-models-dtensor-parallelize-_parallelize_model) | Parallelize a model using DTensor. | +| [`_parallelize_nm5_h`](#nemo_rl-models-dtensor-parallelize-_parallelize_nm5_h) | Parallelize a NemotronHForCausalLM model across data and tensor parallel dimensions. | +| [`_parallelize_qwen`](#nemo_rl-models-dtensor-parallelize-_parallelize_qwen) | Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions. | +| [`clip_grad_by_total_norm_`](#nemo_rl-models-dtensor-parallelize-clip_grad_by_total_norm_) | Clips gradient of an iterable of parameters by total norm. | +| [`get_grad_norm`](#nemo_rl-models-dtensor-parallelize-get_grad_norm) | Calculate the norm of gradients. | +| [`get_hf_tp_plan`](#nemo_rl-models-dtensor-parallelize-get_hf_tp_plan) | Get the Hugging Face tensor parallel plan from the model. | +| [`to_local_if_dtensor`](#nemo_rl-models-dtensor-parallelize-to_local_if_dtensor) | Returns the local shard of the given tensor if it is a DTensor. | +| [`translate_parallel_style`](#nemo_rl-models-dtensor-parallelize-translate_parallel_style) | Translate parallel style str to parallel type. | + +### Data + +[`PARALLIZE_FUNCTIONS`](#nemo_rl-models-dtensor-parallelize-PARALLIZE_FUNCTIONS) + +### API + + + + + +```python +class nemo_rl.models.dtensor.parallelize.RotaryEmbedParallel() +``` + + + + + + +**Bases:** `SequenceParallel` + +Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple. + + + + + + +```python +nemo_rl.models.dtensor.parallelize.RotaryEmbedParallel._prepare_input_fn( + sequence_sharding, + mod, + inputs, + device_mesh +) +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.RotaryEmbedParallel._prepare_output_fn( + use_local_output, + mod, + outputs, + device_mesh +) +``` + + + + + + +staticmethod + + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize._parallelize_gemma3( + model: typing.Union[transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration], + sequence_parallel: bool = False +) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle] +``` + + + + + + +Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions. + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize._parallelize_llama( + model: transformers.models.llama.modeling_llama.LlamaForCausalLM, + sequence_parallel: bool = False +) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle] +``` + + + + + + +Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions. + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize._parallelize_model( + model: typing.Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM, transformers.models.llama.modeling_llama.LlamaForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration], + dp_mesh: torch.distributed.device_mesh.DeviceMesh, + tp_mesh: torch.distributed.device_mesh.DeviceMesh, + param_dtype: torch.dtype, + sequence_parallel: bool = False, + activation_checkpointing: bool = False, + cpu_offload: bool = False, + custom_parallel_plan: typing.Optional[typing.Union[dict, str]] = None +) +``` + + + + + + +Parallelize a model using DTensor. + +**Parameters:** + + +The model to parallelize. + + + +Device mesh for data parallelism. + + + +Device mesh for tensor parallelism. + + + +Data type for model parameters. + + + +Whether to use sequence parallelism. Defaults to False. + + + +Whether to use activation checkpointing. Defaults to False. + + + +Whether to enable cpu offloading for FSDP. Defaults to False. + + + +Custom parallel plan for the model. Defaults to None. +If it's a dict, it will be used as the parallel plan directly. +If it's a string, it must be a path that points to a dict or a function that returns a dict. +The usage example can refer to `docs/design-docs/fsdp2-parallel-plan.md`. + + +**Returns:** + +The parallelized model. + +**Raises:** + +- `ValueError`: If the model type is not supported for parallelization. + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize._parallelize_nm5_h( + model, + dp_mesh: torch.distributed.device_mesh.DeviceMesh, + tp_mesh: torch.distributed.device_mesh.DeviceMesh, + param_dtype: torch.dtype, + sequence_parallel: bool = False, + activation_checkpointing: bool = False, + cpu_offload: bool = False, + custom_parallel_plan: typing.Optional[typing.Union[dict, str]] = None +) -> torch.distributed.fsdp.FSDPModule +``` + + + + + + +Parallelize a NemotronHForCausalLM model across data and tensor parallel dimensions. + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize._parallelize_qwen( + model: typing.Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM], + sequence_parallel: bool = False +) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle] +``` + + + + + + +Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions. + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.clip_grad_by_total_norm_( + parameters: typing.Union[list[typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], + max_grad_norm: typing.Union[int, float], + total_norm: float +) +``` + + + + + + +Clips gradient of an iterable of parameters by total norm. + +Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L138 + +Note that the gradients are modified in place. + +**Parameters:** + + + +An iterable of Tensors or DTensors, or a single Tensor or DTensor +that will have gradients normalized. + + + +Maximum norm of the gradients. + + + +The pre-computed total norm of the gradients to use for scaling. + + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.get_grad_norm( + parameters: typing.Union[list[typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], + dp_cp_group: torch.distributed.ProcessGroup, + tp_group: torch.distributed.ProcessGroup, + norm_type: typing.Union[int, float] = 2, + dtype: torch.dtype = torch.float32 +) -> float +``` + + + + + + +Calculate the norm of gradients. + +Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L51 + +**Parameters:** + + + +An iterable of Tensors or DTensors, or a single Tensor or DTensor +that will have gradient norm calculated. + + + +Process group for data parallel communication. + + + +Process group for context parallel communication. + + + +Process group for tensor parallel communication. + + + +Type of the used p-norm. Can be ``'inf'`` for +infinity norm. + + +**Returns:** `float` + +Total norm of the gradients (viewed as a single vector) + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.get_hf_tp_plan( + model: transformers.modeling_utils.PreTrainedModel +) +``` + + + + + + +Get the Hugging Face tensor parallel plan from the model. + +This function: +- Retrieves TP strategies from model class, instance, and inner model levels. +- Handles special cases for `embed_tokens` and `lm_head` for speed up. +- Converts string-based parallel styles to DTensor parallelization strategies. + +Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L532 + +**Parameters:** + + +A Hugging Face model instance + + +**Returns:** + +A dictionary mapping model component paths to their parallelization strategies + +**Raises:** + +- `AssertionError`: If no TP plan is found + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.to_local_if_dtensor( + tensor: typing.Union[torch.Tensor, torch.distributed.tensor.DTensor] +) -> torch.Tensor +``` + + + + + + +Returns the local shard of the given tensor if it is a DTensor. + +Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/605f618f237cda8fa80132bc2ccff933512d5a0d/megatron/core/utils.py#L746 + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.translate_parallel_style( + style: str +) +``` + + + + + + +Translate parallel style str to parallel type. + +Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L547 + + + + + + + + +```python +nemo_rl.models.dtensor.parallelize.PARALLIZE_FUNCTIONS: dict[type[Module], Callable[..., dict[str, ParallelStyle]]] = {Qwen2ForCausalLM: _parallelize_qwen, Qwen3ForCausalLM: _parallelize_qwen, Llama... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx new file mode 100644 index 0000000..ff3114c --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation.mdx @@ -0,0 +1,62 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation +title: nemo_rl.models.generation +--- + +## Subpackages + +- **[`nemo_rl.models.generation.sglang`](/nemo-rl/nemo_rl/models/generation/sglang)** +- **[`nemo_rl.models.generation.vllm`](/nemo-rl/nemo_rl/models/generation/vllm)** + +## Submodules + +- **[`nemo_rl.models.generation.interfaces`](/nemo-rl/nemo_rl/models/generation/interfaces)** + +## Package Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`configure_generation_config`](#nemo_rl-models-generation-configure_generation_config) | Apply specific configurations to generation config. | + +### Data + +[`TokenizerType`](#nemo_rl-models-generation-TokenizerType) + +### API + + + + + +```python +nemo_rl.models.generation.configure_generation_config( + config: nemo_rl.models.generation.interfaces.GenerationConfig, + tokenizer: nemo_rl.models.generation.TokenizerType, + is_eval = False +) -> nemo_rl.models.generation.interfaces.GenerationConfig +``` + + + + + + +Apply specific configurations to generation config. + + + + + + + + +```python +nemo_rl.models.generation.TokenizerType = PreTrainedTokenizerBase +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx new file mode 100644 index 0000000..886ccc8 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/interfaces.mdx @@ -0,0 +1,569 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/interfaces +title: nemo_rl.models.generation.interfaces +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ColocationConfig`](#nemo_rl-models-generation-interfaces-ColocationConfig) | - | +| [`GenerationConfig`](#nemo_rl-models-generation-interfaces-GenerationConfig) | Configuration for generation. | +| [`GenerationDatumSpec`](#nemo_rl-models-generation-interfaces-GenerationDatumSpec) | Specification for input data required by generation models. | +| [`GenerationInterface`](#nemo_rl-models-generation-interfaces-GenerationInterface) | Abstract base class defining the interface for RL policies. | +| [`GenerationOutputSpec`](#nemo_rl-models-generation-interfaces-GenerationOutputSpec) | Specification for output data returned by generation models. | +| [`OptionalResourcesConfig`](#nemo_rl-models-generation-interfaces-OptionalResourcesConfig) | - | +| [`ResourcesConfig`](#nemo_rl-models-generation-interfaces-ResourcesConfig) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`verify_right_padding`](#nemo_rl-models-generation-interfaces-verify_right_padding) | Verify that a tensor is right-padded according to the provided lengths. | + +### API + + + + + +```python +class nemo_rl.models.generation.interfaces.ColocationConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.GenerationConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for generation. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.GenerationDatumSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +Specification for input data required by generation models. + +- input_ids: Tensor of token IDs representing the input sequences (right padded) +- input_lengths: Tensor containing the actual length of each sequence (without padding) +- stop_strings: Optional list of strings to stop generation (per sample) +- __extra__: Additional model-specific data fields + +Example of a batch with 4 entries with different sequence lengths: + + +```python +# Batch of 4 sequences with lengths [3, 5, 2, 4] + +input_ids (padded): +[ + [101, 2054, 2003, 0, 0], # Length 3 + [101, 2054, 2003, 2001, 1996], # Length 5 + [101, 2054, 0, 0, 0], # Length 2 + [101, 2054, 2003, 2001, 0], # Length 4 +] + +input_lengths: +[3, 5, 2, 4] +``` + + + +All functions receiving or returning GenerationDatumSpec should ensure +right padding is maintained. Use verify_right_padding() to check. + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.GenerationInterface() +``` + + + + + + +Abstract + +Abstract base class defining the interface for RL policies. + + + +Whether the generation backend requires KV cache scales synchronization. + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.clear_logger_metrics() -> None +``` + + + + + + +Clear logger metrics for performance reporting. + +This is an optional method that backends can implement to clear +telemetry metrics. Default implementation does nothing. + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.finish_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.get_logger_metrics() -> dict[str, typing.Any] +``` + + + + + + +Get logger metrics for performance reporting. + +This is an optional method that backends can implement to collect +telemetry metrics. Default implementation returns empty dict. + +**Returns:** `dict[str, Any]` + +Dictionary of metrics. Format may vary by backend. + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.init_collective( + ip: str, + port: int, + world_size: int +) -> list[ray.ObjectRef] +``` + + + + + + +abstract + +Initialize the collective communication. + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.invalidate_kv_cache() -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.prepare_for_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.prepare_refit_info( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + +Prepare the info for refit. + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.update_weights_from_collective() -> list[ray.ObjectRef] +``` + + + + + + +Update the model weights from collective communication. + + + + + + + +```python +nemo_rl.models.generation.interfaces.GenerationInterface.update_weights_via_ipc_zmq() -> list[ray.ObjectRef] +``` + + + + + + +Update the model weights from the given IPC handles. + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.GenerationOutputSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +Specification for output data returned by generation models. + +- output_ids: Tensor of token IDs representing the generated sequences (right padded) +- generation_lengths: Tensor containing the actual length of each generated sequence +- unpadded_sequence_lengths: Tensor containing the actual length of each input + generated sequence (without padding) +- logprobs: Tensor of log probabilities for each generated token (right padded with zeros) +- truncated: Boolean tensor indicating if each sequence was truncated (hit max_tokens limit) +- __extra__: Additional model-specific data fields + +Example of a batch with 2 sequences: + + +```python +# Sample batch with 2 examples +# - Example 1: Input length 3, generated response length 4 +# - Example 2: Input length 5, generated response length 2 + +output_ids (right-padded): +[ + [101, 2054, 2003, 2023, 2003, 1037, 2200, 0], # 7 valid tokens (3 input + 4 output) + [101, 2054, 2003, 2001, 1996, 3014, 2005, 0], # 7 valid tokens (5 input + 2 output) +] + +generation_lengths: +[4, 2] # Length of just the generated response part + +unpadded_sequence_lengths: +[7, 7] # Length of full valid sequence (input + generated response) + +logprobs (right-padded with zeros): +[ + [0.0, 0.0, 0.0, -1.2, -0.8, -2.1, -1.5, 0.0], # First 3 are 0 (input tokens), next 4 are actual logprobs + [0.0, 0.0, 0.0, 0.0, 0.0, -0.9, -1.7, 0.0], # First 5 are 0 (input tokens), next 2 are actual logprobs +] + +truncated: +[False, True] # Example 2 was truncated (hit max_tokens limit without EOS) +``` + + + +All functions receiving or returning GenerationOutputSpec should ensure +right padding is maintained. Use verify_right_padding() to check. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.OptionalResourcesConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.interfaces.ResourcesConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.interfaces.verify_right_padding( + data: typing.Union[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], + pad_value: int = 0, + raise_error: bool = True +) -> tuple[bool, typing.Union[str, None]] +``` + + + + + + +Verify that a tensor is right-padded according to the provided lengths. + +**Parameters:** + + +The BatchedDataDict to check, containing either: +- For GenerationDatumSpec: input_ids and input_lengths +- For GenerationOutputSpec: output_ids and unpadded_sequence_lengths + + + +The expected padding value (default: 0) + + + +Whether to raise an error if wrong padding is detected + + +**Returns:** `bool` + +Tuple of (is_right_padded, error_message) + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx new file mode 100644 index 0000000..78b60ba --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang.mdx @@ -0,0 +1,33 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang +title: nemo_rl.models.generation.sglang +--- + +## Submodules + +- **[`nemo_rl.models.generation.sglang.config`](/nemo-rl/nemo_rl/models/generation/sglang/config)** +- **[`nemo_rl.models.generation.sglang.sglang_copied_utils`](/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils)** +- **[`nemo_rl.models.generation.sglang.sglang_generation`](/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation)** +- **[`nemo_rl.models.generation.sglang.sglang_worker`](/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker)** +- **[`nemo_rl.models.generation.sglang.utils`](/nemo-rl/nemo_rl/models/generation/sglang/utils)** + +## Package Contents + +### Data + +[`__all__`](#nemo_rl-models-generation-sglang-__all__) + +### API + + + + + +```python +nemo_rl.models.generation.sglang.__all__ = ['SGLangConfig', 'SGLangGeneration'] +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx new file mode 100644 index 0000000..f86bc80 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/config.mdx @@ -0,0 +1,299 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang/config +title: nemo_rl.models.generation.sglang.config +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`SGLangConfig`](#nemo_rl-models-generation-sglang-config-SGLangConfig) | Configuration for SGLang runtime. | +| [`SglangSpecificArgs`](#nemo_rl-models-generation-sglang-config-SglangSpecificArgs) | SGLang-specific configuration arguments. | + +### API + + + + + +```python +class nemo_rl.models.generation.sglang.config.SGLangConfig() +``` + + + + + + +**Bases:** [GenerationConfig](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationConfig) + +Configuration for SGLang runtime. + + + + + + + + + + + + + +```python +class nemo_rl.models.generation.sglang.config.SglangSpecificArgs +``` + + + + + + +**Bases:** `typing.TypedDict` + +SGLang-specific configuration arguments. + +Most fields below map directly to SGLang's ServerArgs (see: +https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py). + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx new file mode 100644 index 0000000..c940dea --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils.mdx @@ -0,0 +1,307 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_copied_utils +title: nemo_rl.models.generation.sglang.sglang_copied_utils +--- + +Standalone utility functions copied from the SGLang project. + +This module contains utility functions that were originally part of the SGLang +repository (https://github.com/sgl-project/sglang). They have been copied here +to avoid requiring sglang as a runtime dependency for weight refitting functionality. + +IMPORTANT: This module should NOT contain any imports from the sglang package. +All functions are standalone and self-contained. + +Each function includes a permalink to its original source in the SGLang repository. +These functions were copied from sglang version 0.5.2. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MultiprocessingSerializer`](#nemo_rl-models-generation-sglang-sglang_copied_utils-MultiprocessingSerializer) | Serialize/deserialize Python objects using ForkingPickler for IPC. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_device_from_maybe_uuid`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_device_from_maybe_uuid) | Convert a device UUID string or index to a device index. | +| [`_device_to_uuid`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_device_to_uuid) | Convert a device index to its UUID string. | +| [`_modify_tuple`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_modify_tuple) | Create a new tuple with one element modified by a function. | +| [`_rebuild_cuda_tensor_modified`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_rebuild_cuda_tensor_modified) | Modified rebuild_cuda_tensor that accepts GPU UUID or device index. | +| [`_reduce_tensor_modified`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_reduce_tensor_modified) | Modified reduce_tensor that stores GPU UUID instead of device index. | +| [`monkey_patch_torch_reductions`](#nemo_rl-models-generation-sglang-sglang_copied_utils-monkey_patch_torch_reductions) | Monkey patch torch multiprocessing reductions to use GPU UUIDs. | + +### Data + +[`_REDUCE_TENSOR_ARG_DEVICE_INDEX`](#nemo_rl-models-generation-sglang-sglang_copied_utils-_REDUCE_TENSOR_ARG_DEVICE_INDEX) + +### API + + + + + +```python +class nemo_rl.models.generation.sglang.sglang_copied_utils.MultiprocessingSerializer() +``` + + + + + + +Serialize/deserialize Python objects using ForkingPickler for IPC. + +This class enables serialization of objects (including CUDA tensors with IPC +handles) for transfer between processes via HTTP or other mechanisms. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/utils.py#L589-L623 + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils.MultiprocessingSerializer.deserialize( + data +) +``` + + + + + + +staticmethod + +Deserialize a previously serialized object. + +**Parameters:** + + +The serialized data, optionally base64-encoded. + + +**Returns:** + +The deserialized Python object. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils.MultiprocessingSerializer.serialize( + obj, + output_str: bool = False +) +``` + + + + + + +staticmethod + +Serialize a Python object using ForkingPickler. + +**Parameters:** + + +The object to serialize. + + + +If True, return a base64-encoded string instead of raw bytes. + + +**Returns:** + +bytes or str: The serialized object. + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._device_from_maybe_uuid( + device_maybe_uuid: typing.Union[int, str] +) -> int +``` + + + + + + +Convert a device UUID string or index to a device index. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L55-L65 + +**Parameters:** + + +Either an integer device index or a UUID string. + + +**Returns:** `int` + +The integer device index. + +**Raises:** + +- `Exception`: If the UUID doesn't match any available device. + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._device_to_uuid( + device: int +) -> str +``` + + + + + + +Convert a device index to its UUID string. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L51-L52 + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._modify_tuple( + t, + index: int, + modifier: typing.Callable +) +``` + + + + + + +Create a new tuple with one element modified by a function. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L68-L69 + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._rebuild_cuda_tensor_modified( + args = () +) +``` + + + + + + +Modified rebuild_cuda_tensor that accepts GPU UUID or device index. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L46-L48 + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._reduce_tensor_modified( + args = (), + kwargs = {} +) +``` + + + + + + +Modified reduce_tensor that stores GPU UUID instead of device index. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L39-L43 + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils.monkey_patch_torch_reductions() +``` + + + + + + +Monkey patch torch multiprocessing reductions to use GPU UUIDs. + +This patch modifies PyTorch's CUDA tensor IPC mechanism to use GPU UUIDs +instead of device indices. This enables proper weight transfer between +processes that may have different CUDA_VISIBLE_DEVICES configurations. + +The patch is idempotent - calling it multiple times is safe. + +This is a workaround before PyTorch https://github.com/pytorch/pytorch/pull/149248 +is merged and released. + +Original source (sglang v0.5.2): +https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L20-L33 + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_copied_utils._REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx new file mode 100644 index 0000000..c8393bd --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_generation.mdx @@ -0,0 +1,369 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_generation +title: nemo_rl.models.generation.sglang.sglang_generation +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`SGLangGeneration`](#nemo_rl-models-generation-sglang-sglang_generation-SGLangGeneration) | - | + +### Data + +[`TOP_K_THRESHOLD`](#nemo_rl-models-generation-sglang-sglang_generation-TOP_K_THRESHOLD) + +[`TOP_P_THRESHOLD`](#nemo_rl-models-generation-sglang-sglang_generation-TOP_P_THRESHOLD) + +[`logger`](#nemo_rl-models-generation-sglang-sglang_generation-logger) + +### API + + + + + +```python +class nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, + config: nemo_rl.models.generation.sglang.config.SGLangConfig, + name_prefix: str = 'sglang_policy', + workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None +) +``` + + + + + + +**Bases:** [GenerationInterface](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationInterface) + + + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.__del__() -> None +``` + + + + + + +Shuts down the worker groups when the object is deleted or is garbage collected. + +This is an extra safety net in case the user forgets to call shutdown() and the pointer to +the object is lost due to leaving a function scope. It's always recommended that the +user calls shutdown(). + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration._allocate_bundles_for_servers( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, + num_servers: int, + gpus_per_server: int +) -> list[tuple[int, list[int]]] +``` + + + + + + +Allocate GPU bundles to each SGLang server. + +Each server gets consecutive bundles within the same placement group (node). +Ray will automatically set CUDA_VISIBLE_DEVICES so each server sees logical GPUs 0, 1, 2, ..., gpus_per_server-1. + +**Parameters:** + + +The Ray virtual cluster + + + +Total number of SGLang servers to create + + + +Number of GPUs each server needs + + +**Returns:** `list[tuple[int, list[int]]]` + +List of (node_idx, [bundle_indices]) tuples for each server + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.finish_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +Sleep workers and reset prefix cache. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using SGLang. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.get_sglang_server_urls() -> list[str] +``` + + + + + + +Get base URLs of all SGLang servers. + +**Returns:** `list[str]` + +List of base URLs (e.g., ["http://localhost:30000", "http://localhost:30001"]) + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.get_sglang_url_to_gpu_uuids() -> dict[str, list[str]] +``` + + + + + + +Get mapping from SGLang server URL to list of GPU UUIDs it uses. + +**Returns:** `dict[str, list[str]]` + +Dict mapping server URL to list of GPU UUIDs + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.init_collective( + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> list[ray.ObjectRef] +``` + + + + + + +Initialize the collective communication. + +TODO: if weight updates via NCCL are needed in the future. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.invalidate_kv_cache() -> bool +``` + + + + + + +Invalidate KV cache before weight updates (Megatron-style). + +This flushes the cache before weight updates to clear stale cache. +Only primary workers (TP rank 0, model owners) will flush their cache. + +**Returns:** `bool` + +True if all caches were flushed successfully, False otherwise + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.prepare_for_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +Wake workers up for colocated inference. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.prepare_refit_info( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.shutdown() -> bool +``` + + + + + + +Shut down all SGLang workers and clean up resources. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.update_weights_from_collective() -> list[ray.ObjectRef] +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.SGLangGeneration.update_weights_via_ipc_zmq() -> list[ray.ObjectRef] +``` + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.TOP_K_THRESHOLD = 8000 +``` + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.TOP_P_THRESHOLD = 0.99 +``` + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_generation.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx new file mode 100644 index 0000000..a74da3c --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/sglang_worker.mdx @@ -0,0 +1,529 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang/sglang_worker +title: nemo_rl.models.generation.sglang.sglang_worker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`SGLangGenerationWorker`](#nemo_rl-models-generation-sglang-sglang_worker-SGLangGenerationWorker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_require_sglang`](#nemo_rl-models-generation-sglang-sglang_worker-_require_sglang) | Import `sglang` lazily so test collection works without the optional extra. | + +### Data + +[`logger`](#nemo_rl-models-generation-sglang-sglang_worker-logger) + +### API + + + + + +```python +class nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker( + config: nemo_rl.models.generation.sglang.config.SGLangConfig, + bundle_indices: typing.Optional[list[int]] = None, + fraction_of_gpus: float = 1.0, + seed: typing.Optional[int] = None +) +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.__repr__() -> str +``` + + + + + + +Customizes the actor's prefix in the Ray logs. + +This makes it easier to identify which worker is producing specific log messages. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._build_sampling_params( + greedy: bool, + stop_strings, + max_new_tokens: typing.Optional[int] = None, + input_len: typing.Optional[int] = None, + context_length: typing.Optional[int] = None, + sample_index: typing.Optional[int] = None +) -> dict[str, typing.Any] +``` + + + + + + +Build sampling parameters dictionary for SGLang API. + +**Parameters:** + + +Whether to use greedy decoding (temperature=0.0) + + + +Merged stop strings (not used here, handled per sample) + + + +Override max_new_tokens from config if provided + + + +Input length for this sample (used for context_length adjustment) + + + +Maximum context length (if provided, adjusts max_new_tokens) + + + +Sample index (used for warning messages, 0-indexed) + + +**Returns:** `dict[str, Any]` + +Dictionary of sampling parameters compatible with SGLang API + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._ensure_session() +``` + + + + + + +async + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._generate_async( + tasks +) +``` + + + + + + +async + +Execute generation tasks with concurrency control. + +TEMP: Uses a semaphore to limit the number of concurrent requests per server, preventing server overload. +A router based solution is preffered in the future. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._generate_single_sample( + input_ids: list[int], + sampling_params: dict[str, typing.Any], + stop_string: typing.Optional[str] = None +) -> tuple[list[int], list[float]] +``` + + + + + + +async + +Generate a single sample using SGLang API (async function). + +**Parameters:** + + +List of input token IDs (without padding) + + + +Dictionary of sampling parameters (temperature, top_p, max_new_tokens, etc.) + + + +Optional stop string for this sample + + +**Returns:** `tuple[list[int], list[float]]` + +Tuple of (generated_tokens, logprobs): +- generated_tokens: List of generated token IDs +- logprobs: List of log probabilities for generated tokens + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._launch_server_process( + server_args: typing.Any +) -> multiprocessing.Process +``` + + + + + + +Launch the SGLang server process and wait for it to be ready. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._make_request( + endpoint: str, + payload: typing.Optional[dict] = None +) +``` + + + + + + +Make a POST request to the specified endpoint with the given payload. + +**Parameters:** + + +The API endpoint to call + + + +The JSON payload to send (default: empty dict) + + +**Returns:** + +The JSON response from the server + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker._merge_stop_strings( + batch_stop_strings +) +``` + + + + + + +Merge stop strings from config and batch. + +**Parameters:** + + +List of stop strings from batch (one per sample) + + +**Returns:** + +List of merged stop strings (one per sample) + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.configure_worker( + num_gpus: int | float, + bundle_indices: typing.Optional[tuple[int, list[int]]] = None +) -> tuple[dict[str, typing.Any], dict[str, str], dict[str, typing.Any]] +``` + + + + + + +staticmethod + +Provides complete worker configuration for SGLang server. + +This method configures the worker based on bundle_indices which tells us +how many GPUs this server should use. + +**Parameters:** + + +Original GPU allocation for this worker based on the placement group + + + +Tuple of (node_idx, local_bundle_indices) for this server + + +**Returns:** `tuple[dict[str, Any], dict[str, str], dict[str, Any]]` + +tuple with complete worker configuration: +- 'resources': Resource allocation (e.g., num_gpus) +- 'env_vars': Environment variables for this worker +- 'init_kwargs': Parameters to pass to __init__ of the worker + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using SGLang generation. + +**Parameters:** + + +BatchedDataDict containing input_ids and input_lengths tensors + + + +Whether to use greedy decoding instead of sampling + + +**Returns:** `BatchedDataDict[GenerationOutputSpec]` + +BatchedDataDict conforming to GenerationOutputSpec: +- output_ids: input + generated token IDs with proper padding +- logprobs: Log probabilities for tokens +- generation_lengths: Lengths of each response +- unpadded_sequence_lengths: Lengths of each input + generated sequence + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.get_base_url() -> str +``` + + + + + + +Get the base URL of this SGLang server. + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.get_gpu_uuids() -> list[str] +``` + + + + + + +Get list of GPU UUIDs used by this SGLang server. + +**Returns:** `list[str]` + +List of GPU UUIDs (e.g., ["GPU-xxxxx", "GPU-yyyyy"]) + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.invalidate_kv_cache() -> bool +``` + + + + + + +Invalidate KV cache before weight updates (Megatron-style). + +This flushes the cache before weight updates to clear stale cache. +Uses retry logic to handle cases where there are pending requests. + +**Returns:** `bool` + +True if flush was successful, False otherwise + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.shutdown() -> bool +``` + + + + + + +Shutdown the SGLang server process and cleanup async resources. + +**Returns:** `bool` + +True if shutdown was successful, False otherwise + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.sleep() +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker.wake_up( + kwargs = {} +) +``` + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker._require_sglang() +``` + + + + + + +Import `sglang` lazily so test collection works without the optional extra. + + + + + + + + +```python +nemo_rl.models.generation.sglang.sglang_worker.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx new file mode 100644 index 0000000..ace8dcd --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/sglang/utils.mdx @@ -0,0 +1,109 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/sglang/utils +title: nemo_rl.models.generation.sglang.utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AsyncLoopThread`](#nemo_rl-models-generation-sglang-utils-AsyncLoopThread) | A background event loop thread for running async operations in Ray actors. | + +### API + + + + + +```python +class nemo_rl.models.generation.sglang.utils.AsyncLoopThread() +``` + + + + + + +A background event loop thread for running async operations in Ray actors. + +This class creates a dedicated thread with its own event loop, allowing +synchronous Ray actor methods to execute async coroutines without blocking +the main actor thread. This is necessary because run_coroutine_threadsafe +requires the event loop to be in a different thread. + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.sglang.utils.AsyncLoopThread._start_loop() +``` + + + + + + +Run the event loop in the background thread. + + + + + + + +```python +nemo_rl.models.generation.sglang.utils.AsyncLoopThread.run( + coro +) +``` + + + + + + +Schedule a coroutine onto the loop and block until it's done. + +**Parameters:** + + +The coroutine to execute + + +**Returns:** + +The result of the coroutine + + + + + + + +```python +nemo_rl.models.generation.sglang.utils.AsyncLoopThread.shutdown() +``` + + + + + + +Shutdown the event loop and wait for the thread to finish. + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx new file mode 100644 index 0000000..c5f278e --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm.mdx @@ -0,0 +1,34 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm +title: nemo_rl.models.generation.vllm +--- + +## Submodules + +- **[`nemo_rl.models.generation.vllm.config`](/nemo-rl/nemo_rl/models/generation/vllm/config)** +- **[`nemo_rl.models.generation.vllm.utils`](/nemo-rl/nemo_rl/models/generation/vllm/utils)** +- **[`nemo_rl.models.generation.vllm.vllm_backend`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend)** +- **[`nemo_rl.models.generation.vllm.vllm_generation`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation)** +- **[`nemo_rl.models.generation.vllm.vllm_worker`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker)** +- **[`nemo_rl.models.generation.vllm.vllm_worker_async`](/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async)** + +## Package Contents + +### Data + +[`__all__`](#nemo_rl-models-generation-vllm-__all__) + +### API + + + + + +```python +nemo_rl.models.generation.vllm.__all__ = ['VllmConfig', 'VllmGeneration', 'VllmGenerationWorker', 'VllmAsyncGenerationWor... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx new file mode 100644 index 0000000..6ca4574 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/config.mdx @@ -0,0 +1,111 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/config +title: nemo_rl.models.generation.vllm.config +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`VllmConfig`](#nemo_rl-models-generation-vllm-config-VllmConfig) | - | +| [`VllmSpecificArgs`](#nemo_rl-models-generation-vllm-config-VllmSpecificArgs) | - | + +### API + + + + + +```python +class nemo_rl.models.generation.vllm.config.VllmConfig() +``` + + + + + + +**Bases:** [GenerationConfig](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationConfig) + + + + + + + + + + + + +```python +class nemo_rl.models.generation.vllm.config.VllmSpecificArgs +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx new file mode 100644 index 0000000..5bfcfad --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/utils.mdx @@ -0,0 +1,113 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/utils +title: nemo_rl.models.generation.vllm.utils +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`aggregate_spec_decode_counters`](#nemo_rl-models-generation-vllm-utils-aggregate_spec_decode_counters) | Aggregate speculative decoding counters from multiple workers. | +| [`compute_spec_decode_metrics`](#nemo_rl-models-generation-vllm-utils-compute_spec_decode_metrics) | Compute delta and derived metrics for speculative decoding. | +| [`format_prompt_for_vllm_generation`](#nemo_rl-models-generation-vllm-utils-format_prompt_for_vllm_generation) | Format a list of prompts for vllm generation (which requires a specific format for its own `generate` method). | + +### API + + + + + +```python +nemo_rl.models.generation.vllm.utils.aggregate_spec_decode_counters( + worker_metrics: list[dict[str, float | list[float]]] +) -> dict[str | tuple[str, int], float] +``` + + + + + + +Aggregate speculative decoding counters from multiple workers. + +Combines spec decode metrics collected from DP leader workers into +a single aggregated counter dictionary. + +**Parameters:** + + +List of metric dictionaries from each worker. +Each dict maps metric names to float values or lists of floats +(for per-position metrics). + + +**Returns:** `dict[str | tuple[str, int], float]` + +Dictionary mapping metric names to their aggregated float values. + + + + + + + + +```python +nemo_rl.models.generation.vllm.utils.compute_spec_decode_metrics( + start_counters: dict[str | tuple[str, int], float], + end_counters: dict[str | tuple[str, int], float] +) -> dict[str, float] +``` + + + + + + +Compute delta and derived metrics for speculative decoding. + +Calculates the difference between two counter snapshots and derives +acceptance rate and acceptance length metrics for logging. + +**Parameters:** + + +Counter snapshot taken before generation. + + + +Counter snapshot taken after generation. + + +**Returns:** `dict[str, float]` + +Dictionary of metrics suitable for logging to wandb/tensorboard. + + + + + + + + +```python +nemo_rl.models.generation.vllm.utils.format_prompt_for_vllm_generation( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + sample_idx: typing.Optional[int] = None +) -> list[dict[str, typing.Any]] +``` + + + + + + +Format a list of prompts for vllm generation (which requires a specific format for its own `generate` method). + +See https://docs.vllm.ai/en/v0.9.1/features/multimodal_inputs.html for prompt format for multimodal inputs. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx new file mode 100644 index 0000000..c06de54 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_backend.mdx @@ -0,0 +1,236 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_backend +title: nemo_rl.models.generation.vllm.vllm_backend +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`VllmInternalWorkerExtension`](#nemo_rl-models-generation-vllm-vllm_backend-VllmInternalWorkerExtension) | - | + +### API + + + + + +```python +class nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension() +``` + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension._maybe_process_fp8_kv_cache() -> None +``` + + + + + + +Process weights after loading for FP8 KV cache (static scales). + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.cleanup() -> None +``` + + + + + + +Shutdown and cleanup resources. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.get_zmq_address() +``` + + + + + + +Get the ZMQ address for the current device. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.init_collective( + rank_prefix: int, + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> None +``` + + + + + + +Initialize the collective communication. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.maybe_init_zmq() +``` + + + + + + +Initialize the ZMQ socket if it doesn't exist. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.prepare_refit_info( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + +Prepare state dict metadata for weight refitting and IPC streaming. + +**Parameters:** + + +A dictionary containing the info for refit. +e.g. {tensor_name: (shape, dtype)} + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.report_device_id() -> str +``` + + + + + + +Retrieve the UUID of the current CUDA device. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.start_gpu_profiling() -> None +``` + + + + + + +Start GPU profiling. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.stop_gpu_profiling() -> None +``` + + + + + + +Stop GPU profiling. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.update_weights_from_collective() -> bool +``` + + + + + + +Update the model weights from collective communication. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_backend.VllmInternalWorkerExtension.update_weights_via_ipc_zmq() -> bool +``` + + + + + + +Receive and update model weights via ZMQ IPC socket. + +**Returns:** `bool` + +True if weights were successfully updated. + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx new file mode 100644 index 0000000..e4cdee2 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_generation.mdx @@ -0,0 +1,656 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_generation +title: nemo_rl.models.generation.vllm.vllm_generation +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`VllmGeneration`](#nemo_rl-models-generation-vllm-vllm_generation-VllmGeneration) | - | + +### Data + +[`TOP_K_THRESHOLD`](#nemo_rl-models-generation-vllm-vllm_generation-TOP_K_THRESHOLD) + +[`TOP_P_THRESHOLD`](#nemo_rl-models-generation-vllm-vllm_generation-TOP_P_THRESHOLD) + +### API + + + + + +```python +class nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, + config: nemo_rl.models.generation.vllm.config.VllmConfig, + name_prefix: str = 'vllm_policy', + workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None +) +``` + + + + + + +**Bases:** [GenerationInterface](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationInterface) + + + + + + + + + + + + + + + + + + + + + + + + + + +Check if KV cache scales should be synchronized during refit. + +Returns True if kv_cache_dtype is fp8/fp8_e4m3. + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.__del__() -> None +``` + + + + + + +Shuts down the worker groups when the object is deleted or is garbage collected. + +This is an extra safety net in case the user forgets to call shutdown() and the pointer to +the object is lost due to leaving a function scope. It's always recommended that the +user calls shutdown(). + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._async_generate_base( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + method_name: str, + data_validation_fn, + greedy: bool = False +) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] +``` + + + + + + +async + +Base async generation method that handles common worker management logic. + +**Parameters:** + + +Input data for generation + + + +Name of the worker method to call ('generate_async' or 'generate_text_async') + + + +Function to validate input data + + + +Whether to use greedy decoding + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._get_raw_spec_counters() -> dict[str | tuple[str, int], float] +``` + + + + + + +Collect raw spec decode counters from workers. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._get_tied_worker_bundle_indices( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster +) -> list[tuple[int, list[int]]] +``` + + + + + + +Calculate bundle indices for tensor and pipeline parallel workers. + +Handles both unified placement groups (for cross-node model parallelism) and +per-node placement groups (for node-local model parallelism). + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._post_init() +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._report_device_id() -> list[list[str]] +``` + + + + + + +Report the device ID of vllm workers. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration._report_dp_openai_server_base_urls() -> list[typing.Optional[str]] +``` + + + + + + +Report the data parallel OpenAI server base URLs of vLLM workers, only populated if it is async vLLM engine and the HTTP server is active. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.clear_logger_metrics() -> None +``` + + + + + + +Clear logger metrics for performance reporting. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.clear_vllm_logger_metrics() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.finish_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +Sleep workers and reset prefix cache. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using vLLM. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate_async( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] +``` + + + + + + +async + +Generate responses asynchronously, yielding individual samples as they complete. + +This method provides per-sample streaming across all workers, yielding each +sample result as soon as it's ready, regardless of which worker processed it. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate_text( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate text responses using vLLM. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.generate_text_async( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] +``` + + + + + + +async + +Generate text responses asynchronously, yielding results as they are ready. + +**Parameters:** + + +BatchedDataDict containing prompts with text strings + + + +Whether to use greedy decoding instead of sampling + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.get_logger_metrics() -> dict[str, typing.Any] +``` + + + + + + +Get logger metrics for performance reporting. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.get_step_metrics() -> dict[str, float] +``` + + + + + + +Get speculative decoding metrics delta since snapshot_step_metrics(). + +**Returns:** `dict[str, float]` + +Dictionary of delta metrics with 'vllm/' prefix. + +**Raises:** + +- `RuntimeWarning`: If called without snapshot_step_metrics() first. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.get_vllm_logger_metrics() -> dict[str, typing.Any] +``` + + + + + + +Collect vLLM logger metrics from vLLM workers (model-owner actors only). + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.init_collective( + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> list[ray.ObjectRef] +``` + + + + + + +Initialize the collective communication. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.invalidate_kv_cache() -> bool +``` + + + + + + +Invalidate reusable caches in vLLM (e.g., prefix/KV cache) after weight updates. + +For async_engine, calls reset_prefix_cache_async on workers. For sync, calls reset_prefix_cache. +Returns True if all workers report success. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.prepare_for_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + +Wake workers up for colocated inference. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.prepare_refit_info( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + +Prepare the info for refit. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.shutdown() -> bool +``` + + + + + + +Shut down all vLLM workers and clean up resources. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.snapshot_step_metrics() -> None +``` + + + + + + +Snapshot current spec decode counters to begin tracking a training step. + +Call this before generation to establish a baseline for metrics delta. + +**Raises:** + +- `RuntimeWarning`: If called twice without get_step_metrics() in between. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.start_gpu_profiling() -> None +``` + + + + + + +Start GPU profiling. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.stop_gpu_profiling() -> None +``` + + + + + + +Stop GPU profiling. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.update_weights_from_collective() -> list[ray.ObjectRef] +``` + + + + + + +Update weights of the policy using collective communication. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.VllmGeneration.update_weights_via_ipc_zmq() -> list[ray.ObjectRef] +``` + + + + + + +Update weights of the policy using IPC handles via ZMQ socket. + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.TOP_K_THRESHOLD = 8000 +``` + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_generation.TOP_P_THRESHOLD = 0.99 +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx new file mode 100644 index 0000000..080c9c1 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker.mdx @@ -0,0 +1,545 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker +title: nemo_rl.models.generation.vllm.vllm_worker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`BaseVllmGenerationWorker`](#nemo_rl-models-generation-vllm-vllm_worker-BaseVllmGenerationWorker) | - | +| [`VllmGenerationWorker`](#nemo_rl-models-generation-vllm-vllm_worker-VllmGenerationWorker) | - | + +### API + + + + + +```python +class nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker( + config: nemo_rl.models.generation.vllm.config.VllmConfig, + bundle_indices: typing.Optional[list[int]] = None, + fraction_of_gpus: float = 1.0, + seed: typing.Optional[int] = None +) +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.__repr__() -> str +``` + + + + + + +Customizes the actor's prefix in the Ray logs. + +This makes it easier to identify which worker is producing specific log messages. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker._build_sampling_params( + greedy: bool, + stop_strings, + max_new_tokens: typing.Optional[int] = None +) +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker._get_raw_spec_counters() -> dict[str, float | list[float]] +``` + + + + + + +Get speculative decoding metrics from the vLLM engine. + +Collects spec decode counters including number of drafts, +draft tokens, and accepted tokens for monitoring acceptance rates. + +**Returns:** `dict[str, float | list[float]]` + +Dictionary mapping metric names to their values. + +**Raises:** + +- `AssertionError`: If called before vLLM engine is initialized. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker._merge_stop_strings( + batch_stop_strings +) +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.configure_worker( + num_gpus: int | float, + bundle_indices: typing.Optional[tuple[int, list[int]]] = None +) -> tuple[dict[str, typing.Any], dict[str, str], dict[str, typing.Any]] +``` + + + + + + +staticmethod + +Provides complete worker configuration for vLLM tensor and pipeline parallelism. + +This method configures the worker based on its role in tensor and pipeline parallelism, +which is determined directly from the bundle_indices parameter. + +**Parameters:** + + +Original GPU allocation for this worker based on the placement group + + + +Tuple of (node_idx, local_bundle_indices) for parallelism (if applicable) + + +**Returns:** `tuple[dict[str, Any], dict[str, str], dict[str, Any]]` + +tuple with complete worker configuration: +- 'resources': Resource allocation (e.g., num_gpus) +- 'env_vars': Environment variables for this worker +- 'init_kwargs': Parameters to pass to __init__ of the worker + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.is_alive() +``` + + + + + + +Check if the worker is alive. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.llm() +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.start_gpu_profiling() -> None +``` + + + + + + +Start GPU profiling. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.BaseVllmGenerationWorker.stop_gpu_profiling() -> None +``` + + + + + + +Stop GPU profiling. + + + + + + + + + +```python +class nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker() +``` + + + + + + +**Bases:** [BaseVllmGenerationWorker](#nemo_rl-models-generation-vllm-vllm_worker-BaseVllmGenerationWorker) + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker._create_engine( + llm_kwargs: dict[str, typing.Any] +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using vLLM generation. + +**Parameters:** + + +BatchedDataDict containing input_ids and input_lengths tensors + + + +Whether to use greedy decoding instead of sampling + + +**Returns:** `BatchedDataDict[GenerationOutputSpec]` + +BatchedDataDict conforming to GenerationOutputSpec: +- output_ids: input + generated token IDs with proper padding +- logprobs: Log probabilities for tokens +- generation_lengths: Lengths of each response +- unpadded_sequence_lengths: Lengths of each input + generated sequence + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.generate_text( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate text responses using vLLM generation. + +**Parameters:** + + +BatchedDataDict containing prompts with text strings + + + +Whether to use greedy decoding instead of sampling + + +**Returns:** `BatchedDataDict[GenerationOutputSpec]` + +BatchedDataDict containing: +- texts: List of generated text responses + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.init_collective( + rank_prefix: int, + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.post_init() +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.prepare_refit_info( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + +Prepare the info for refit. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.report_device_id() -> list[str] +``` + + + + + + +Report device ID from the vLLM worker. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.reset_prefix_cache() +``` + + + + + + +Reset the prefix cache of vLLM engine. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.shutdown() -> bool +``` + + + + + + +Clean up vLLM resources. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.sleep() +``` + + + + + + +Put the vLLM engine to sleep. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.update_weights_from_collective() -> bool +``` + + + + + + +Update the model weights from collective communication. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.update_weights_via_ipc_zmq() -> bool +``` + + + + + + +Update weights from IPC handles via ZMQ socket. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker.wake_up( + kwargs = {} +) +``` + + + + + + +Wake up the vLLM engine. + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx new file mode 100644 index 0000000..b9b731a --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async.mdx @@ -0,0 +1,485 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/generation/vllm/vllm_worker_async +title: nemo_rl.models.generation.vllm.vllm_worker_async +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`VllmAsyncGenerationWorker`](#nemo_rl-models-generation-vllm-vllm_worker_async-VllmAsyncGenerationWorker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_replace_prefix_tokens`](#nemo_rl-models-generation-vllm-vllm_worker_async-_replace_prefix_tokens) | This is a subroutine used inside the vLLM Chat Completion server. | + +### API + + + + + +```python +class nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker() +``` + + + + + + +**Bases:** [BaseVllmGenerationWorker](/nemo-rl/nemo_rl/models/generation/vllm/vllm_worker#nemo_rl-models-generation-vllm-vllm_worker-BaseVllmGenerationWorker) + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._create_engine( + llm_kwargs: dict[str, typing.Any] +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._setup_vllm_openai_api_server( + app: fastapi.FastAPI +) -> fastapi.FastAPI +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._setup_vllm_server() -> tuple[threading.Thread, str, uvicorn.Server] +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker._start_vllm_metrics_logger() -> None +``` + + + + + + +Start a background thread that periodically collects vLLM logger metrics. + +Controlled by vllm_metrics_logger_interval (default: 0.5) in vllm_cfg. +Runs only on the model-owner actor. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.clear_vllm_logger_metrics() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.generate_async( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] +``` + + + + + + +async + +Generate a batch of data using vLLM's AsyncLLMEngine, yielding results as they are ready. + +**Parameters:** + + +BatchedDataDict with input_ids and input_lengths + + + +Whether to use greedy decoding instead of sampling + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.generate_text_async( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> typing.AsyncGenerator[tuple[int, nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec]], None] +``` + + + + + + +async + +Generate text responses asynchronously, yielding results as they are ready. + +**Parameters:** + + +BatchedDataDict containing prompts with text strings + + + +Whether to use greedy decoding instead of sampling + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.get_vllm_logger_metrics() -> dict[str, typing.Any] +``` + + + + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.init_collective_async( + rank_prefix: int, + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> None +``` + + + + + + +async + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.post_init_async() +``` + + + + + + +async + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.prepare_refit_info_async( + state_dict_info: dict[str, typing.Any] +) -> None +``` + + + + + + +async + +Async version of prepare_refit_info. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.report_device_id_async() -> list[str] +``` + + + + + + +async + +Async version of report_device_id. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.report_dp_openai_server_base_url() -> typing.Optional[str] +``` + + + + + + +async + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.reset_prefix_cache_async() +``` + + + + + + +async + +Async version of reset_prefix_cache. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.shutdown() -> bool +``` + + + + + + +async + +Clean up vLLM resources. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.sleep_async() +``` + + + + + + +async + +Async version of sleep. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.update_weights_from_collective_async() -> bool +``` + + + + + + +async + +Async version of update_weights_from_collective. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.update_weights_via_ipc_zmq_async() -> bool +``` + + + + + + +async + +Async version of update_weights_via_ipc_zmq. + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker.wake_up_async( + kwargs = {} +) +``` + + + + + + +async + +Async version of wake_up. + + + + + + + + + +```python +nemo_rl.models.generation.vllm.vllm_worker_async._replace_prefix_tokens( + tokenizer, + model_prefix_token_ids: list[int], + template_prefix_token_ids: list[int], + template_token_ids: list[int] +) -> list[int] +``` + + + + + + +This is a subroutine used inside the vLLM Chat Completion server. + +This function is for fixing up the chat template-tokenized messages history +to match the model output tokenization up to the last assistant turn, +in order to preserve the monotonic tokens property for optimized multi-turn +training. + +Some environments (namely NeMo-Gym) require an OpenAI compatible server +endpoint rather than an inference engine handle. This is fine for the most +part, but it may cause issues when the environment is used as a part of +training. + +RL training frameworks train models on token IDs, but the OpenAI compatible +server communicates in what is basically de-tokenized text. When multiple +model calls are made to the OpenAI compatible server in a single trajectory, +model generations in previous model calls may be re-tokenized to something +that is different than what was generated. This is not too big of an issue +(that we know of) at inference time, but the log probs the model produces +are different enough for the differently re-tokenized generation result that +it causes the training to be off policy. Off policy isn't necessarily a bad +thing in isolation, but this source of off-policyness may cause unexpected +issues if not properly accounted for. It also mis-aligns the token ID +sequences across model calls, which feels very strange during training. + +There are real cases where the model output string _does not match_ the chat +template tokenization of the parsed model output. A concrete example is +inconsistent whitespace tokens around tool call special tokens. + +TODO When NeMo RL supports training image generation models, we want to +revisit and possibly update this function. This issue occurs when the model +generates tokens that are de-tokenized into text or images, and then +re-tokenized into tokens. So if there is a situation like that with images +and image tokenization is non-unique, then we will need to uppdate this +function. + +Example (turn-by-turn, concise; eos_token_id = 2): + Turn 1: + - prefill_T1 (template prefill) = [11,12,13,40,41] + - model output = [220,17,2] # decodes to " 4" + EOS + - model_prefix_token_ids = prefill_T1 + model output + => [11,12,13,40,41,220,17,2] + + Turn 2 (template retokenizes prior assistant text differently): + - template_prefix_token_ids = [11,12,13,40,41,1001,2] # 1001 decodes to " 4" + - template_token_ids = [11,12,13,40,41,1001,2,21,22,40,41] + + _replace_prefix_tokens keeps the exact prior model tokens up to EOS and + resumes from the template after that EOS: + output => [11,12,13,40,41,220,17,2,21,22,40,41] + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx new file mode 100644 index 0000000..6095398 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface.mdx @@ -0,0 +1,9 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/huggingface +title: nemo_rl.models.huggingface +--- + +## Submodules + +- **[`nemo_rl.models.huggingface.common`](/nemo-rl/nemo_rl/models/huggingface/common)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx new file mode 100644 index 0000000..f626968 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/huggingface/common.mdx @@ -0,0 +1,303 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/huggingface/common +title: nemo_rl.models.huggingface.common +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FlashAttentionKwargs`](#nemo_rl-models-huggingface-common-FlashAttentionKwargs) | Dataclass to hold FlashAttention v2 kwargs. | +| [`ModelFlag`](#nemo_rl-models-huggingface-common-ModelFlag) | Enum that defines special flags for model-specific behaviors. | + +### Functions + +| Name | Description | +|------|-------------| +| [`get_flash_attention_kwargs`](#nemo_rl-models-huggingface-common-get_flash_attention_kwargs) | Returns kwargs required for FlashAttention v2 forward functions. | +| [`group_and_cat_tensors`](#nemo_rl-models-huggingface-common-group_and_cat_tensors) | Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor. | +| [`is_gemma_model`](#nemo_rl-models-huggingface-common-is_gemma_model) | - | +| [`pack_sequences`](#nemo_rl-models-huggingface-common-pack_sequences) | Packs sequences into rows where each row concatenates multiple sequences. | +| [`unpack_tensor`](#nemo_rl-models-huggingface-common-unpack_tensor) | Unpacks a packed tensor into individual sequences padded to the same length. | + +### Data + +[`Tensor`](#nemo_rl-models-huggingface-common-Tensor) + +### API + + + + + +```python +class nemo_rl.models.huggingface.common.FlashAttentionKwargs( + cu_seqlens_q: nemo_rl.models.huggingface.common.Tensor, + cu_seqlens_k: nemo_rl.models.huggingface.common.Tensor, + max_seqlen_q: int, + max_seqlen_k: int +) +``` + + + + + + +Dataclass + +Dataclass to hold FlashAttention v2 kwargs. + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.huggingface.common.ModelFlag +``` + + + + + + +**Bases:** `enum.Enum` + +Enum that defines special flags for model-specific behaviors. + +This enum provides a way to identify models that require special handling or +configuration in different parts of the NeMo RL codebase. + +Each flag has a `matches` method that determines if the flag applies to a given model_name. + + + + + + + + + + +```python +nemo_rl.models.huggingface.common.get_flash_attention_kwargs( + input_lengths: torch.Tensor +) -> nemo_rl.models.huggingface.common.FlashAttentionKwargs +``` + + + + + + +Returns kwargs required for FlashAttention v2 forward functions. + +**Parameters:** + + +[batch_size] containing lengths of each sequence + + +**Returns:** `FlashAttentionKwargs` + +Dict[str, torch.Tensor | int]: +{ + "cu_seqlens_q": Tensor[int32], + "cu_seqlens_k": Tensor[int32], + "max_seqlen_q": int, + "max_seqlen_k": int +} + + + + + + + + +```python +nemo_rl.models.huggingface.common.group_and_cat_tensors( + tensors: list[torch.Tensor], + group_sizes: list[int], + padding_value: int = 0, + min_seq_len: int = 0 +) -> torch.Tensor +``` + + + + + + +Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor. + +Each group of 1D tensors is concatenated into a single 1D tensor, and all resulting +group tensors are padded to the same length and stacked into a 2D tensor. + +**Parameters:** + + +List of 1D tensors of varying lengths. + + + +List of integers. Each integer specifies how many tensors to group. + + + +Integer used to pad shorter sequences. + + + +Minimum sequence length. + + +**Returns:** `torch.Tensor` + +A 2D tensor where each row is a padded concatenation of the grouped tensors. + + + + + + + + +```python +nemo_rl.models.huggingface.common.is_gemma_model( + model_name: str +) -> bool +``` + + + + + + + + + + + + + +```python +nemo_rl.models.huggingface.common.pack_sequences( + input_ids: torch.Tensor, + input_lengths: torch.Tensor, + packed_sequence_size: list[int], + padding_value: int = 0, + return_attention_mask: bool = True, + min_seq_len: int = 0 +) -> typing.Tuple[torch.Tensor, torch.Tensor, typing.Optional[torch.Tensor]] +``` + + + + + + +Packs sequences into rows where each row concatenates multiple sequences. + +Useful for sequence packing in transformer models (e.g. for SFT training). Returns: +packed input_ids, packed position_ids, and optional attention_mask. + +**Parameters:** + + +Tensor of shape [num_sequences, max_seq_len] + + + +Tensor of shape [num_sequences], containing true lengths + + + +How many sequences to pack per row + + + +Pad value for input_ids + + + +Whether to return per-row causal attention mask + + + +Minimum sequence length. + + +**Returns:** `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]` + + +input_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] +position_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] +attention_mask (Optional[torch.Tensor]): [batch_size, max_len, max_len] if requested + + + + + + + + +```python +nemo_rl.models.huggingface.common.unpack_tensor( + tensor, + input_lengths +) +``` + + + + + + +Unpacks a packed tensor into individual sequences padded to the same length. + +**Parameters:** + + +Packed tensor of shape [batch_size, packed_seq_len]. + + + +Original sequence lengths in the order they were packed. + + +**Returns:** + +torch.Tensor: [num_sequences, max_seq_len], each row is one unpacked and padded sequence. + + + + + + + + +```python +nemo_rl.models.huggingface.common.Tensor = TypeVar('Tensor', bound=(torch.Tensor)) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx new file mode 100644 index 0000000..368ff06 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron.mdx @@ -0,0 +1,15 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron +title: nemo_rl.models.megatron +--- + +## Submodules + +- **[`nemo_rl.models.megatron.common`](/nemo-rl/nemo_rl/models/megatron/common)** +- **[`nemo_rl.models.megatron.community_import`](/nemo-rl/nemo_rl/models/megatron/community_import)** +- **[`nemo_rl.models.megatron.config`](/nemo-rl/nemo_rl/models/megatron/config)** +- **[`nemo_rl.models.megatron.data`](/nemo-rl/nemo_rl/models/megatron/data)** +- **[`nemo_rl.models.megatron.pipeline_parallel`](/nemo-rl/nemo_rl/models/megatron/pipeline_parallel)** +- **[`nemo_rl.models.megatron.setup`](/nemo-rl/nemo_rl/models/megatron/setup)** +- **[`nemo_rl.models.megatron.train`](/nemo-rl/nemo_rl/models/megatron/train)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx new file mode 100644 index 0000000..eece5c9 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/common.mdx @@ -0,0 +1,133 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/common +title: nemo_rl.models.megatron.common +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_round_up_to_multiple`](#nemo_rl-models-megatron-common-_round_up_to_multiple) | - | +| [`broadcast_tensor`](#nemo_rl-models-megatron-common-broadcast_tensor) | Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata. | +| [`get_moe_metrics`](#nemo_rl-models-megatron-common-get_moe_metrics) | Returns Mixture of Experts (MoE) auxiliary-loss metrics. | + +### API + + + + + +```python +nemo_rl.models.megatron.common._round_up_to_multiple( + value: int, + multiple: int +) -> int +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.common.broadcast_tensor( + tensor: torch.Tensor | None, + src_rank: int, + group: torch.distributed.ProcessGroup +) -> torch.Tensor +``` + + + + + + +Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata. + +Handles the case where the input tensor might be None on non-source ranks. +If the input tensor is provided on non-source ranks, it must have the +correct shape and dtype matching the tensor on the source rank. + +**Parameters:** + + +The tensor to broadcast on the source rank. Can be None on + non-source ranks (will be created with correct shape/dtype). + If not None on non-source ranks, it's used as the buffer + for the broadcast and must match the source tensor's metadata. + + + +The global rank of the source process. + + + +The process group for communication. + + +**Returns:** `torch.Tensor` + +torch.Tensor: The broadcasted tensor. On non-source ranks, this will + be the tensor received from the source. + +**Raises:** + +- `ValueError`: If the tensor is None on the source rank, or if a tensor + provided on a non-source rank has mismatched shape/dtype/device. +- `TypeError`: If broadcasting metadata fails (e.g., due to pickling issues). + + + + + + + + +```python +nemo_rl.models.megatron.common.get_moe_metrics( + loss_scale: float, + total_loss_dict: typing.Optional[dict] = None, + per_layer_logging: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Returns Mixture of Experts (MoE) auxiliary-loss metrics. + +This function reduces MoE auxiliary losses across ranks, aggregates them, and +returns a dictionary of metrics. + +**Parameters:** + + +Scale factor to apply to each auxiliary loss (e.g., 1/num_microbatches). + + + +If provided, accumulate means into this dict (by name). + + + +If True, include per-layer values in the returned dict. + + +**Returns:** `dict[str, Any]` + +dict[str, Any]: A flat dict of aggregated metrics. For each aux loss name, + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx new file mode 100644 index 0000000..a0f53a4 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/community_import.mdx @@ -0,0 +1,76 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/community_import +title: nemo_rl.models.megatron.community_import +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`export_model_from_megatron`](#nemo_rl-models-megatron-community_import-export_model_from_megatron) | - | +| [`import_model_from_hf_name`](#nemo_rl-models-megatron-community_import-import_model_from_hf_name) | Import a Hugging Face model into Megatron checkpoint format and save the Megatron checkpoint to the output path. | + +### API + + + + + +```python +nemo_rl.models.megatron.community_import.export_model_from_megatron( + hf_model_name: str, + input_path: str, + output_path: str, + hf_tokenizer_path: str, + overwrite: bool = False, + hf_overrides: typing.Optional[dict[str, typing.Any]] = {} +) +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.community_import.import_model_from_hf_name( + hf_model_name: str, + output_path: str, + megatron_config: typing.Optional[nemo_rl.models.policy.MegatronConfig] = None, + config_overrides: typing.Any = {} +) +``` + + + + + + +Import a Hugging Face model into Megatron checkpoint format and save the Megatron checkpoint to the output path. + +**Parameters:** + + +Hugging Face model ID or local path (e.g., 'meta-llama/Llama-3.1-8B-Instruct'). + + + +Directory to write the Megatron checkpoint (e.g., /tmp/megatron_ckpt). + + + +Optional megatron config with paralellism settings for distributed megatron model import. + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx new file mode 100644 index 0000000..71f5eba --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/config.mdx @@ -0,0 +1,142 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/config +title: nemo_rl.models.megatron.config +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MegatronGenerationConfig`](#nemo_rl-models-megatron-config-MegatronGenerationConfig) | - | +| [`ModelAndOptimizerState`](#nemo_rl-models-megatron-config-ModelAndOptimizerState) | Container for model and optimizer state. | +| [`RuntimeConfig`](#nemo_rl-models-megatron-config-RuntimeConfig) | Runtime configuration for model training and inference. | + +### API + + + + + +```python +class nemo_rl.models.megatron.config.MegatronGenerationConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.megatron.config.ModelAndOptimizerState() +``` + + + + + + +**Bases:** `NamedTuple` + +Container for model and optimizer state. + +This named tuple holds all model-related state including the model itself, +optimizer, scheduler, and metadata about the model type and configuration. + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.megatron.config.RuntimeConfig() +``` + + + + + + +**Bases:** `NamedTuple` + +Runtime configuration for model training and inference. + +This contains all validated runtime settings needed for model initialization, +parallelization, and training. + + + + + + + + + + + + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx new file mode 100644 index 0000000..aa84bf8 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/data.mdx @@ -0,0 +1,471 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/data +title: nemo_rl.models.megatron.data +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ProcessedInputs`](#nemo_rl-models-megatron-data-ProcessedInputs) | Processed microbatch inputs used for model forward pass. | +| [`ProcessedMicrobatch`](#nemo_rl-models-megatron-data-ProcessedMicrobatch) | Container for a processed microbatch ready for model forward pass. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_pack_sequence_parameters_for_megatron`](#nemo_rl-models-megatron-data-_get_pack_sequence_parameters_for_megatron) | Get pack sequence parameters for Megatron model processing with optional context parallelism. | +| [`_pack_sequences_for_megatron`](#nemo_rl-models-megatron-data-_pack_sequences_for_megatron) | Pack sequences for Megatron model processing with optional context parallelism. | +| [`_unpack_sequences_from_megatron`](#nemo_rl-models-megatron-data-_unpack_sequences_from_megatron) | Unpack sequences from Megatron output format. | +| [`get_and_validate_seqlen`](#nemo_rl-models-megatron-data-get_and_validate_seqlen) | - | +| [`get_microbatch_iterator`](#nemo_rl-models-megatron-data-get_microbatch_iterator) | Create a processed microbatch iterator from a batch of data. | +| [`make_processed_microbatch_iterator`](#nemo_rl-models-megatron-data-make_processed_microbatch_iterator) | Wrap a raw microbatch iterator to yield processed microbatches. | +| [`process_global_batch`](#nemo_rl-models-megatron-data-process_global_batch) | Process a global batch and compute normalization factors. | +| [`process_microbatch`](#nemo_rl-models-megatron-data-process_microbatch) | Process a microbatch for Megatron model forward pass. | + +### API + + + + + +```python +class nemo_rl.models.megatron.data.ProcessedInputs( + input_ids: torch.Tensor, + input_ids_cp_sharded: torch.Tensor, + attention_mask: typing.Optional[torch.Tensor], + position_ids: typing.Optional[torch.Tensor], + packed_seq_params: typing.Optional[megatron.core.packed_seq_params.PackedSeqParams], + cu_seqlens_padded: typing.Optional[torch.Tensor] +) +``` + + + + + + +Dataclass + +Processed microbatch inputs used for model forward pass. + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.megatron.data.ProcessedMicrobatch( + data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + input_ids: torch.Tensor, + input_ids_cp_sharded: torch.Tensor, + attention_mask: typing.Optional[torch.Tensor], + position_ids: typing.Optional[torch.Tensor], + packed_seq_params: typing.Optional[megatron.core.packed_seq_params.PackedSeqParams], + cu_seqlens_padded: typing.Optional[torch.Tensor] +) +``` + + + + + + +Dataclass + +Container for a processed microbatch ready for model forward pass. + +This dataclass holds both the original data dictionary and the processed +tensors needed for the Megatron model forward pass. + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.data._get_pack_sequence_parameters_for_megatron( + megatron_cfg: dict, + max_seq_len_in_batch: int +) +``` + + + + + + +Get pack sequence parameters for Megatron model processing with optional context parallelism. + +**Parameters:** + + +Megatron configuration + + + +Maximum sequence length in batch + + +**Returns:** + +Tuple of: + + + + + + + + +```python +nemo_rl.models.megatron.data._pack_sequences_for_megatron( + input_ids: torch.Tensor, + seq_lengths: torch.Tensor, + pad_individual_seqs_to_multiple_of: int = 1, + pad_packed_seq_to_multiple_of: int = 1, + pad_packed_seq_to: typing.Optional[int] = None, + cp_rank: int = 0, + cp_size: int = 1 +) -> tuple[torch.Tensor, megatron.core.packed_seq_params.PackedSeqParams, torch.Tensor, typing.Optional[torch.Tensor]] +``` + + + + + + +Pack sequences for Megatron model processing with optional context parallelism. + +**Parameters:** + + +Input token IDs [batch_size, seq_length] + + + +Actual sequence lengths for each sample [batch_size] + + + +Pad individual sequences to a multiple of this value + + + +Pad packed sequences to a multiple of this value + + + +Pad packed sequences to this value (before CP) +- The three parameters above can be calculated using _get_pack_sequence_parameters_for_megatron, we do not recommend users to set these parameters manually. + + + +Context parallelism size + + +**Returns:** `torch.Tensor` + +Tuple of: + + + + + + + + +```python +nemo_rl.models.megatron.data._unpack_sequences_from_megatron( + output_tensor: torch.Tensor, + seq_lengths: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqlens_padded: typing.Optional[torch.Tensor], + original_batch_size: int, + original_seq_length: int +) -> torch.Tensor +``` + + + + + + +Unpack sequences from Megatron output format. + +**Parameters:** + + +Packed output tensor [1, T, vocab_size] + + + +Actual sequence lengths for each sample + + + +Cumulative sequence lengths + + + +Padded cumulative sequence lengths (if CP was used) + + + +Original batch size + + + +Original maximum sequence length + + +**Returns:** `torch.Tensor` + +Unpacked output tensor [batch_size, seq_length, vocab_size] + + + + + + + + +```python +nemo_rl.models.megatron.data.get_and_validate_seqlen( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +) +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.data.get_microbatch_iterator( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + cfg: dict[str, typing.Any], + mbs: int, + straggler_timer: megatron.core.utils.StragglerDetector, + seq_length_key: typing.Optional[str] = None +) -> typing.Tuple[typing.Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch], int, int, int, int] +``` + + + + + + +Create a processed microbatch iterator from a batch of data. + +This function creates an iterator that yields ProcessedMicrobatch objects, +which contain both the original data dictionary and the processed tensors +ready for model forward pass. + +**Parameters:** + + +The batch data to create microbatches from + + + +Configuration dictionary + + + +Microbatch size + + + +Key for sequence lengths in data dict (auto-detected if None) + + +**Returns:** `Iterator[ProcessedMicrobatch]` + +Tuple containing the iterator and metadata + + + + + + + + +```python +nemo_rl.models.megatron.data.make_processed_microbatch_iterator( + raw_iterator: typing.Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any]], + cfg: dict[str, typing.Any], + seq_length_key: typing.Optional[str], + pad_individual_seqs_to_multiple_of: int, + pad_packed_seq_to_multiple_of: int, + straggler_timer: megatron.core.utils.StragglerDetector, + pad_full_seq_to: typing.Optional[int] +) -> typing.Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch] +``` + + + + + + +Wrap a raw microbatch iterator to yield processed microbatches. + +This function takes a raw iterator that yields BatchedDataDict objects and +wraps it to yield ProcessedMicrobatch objects that contain both the original +data and the processed tensors ready for model forward pass. + +**Parameters:** + + +Iterator yielding raw BatchedDataDict microbatches + + + +Configuration dictionary containing sequence_packing settings + + + +Key for sequence length in data dict (required for packing) + + + +Padding multiple for individual sequences + + + +Padding multiple for packed sequences + + + +Target length for full sequence padding (optional) + + + + + + + + + +```python +nemo_rl.models.megatron.data.process_global_batch( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + dp_group: torch.distributed.ProcessGroup, + batch_idx: int, + batch_size: int +) -> dict[str, typing.Any] +``` + + + + + + +Process a global batch and compute normalization factors. + +**Parameters:** + + +Full dataset to extract a batch from + + + +Loss function (used to check loss type for token-level validation) + + + +Data parallel process group for all-reduce + + + +Index of batch to extract + + + +Size of batch to extract + + +**Returns:** `dict[str, Any]` + +Dictionary containing: + + + + + + + + +```python +nemo_rl.models.megatron.data.process_microbatch( + data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + seq_length_key: typing.Optional[str] = None, + pad_individual_seqs_to_multiple_of: int = 1, + pad_packed_seq_to_multiple_of: int = 1, + pad_full_seq_to: typing.Optional[int] = None, + pack_sequences: bool = False, + straggler_timer: typing.Optional[megatron.core.utils.StragglerDetector] = None +) -> tuple[torch.Tensor, torch.Tensor, typing.Optional[torch.Tensor], typing.Optional[torch.Tensor], typing.Optional[megatron.core.packed_seq_params.PackedSeqParams], typing.Optional[torch.Tensor]] +``` + + + + + + +Process a microbatch for Megatron model forward pass. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/pipeline_parallel.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/pipeline_parallel.mdx new file mode 100644 index 0000000..501960a --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/pipeline_parallel.mdx @@ -0,0 +1,124 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/pipeline_parallel +title: nemo_rl.models.megatron.pipeline_parallel +--- + +Pipeline parallel utilities for Megatron models. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`broadcast_loss_metrics_from_last_stage`](#nemo_rl-models-megatron-pipeline_parallel-broadcast_loss_metrics_from_last_stage) | Broadcast loss metrics from the last pipeline stage to all stages. | +| [`broadcast_obj_from_pp_rank`](#nemo_rl-models-megatron-pipeline_parallel-broadcast_obj_from_pp_rank) | Broadcast an object across pipeline parallel ranks. | +| [`broadcast_tensors_from_last_stage`](#nemo_rl-models-megatron-pipeline_parallel-broadcast_tensors_from_last_stage) | Broadcast multiple tensors from the last pipeline stage to all stages. | + +### API + + + + + +```python +nemo_rl.models.megatron.pipeline_parallel.broadcast_loss_metrics_from_last_stage( + loss_metrics: typing.Optional[list] = None +) -> list +``` + + + + + + +Broadcast loss metrics from the last pipeline stage to all stages. + +This utility handles the common pattern where loss computation happens on the last +pipeline stage and needs to be broadcast to all other stages. + +**Parameters:** + + +List of loss metrics if on last stage, None otherwise + + +**Returns:** `list` + +List of loss metrics on all ranks + + + + + + + + +```python +nemo_rl.models.megatron.pipeline_parallel.broadcast_obj_from_pp_rank( + obj: typing.Any +) -> typing.Any +``` + + + + + + +Broadcast an object across pipeline parallel ranks. + +This utility function handles broadcasting an object from the rank that owns it +to all other pipeline parallel ranks. If only one rank has the object (non-None), +it will be broadcast to all other ranks. + +**Parameters:** + + +The object to broadcast. Can be None on ranks that don't own it. + + +**Returns:** `Any` + +The object on all ranks (either the original or the broadcast copy). + +**Raises:** + +- `ValueError`: If the object doesn't exist on any pipeline parallel rank. + + + + + + + + +```python +nemo_rl.models.megatron.pipeline_parallel.broadcast_tensors_from_last_stage( + tensors: dict[str, typing.Optional[torch.Tensor]] +) -> dict[str, torch.Tensor] +``` + + + + + + +Broadcast multiple tensors from the last pipeline stage to all stages. + +**Parameters:** + + +Dictionary mapping tensor names to tensors (None on non-last stages) + + + +Pipeline parallel group (auto-detected if None) + + +**Returns:** `dict[str, torch.Tensor]` + +Dictionary of broadcasted tensors on all ranks + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx new file mode 100644 index 0000000..485915a --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/setup.mdx @@ -0,0 +1,535 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/setup +title: nemo_rl.models.megatron.setup +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MoEFloat16Module`](#nemo_rl-models-megatron-setup-MoEFloat16Module) | Float 16 Module with the ability to keep the expert bias in float32. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_apply_moe_config`](#nemo_rl-models-megatron-setup-_apply_moe_config) | Apply Mixture of Experts configuration. | +| [`_apply_parallelism_config`](#nemo_rl-models-megatron-setup-_apply_parallelism_config) | Apply tensor/pipeline/context parallelism configuration. | +| [`_apply_performance_config`](#nemo_rl-models-megatron-setup-_apply_performance_config) | Apply performance optimization configuration. | +| [`_apply_precision_config`](#nemo_rl-models-megatron-setup-_apply_precision_config) | Apply precision and dtype configuration. | +| [`_create_checkpoint_config`](#nemo_rl-models-megatron-setup-_create_checkpoint_config) | Create checkpoint configurations. | +| [`_create_megatron_config`](#nemo_rl-models-megatron-setup-_create_megatron_config) | Create the final Megatron configuration container. | +| [`_validate_chunking_config`](#nemo_rl-models-megatron-setup-_validate_chunking_config) | Validate chunking configuration. | +| [`_validate_dtype_config`](#nemo_rl-models-megatron-setup-_validate_dtype_config) | - | +| [`_validate_optimizer_config`](#nemo_rl-models-megatron-setup-_validate_optimizer_config) | Validate optimizer configuration. | +| [`_validate_training_config`](#nemo_rl-models-megatron-setup-_validate_training_config) | Validate training configuration. | +| [`destroy_parallel_state`](#nemo_rl-models-megatron-setup-destroy_parallel_state) | Safely destroy parallel state and reset async call tracking. | +| [`finalize_megatron_setup`](#nemo_rl-models-megatron-setup-finalize_megatron_setup) | Finalize the setup with remaining configurations. | +| [`handle_model_import`](#nemo_rl-models-megatron-setup-handle_model_import) | Handle HF model import if checkpoint doesn't exist. | +| [`setup_distributed`](#nemo_rl-models-megatron-setup-setup_distributed) | Handle NCCL settings, dtype mapping, and basic config setup. | +| [`setup_model_and_optimizer`](#nemo_rl-models-megatron-setup-setup_model_and_optimizer) | - | +| [`setup_model_config`](#nemo_rl-models-megatron-setup-setup_model_config) | Handle all the model configuration logic. | +| [`setup_reference_model_state`](#nemo_rl-models-megatron-setup-setup_reference_model_state) | Setup the reference model for inference and return its state dict. | +| [`validate_and_set_config`](#nemo_rl-models-megatron-setup-validate_and_set_config) | - | +| [`validate_model_paths`](#nemo_rl-models-megatron-setup-validate_model_paths) | Validate and setup model paths. | + +### Data + +[`HAVE_FSDP2`](#nemo_rl-models-megatron-setup-HAVE_FSDP2) + +[`TokenizerType`](#nemo_rl-models-megatron-setup-TokenizerType) + +### API + + + + + +```python +class nemo_rl.models.megatron.setup.MoEFloat16Module( + config: megatron.core.transformer.transformer_config.TransformerConfig, + module: torch.nn.Module +) +``` + + + + + + +**Bases:** `Float16Module` + +Float 16 Module with the ability to keep the expert bias in float32. + +**Parameters:** + + +The transformer config used to initalize the model + + + + + + + +```python +nemo_rl.models.megatron.setup.MoEFloat16Module.re_enable_float32_expert_bias() -> None +``` + + + + + + +Ensure MoE router expert bias stays in float32 for numerical stability. + +Walks the wrapped module to find MoE routers and invokes the +`_maintain_float32_expert_bias()` helper which recreates or casts the +expert bias tensors to float32 as required by Megatron-LM. + + + + + + + + + +```python +nemo_rl.models.megatron.setup._apply_moe_config( + model_cfg: typing.Any, + config: nemo_rl.models.policy.PolicyConfig +) -> None +``` + + + + + + +Apply Mixture of Experts configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._apply_parallelism_config( + model_cfg: typing.Any, + config: nemo_rl.models.policy.PolicyConfig +) -> None +``` + + + + + + +Apply tensor/pipeline/context parallelism configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._apply_performance_config( + model_cfg: typing.Any, + config: nemo_rl.models.policy.PolicyConfig +) -> None +``` + + + + + + +Apply performance optimization configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._apply_precision_config( + model_cfg: typing.Any, + config: nemo_rl.models.policy.PolicyConfig, + dtype: torch.dtype +) -> None +``` + + + + + + +Apply precision and dtype configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._create_checkpoint_config( + pretrained_path: str, + weights_path: typing.Optional[str] +) -> megatron.bridge.training.config.CheckpointConfig +``` + + + + + + +Create checkpoint configurations. + + + + + + + + +```python +nemo_rl.models.megatron.setup._create_megatron_config( + model_cfg: typing.Any, + checkpoint_config: megatron.bridge.training.config.CheckpointConfig, + config: nemo_rl.models.policy.PolicyConfig, + hf_model_name: str, + dtype: torch.dtype +) -> megatron.bridge.training.config.ConfigContainer +``` + + + + + + +Create the final Megatron configuration container. + + + + + + + + +```python +nemo_rl.models.megatron.setup._validate_chunking_config( + config: nemo_rl.models.policy.PolicyConfig +) -> None +``` + + + + + + +Validate chunking configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._validate_dtype_config( + dtype: torch.dtype, + model_cfg: typing.Any, + optimizer_cfg: typing.Any +) -> None +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.setup._validate_optimizer_config( + config: nemo_rl.models.policy.PolicyConfig +) -> None +``` + + + + + + +Validate optimizer configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup._validate_training_config( + config: nemo_rl.models.policy.PolicyConfig, + model_cfg: typing.Any +) -> None +``` + + + + + + +Validate training configuration. + + + + + + + + +```python +nemo_rl.models.megatron.setup.destroy_parallel_state() +``` + + + + + + +Safely destroy parallel state and reset async call tracking. + +This function is called during initialization to clean up temporary distributed +state from model import operations. Resetting async call tracking ensures that +when the main Megatron distributed context is created, all ranks start with +consistent call_idx values for async checkpointing. + + + + + + + + +```python +nemo_rl.models.megatron.setup.finalize_megatron_setup( + config: nemo_rl.models.policy.PolicyConfig, + megatron_cfg: megatron.bridge.training.config.ConfigContainer, + hf_model_name: str, + worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding, + model, + optimizer +) -> tuple +``` + + + + + + +Finalize the setup with remaining configurations. + +**Returns:** `tuple` + +Tuple of (megatron_tokenizer, megatron_bridge, should_disable_forward_pre_hook, dp_size) + + + + + + + + +```python +nemo_rl.models.megatron.setup.handle_model_import( + config: nemo_rl.models.policy.PolicyConfig, + hf_model_name: str, + pretrained_path: str, + pt_checkpoint_exists: bool +) -> None +``` + + + + + + +Handle HF model import if checkpoint doesn't exist. + + + + + + + + +```python +nemo_rl.models.megatron.setup.setup_distributed() -> None +``` + + + + + + +Handle NCCL settings, dtype mapping, and basic config setup. + + + + + + + + +```python +nemo_rl.models.megatron.setup.setup_model_and_optimizer( + policy_cfg: nemo_rl.models.policy.PolicyConfig, + megatron_cfg: megatron.bridge.training.config.ConfigContainer, + load_optimizer: bool = True, + get_embedding_ranks = None, + get_position_embedding_ranks = None +) +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.setup.setup_model_config( + config: nemo_rl.models.policy.PolicyConfig, + rank, + dtype, + hf_model_name: str, + pretrained_path: str, + weights_path: typing.Optional[str] = None +) -> tuple[megatron.bridge.training.config.ConfigContainer, typing.Any] +``` + + + + + + +Handle all the model configuration logic. + + + + + + + + +```python +nemo_rl.models.megatron.setup.setup_reference_model_state( + config: nemo_rl.models.policy.PolicyConfig, + megatron_cfg: megatron.bridge.training.config.ConfigContainer, + pretrained_path: str +) -> dict +``` + + + + + + +Setup the reference model for inference and return its state dict. + + + + + + + + +```python +nemo_rl.models.megatron.setup.validate_and_set_config( + config, + rank, + hf_model_name, + pretrained_path, + weights_path, + tokenizer +) +``` + + + + + + + + + + + + + +```python +nemo_rl.models.megatron.setup.validate_model_paths( + config: nemo_rl.models.policy.PolicyConfig +) -> tuple[str, str, bool] +``` + + + + + + +Validate and setup model paths. + + + + + + + + +```python +nemo_rl.models.megatron.setup.HAVE_FSDP2 = True +``` + + + + + + + + + +```python +nemo_rl.models.megatron.setup.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/train.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/train.mdx new file mode 100644 index 0000000..e49ba20 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/megatron/train.mdx @@ -0,0 +1,538 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/megatron/train +title: nemo_rl.models.megatron.train +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`LogprobsPostProcessor`](#nemo_rl-models-megatron-train-LogprobsPostProcessor) | - | +| [`LossPostProcessor`](#nemo_rl-models-megatron-train-LossPostProcessor) | - | +| [`TopkLogitsPostProcessor`](#nemo_rl-models-megatron-train-TopkLogitsPostProcessor) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`aggregate_training_statistics`](#nemo_rl-models-megatron-train-aggregate_training_statistics) | Aggregate training statistics across microbatches and data-parallel ranks. | +| [`apply_temperature_scaling`](#nemo_rl-models-megatron-train-apply_temperature_scaling) | Apply temperature scaling to logits. | +| [`forward_with_post_processing_fn`](#nemo_rl-models-megatron-train-forward_with_post_processing_fn) | Perform forward pass with pre-processed microbatch and return output tensor and post-processing function. | +| [`megatron_forward_backward`](#nemo_rl-models-megatron-train-megatron_forward_backward) | Execute forward and backward passes using Megatron's utilities. | +| [`model_forward`](#nemo_rl-models-megatron-train-model_forward) | Perform a single forward pass through the model. | + +### Data + +[`PostProcessingFunction`](#nemo_rl-models-megatron-train-PostProcessingFunction) + +### API + + + + + +```python +class nemo_rl.models.megatron.train.LogprobsPostProcessor( + cfg: nemo_rl.models.policy.PolicyConfig +) +``` + + + + + + + + + + +```python +nemo_rl.models.megatron.train.LogprobsPostProcessor.__call__( + data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + input_ids: torch.Tensor, + cu_seqlens_padded: torch.Tensor +) -> typing.Callable[[torch.Tensor], typing.Tuple[torch.Tensor, typing.Dict[str, torch.Tensor]]] +``` + + + + + + +Create a post-processing function that computes token log probabilities. + +This function returns a processor that takes model logits and converts them +to token-level log probabilities, handling both packed and unpacked sequences. + +**Parameters:** + + +Batched data dictionary containing input sequences + + + +Processed input token IDs + + + +Cumulative sequence lengths for packed sequences + + +**Returns:** `Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]` + +Function that takes output tensor and returns (dummy_loss, {"logprobs": token_logprobs}) + + + + + + + + + +```python +class nemo_rl.models.megatron.train.LossPostProcessor( + loss_fn: nemo_rl.algorithms.loss_functions.LossFunction, + cfg: nemo_rl.models.policy.PolicyConfig, + num_microbatches: int = 1, + cp_normalize: bool = True +) +``` + + + + + + + + + + +```python +nemo_rl.models.megatron.train.LossPostProcessor.__call__( + data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + packed_seq_params: typing.Optional[megatron.core.packed_seq_params.PackedSeqParams] = None, + global_valid_seqs: typing.Optional[torch.Tensor] = None, + global_valid_toks: typing.Optional[torch.Tensor] = None +) -> typing.Callable[[torch.Tensor], typing.Tuple[torch.Tensor, typing.Dict[str, typing.Any]]] +``` + + + + + + +Create a loss post-processing function for training. + +This function wraps a loss function with the necessary context and parameters +to compute loss and metrics from model outputs. It handles sequence packing +and context parallelism normalization. + +**Parameters:** + + +Batched data dictionary for the current microbatch + + + +Parameters for packed sequences (optional) + + + +Global valid sequence count for loss normalization + + + +Global valid token count for loss normalization + + +**Returns:** `Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, Any]]]` + +Function that takes output tensor and returns (loss, metrics) tuple + + + + + + + + + +```python +class nemo_rl.models.megatron.train.TopkLogitsPostProcessor( + cfg: nemo_rl.models.policy.PolicyConfig, + k: int +) +``` + + + + + + + + + + +```python +nemo_rl.models.megatron.train.TopkLogitsPostProcessor.__call__( + data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + cu_seqlens_padded: torch.Tensor +) -> typing.Callable[[torch.Tensor], typing.Tuple[torch.Tensor, typing.Dict[str, torch.Tensor]]] +``` + + + + + + +Create a post-processing function that computes top-k logits and indices. + +This function returns a processor that extracts the top-k highest logits +and their corresponding vocabulary indices from model outputs. It handles +tensor parallelism, context parallelism, and sequence packing. + +**Parameters:** + + +Batched data dictionary + + + +Cumulative sequence lengths for packed sequences + + +**Returns:** `Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]` + +Function that takes output tensor and returns + (dummy_loss, {"topk_logits": values, "topk_indices": indices}) + + + + + + + + + +```python +nemo_rl.models.megatron.train.aggregate_training_statistics( + all_mb_metrics: typing.List[typing.Dict[str, typing.Any]], + losses: typing.List[float], + data_parallel_group: torch.distributed.ProcessGroup +) -> typing.Tuple[typing.Dict[str, typing.List[typing.Any]], torch.Tensor] +``` + + + + + + +Aggregate training statistics across microbatches and data-parallel ranks. + +Computes a global loss by all-reducing per-gradient-buffer losses across the +data-parallel group, then collects per-microbatch metrics into lists keyed by +metric name. + +**Parameters:** + + +List of metric dicts from each microbatch. + + + +List of per-gradient-buffer scalar losses on this rank. + + + +The data-parallel process group for all-reduce. + + +**Returns:** `Tuple[Dict[str, List[Any]], torch.Tensor]` + +Tuple of: +- mb_metrics: Dict mapping metric names to lists of values across microbatches. +- global_loss: Tensor of losses summed across all data-parallel ranks. + + + + + + + + +```python +nemo_rl.models.megatron.train.apply_temperature_scaling( + logits: torch.Tensor, + cfg: nemo_rl.models.policy.PolicyConfig +) -> torch.Tensor +``` + + + + + + +Apply temperature scaling to logits. + +**Parameters:** + + +Logits tensor to scale + + + +Policy configuration containing generation settings + + +**Returns:** `torch.Tensor` + +torch.Tensor: Temperature-scaled logits + + + + + + + + +```python +nemo_rl.models.megatron.train.forward_with_post_processing_fn( + data_iterator: typing.Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch], + model: megatron.core.models.gpt.GPTModel, + cfg: nemo_rl.models.policy.PolicyConfig, + post_processing_fn: nemo_rl.models.megatron.train.PostProcessingFunction, + defer_fp32_logits: typing.Optional[bool] = False, + global_valid_seqs: typing.Optional[torch.Tensor] = None, + global_valid_toks: typing.Optional[torch.Tensor] = None, + straggler_timer: typing.Optional[megatron.core.utils.StragglerDetector] = None +) -> typing.Tuple[torch.Tensor, typing.Callable] +``` + + + + + + +Perform forward pass with pre-processed microbatch and return output tensor and post-processing function. + +This function takes a pre-processed microbatch (with sequence packing already handled), +runs the forward step through the model, and prepares a post-processing function for +post-processing the outputs. + +**Parameters:** + + +Iterator yielding ProcessedMicrobatch objects (already processed) + + + +The model to run forward pass on + + + +Policy configuration dictionary + + + +Post-processing function to post-process the logits + + + +Whether to defer FP32 conversion of logits + + + +Global valid sequence count for loss normalization + + + +Global valid token count for loss normalization + + + +Straggler detector for profiling the forward pass + + +**Returns:** `Tuple[torch.Tensor, Callable]` + +(output_tensor, post_processing_fn_wrapped) +- output_tensor: Raw model outputs (logits) +- post_processing_fn_wrapped: Function to create output post-processing function when called + + + + + + + + +```python +nemo_rl.models.megatron.train.megatron_forward_backward( + model: megatron.core.models.gpt.GPTModel, + cfg: nemo_rl.models.policy.PolicyConfig, + data_iterator: typing.Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch], + num_microbatches: int, + seq_length: int, + mbs: int, + post_processing_fn: nemo_rl.models.megatron.train.PostProcessingFunction, + forward_only: bool = False, + defer_fp32_logits: typing.Optional[bool] = False, + global_valid_seqs: typing.Optional[torch.Tensor] = None, + global_valid_toks: typing.Optional[torch.Tensor] = None, + straggler_timer: typing.Optional[megatron.core.utils.StragglerDetector] = None +) -> typing.Any +``` + + + + + + +Execute forward and backward passes using Megatron's utilities. + +This is the main training loop function that coordinates forward and backward +passes across multiple microbatches using Megatron's pipeline parallel +execution framework. + +**Parameters:** + + +The model to train + + + +Policy configuration dictionary + + + +Iterator yielding ProcessedMicrobatch objects (already processed) + + + +Number of microbatches to process + + + +Sequence length + + + +Micro batch size + + + +Post-processing function to post-process the logits + + + +If True, skip backward pass + + + +Whether to skip the conversion of logits to fp32 + + + +Global valid sequence count for loss normalization + + + +Global valid token count for loss normalization + + + +Straggler detector for profiling the forward pass + + +**Returns:** `Any` + +Results from the forward/backward execution + + + + + + + + +```python +nemo_rl.models.megatron.train.model_forward( + model: megatron.core.models.gpt.GPTModel, + data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + cfg: nemo_rl.models.policy.PolicyConfig, + input_ids_cp_sharded: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + packed_seq_params: typing.Optional[megatron.core.packed_seq_params.PackedSeqParams] = None, + defer_fp32_logits: typing.Optional[bool] = False, + straggler_timer: typing.Optional[megatron.core.utils.StragglerDetector] = None +) -> torch.Tensor +``` + + + + + + +Perform a single forward pass through the model. + +**Parameters:** + + +The model to run forward pass on + + + +Dictionary containing batch data + + + +Policy configuration dictionary + + + +Context-parallel sharded input token IDs + + + +Position IDs for tokens + + + +Attention mask for the sequence + + + +Parameters for packed sequences (optional) + + + +Whether to skip the conversion of logits to fp32 + + + +Straggler detector for profiling the forward pass + + +**Returns:** `torch.Tensor` + +torch.Tensor: Output tensor from the model (logits) + + + + + + + + +```python +nemo_rl.models.megatron.train.PostProcessingFunction = Union['LossPostProcessor', 'LogprobsPostProcessor', 'TopkLogitsPostProcessor'] +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx new file mode 100644 index 0000000..1b3ebde --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy.mdx @@ -0,0 +1,948 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy +title: nemo_rl.models.policy +--- + +## Subpackages + +- **[`nemo_rl.models.policy.workers`](/nemo-rl/nemo_rl/models/policy/workers)** + +## Submodules + +- **[`nemo_rl.models.policy.interfaces`](/nemo-rl/nemo_rl/models/policy/interfaces)** +- **[`nemo_rl.models.policy.lm_policy`](/nemo-rl/nemo_rl/models/policy/lm_policy)** +- **[`nemo_rl.models.policy.utils`](/nemo-rl/nemo_rl/models/policy/utils)** + +## Package Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AutomodelBackendConfig`](#nemo_rl-models-policy-AutomodelBackendConfig) | Configuration for custom MoE implementation backend in Automodel. | +| [`AutomodelKwargs`](#nemo_rl-models-policy-AutomodelKwargs) | - | +| [`DTensorConfig`](#nemo_rl-models-policy-DTensorConfig) | - | +| [`DTensorConfigDisabled`](#nemo_rl-models-policy-DTensorConfigDisabled) | - | +| [`DynamicBatchingConfig`](#nemo_rl-models-policy-DynamicBatchingConfig) | - | +| [`DynamicBatchingConfigDisabled`](#nemo_rl-models-policy-DynamicBatchingConfigDisabled) | - | +| [`LoRAConfig`](#nemo_rl-models-policy-LoRAConfig) | - | +| [`LoRAConfigDisabled`](#nemo_rl-models-policy-LoRAConfigDisabled) | - | +| [`MegatronConfig`](#nemo_rl-models-policy-MegatronConfig) | - | +| [`MegatronConfigDisabled`](#nemo_rl-models-policy-MegatronConfigDisabled) | - | +| [`MegatronDDPConfig`](#nemo_rl-models-policy-MegatronDDPConfig) | - | +| [`MegatronOptimizerConfig`](#nemo_rl-models-policy-MegatronOptimizerConfig) | - | +| [`MegatronSchedulerConfig`](#nemo_rl-models-policy-MegatronSchedulerConfig) | - | +| [`PolicyConfig`](#nemo_rl-models-policy-PolicyConfig) | - | +| [`PytorchOptimizerConfig`](#nemo_rl-models-policy-PytorchOptimizerConfig) | - | +| [`RewardModelConfig`](#nemo_rl-models-policy-RewardModelConfig) | - | +| [`SequencePackingConfig`](#nemo_rl-models-policy-SequencePackingConfig) | - | +| [`SequencePackingConfigDisabled`](#nemo_rl-models-policy-SequencePackingConfigDisabled) | - | +| [`SinglePytorchMilestonesConfig`](#nemo_rl-models-policy-SinglePytorchMilestonesConfig) | - | +| [`SinglePytorchSchedulerConfig`](#nemo_rl-models-policy-SinglePytorchSchedulerConfig) | - | +| [`TokenizerConfig`](#nemo_rl-models-policy-TokenizerConfig) | - | + +### Data + +[`SchedulerMilestones`](#nemo_rl-models-policy-SchedulerMilestones) + +### API + + + + + +```python +class nemo_rl.models.policy.AutomodelBackendConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for custom MoE implementation backend in Automodel. + +Used when setting the backend in automodel_kwargs in your config. +Alternatively, pass `force_hf: true` in automodel_kwargs to fall back +to the HuggingFace implementation. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.AutomodelKwargs +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.DTensorConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.DTensorConfigDisabled +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.DynamicBatchingConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.DynamicBatchingConfigDisabled +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.LoRAConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.LoRAConfigDisabled +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.MegatronConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.MegatronConfigDisabled +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.MegatronDDPConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.MegatronOptimizerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.MegatronSchedulerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.PolicyConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.PytorchOptimizerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.RewardModelConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.SequencePackingConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.SequencePackingConfigDisabled +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.SinglePytorchMilestonesConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.models.policy.SinglePytorchSchedulerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.models.policy.TokenizerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.policy.SchedulerMilestones = dict[str, list[int]] +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx new file mode 100644 index 0000000..8cbb649 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/interfaces.mdx @@ -0,0 +1,574 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/interfaces +title: nemo_rl.models.policy.interfaces +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ColocatablePolicyInterface`](#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) | - | +| [`LogprobOutputSpec`](#nemo_rl-models-policy-interfaces-LogprobOutputSpec) | logprobs: Tensor of log probabilities. | +| [`PolicyInterface`](#nemo_rl-models-policy-interfaces-PolicyInterface) | Abstract base class defining the interface for RL policies. | +| [`ReferenceLogprobOutputSpec`](#nemo_rl-models-policy-interfaces-ReferenceLogprobOutputSpec) | logprobs: Tensor of log probabilities. | +| [`ScoreOutputSpec`](#nemo_rl-models-policy-interfaces-ScoreOutputSpec) | scores: Tensor of scores. | +| [`TopkLogitsOutputSpec`](#nemo_rl-models-policy-interfaces-TopkLogitsOutputSpec) | Per-position top-k logits and corresponding global token indices. | + +### API + + + + + +```python +class nemo_rl.models.policy.interfaces.ColocatablePolicyInterface() +``` + + + + + + +**Bases:** [PolicyInterface](#nemo_rl-models-policy-interfaces-PolicyInterface) + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.broadcast_weights_for_collective( + kv_scales: typing.Optional[dict[str, float]] = None +) -> list[ray.ObjectRef] +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.init_collective( + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> list[ray.ObjectRef] +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.offload_after_refit() -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.offload_before_refit() -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.prepare_for_lp_inference() -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.stream_weights_via_http( + sglang_url_to_gpu_uuids: dict[str, list[str]] +) -> list[ray.ObjectRef] +``` + + + + + + +Stream model weights to SGLang servers via HTTP API. + +**Parameters:** + + +Dict mapping SGLang server URL to list of GPU UUIDs it uses + + + + + + + + +```python +nemo_rl.models.policy.interfaces.ColocatablePolicyInterface.stream_weights_via_ipc_zmq( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> list[ray.ObjectRef] +``` + + + + + + +abstract + + + + + + + + + +```python +class nemo_rl.models.policy.interfaces.LogprobOutputSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +logprobs: Tensor of log probabilities. + + + + + + + + + + + +```python +class nemo_rl.models.policy.interfaces.PolicyInterface() +``` + + + + + + +Abstract + +Abstract base class defining the interface for RL policies. + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.calibrate_qkv_fp8_scales( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + micro_batch_size: typing.Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +abstract + +Calibrate FP8 scales for Q/K/V activations used by KV cache. + +**Parameters:** + + +BatchedDataDict containing input_ids and input_lengths. + + + +Optional override for micro batch size during calibration. + + + +Percentile for per-tensor amax estimation. + + + +Safety margin multiplier applied to amax. + + + +Whether to also compute scale for Q in addition to K/V. + + +**Returns:** `dict[str, Any]` + +Dict with overall configuration and per-layer scales. + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.finish_training( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.get_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] +``` + + + + + + +abstract + +Get logprobs of actions from observations. + +**Parameters:** + + +BatchedDataDict containing rollouts (tokens) + + +**Returns:** `BatchedDataDict[LogprobOutputSpec]` + +BatchedDataDict containing: +- logprobs: Tensor of logprobs of actions + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.get_reference_policy_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + micro_batch_size: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec] +``` + + + + + + +abstract + +Get logprobs of actions from observations. + +**Parameters:** + + +BatchedDataDict containing rollouts (tokens) + + +**Returns:** `BatchedDataDict[ReferenceLogprobOutputSpec]` + +BatchedDataDict containing: +- logprobs: Tensor of logprobs of actions + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.get_topk_logits( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + k: int, + micro_batch_size: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec] +``` + + + + + + +abstract + +Get per-position top-k logits and global indices for a batch of inputs. + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.prepare_for_training( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.save_checkpoint( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.shutdown() -> bool +``` + + + + + + +abstract + + + + + + + +```python +nemo_rl.models.policy.interfaces.PolicyInterface.train( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict, + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + eval_mode: bool = False, + gbs: typing.Optional[int] = None, + mbs: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> dict[str, typing.Any] +``` + + + + + + +abstract + +Train the policy on a global batch of data. + +**Parameters:** + + +BatchedDataDict containing rollouts (tokens) + + + +Loss function to use for training + + + +Whether to run in evaluation mode (no gradient updates) + + + +Global batch size override (if None, uses config default) + + + +Micro batch size override (if None, uses config default) + + + + + + + + + + +```python +class nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +logprobs: Tensor of log probabilities. + + + + + + + + + + + +```python +class nemo_rl.models.policy.interfaces.ScoreOutputSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +scores: Tensor of scores. + + + + + + + + + + + +```python +class nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec +``` + + + + + + +**Bases:** `typing.TypedDict` + +Per-position top-k logits and corresponding global token indices. + + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx new file mode 100644 index 0000000..7636f66 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/lm_policy.mdx @@ -0,0 +1,609 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/lm_policy +title: nemo_rl.models.policy.lm_policy +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Policy`](#nemo_rl-models-policy-lm_policy-Policy) | - | + +### Data + +[`PathLike`](#nemo_rl-models-policy-lm_policy-PathLike) + +### API + + + + + +```python +class nemo_rl.models.policy.lm_policy.Policy( + cluster: nemo_rl.distributed.virtual_cluster.RayVirtualCluster, + config: nemo_rl.models.policy.PolicyConfig, + tokenizer: transformers.PreTrainedTokenizerBase, + name_prefix: str = 'lm_policy', + workers_per_node: typing.Optional[typing.Union[int, list[int]]] = None, + init_optimizer: bool = True, + weights_path: typing.Optional[nemo_rl.models.policy.lm_policy.PathLike] = None, + optimizer_path: typing.Optional[nemo_rl.models.policy.lm_policy.PathLike] = None, + init_reference_model: bool = True, + processor: typing.Optional[transformers.AutoProcessor] = None +) +``` + + + + + + +**Bases:** [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface), [GenerationInterface](/nemo-rl/nemo_rl/models/generation/interfaces#nemo_rl-models-generation-interfaces-GenerationInterface) + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.__del__() -> None +``` + + + + + + +Shuts down the worker groups when the object is deleted or is garbage collected. + +This is an extra safety net in case the user forgets to call worker_group.shutdown() and the pointer to +the object is lost due to leaving a function scope. It's always recommended that the +user calls worker_group.shutdown(). + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.broadcast_weights_for_collective( + kv_scales: typing.Optional[dict[str, float]] = None +) -> list[ray.ObjectRef] +``` + + + + + + +Broadcast the weights for collective communication. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.calibrate_qkv_fp8_scales( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + micro_batch_size: typing.Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Trigger KV-cache FP8 scale calibration across Megatron workers and return results. + +Note: The backend `MegatronPolicyWorker.calibrate_qkv_fp8_scales` already implements +distributed reduction, returning results merged across ranks. Therefore, we shard the +input by DP and call in parallel, then take the result from the first worker. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.finish_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.finish_training( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using the policy. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.get_free_memory_bytes() -> int +``` + + + + + + +Get the available free memory. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.get_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] +``` + + + + + + +Get the logprobs of the model for a data dict. + +**Returns:** `BatchedDataDict[LogprobOutputSpec]` + +a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.get_reference_policy_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + micro_batch_size: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec] +``` + + + + + + +Get the logprobs of the reference policy for a data dict. + +Returns: Identical to get_logprobs. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.get_topk_logits( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + k: int, + micro_batch_size: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.TopkLogitsOutputSpec] +``` + + + + + + +Dispatch get_topk_logits to workers (no CP/packed support initially). + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.init_collective( + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> list[ray.ObjectRef] +``` + + + + + + +Initialize the collective communication. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.invalidate_kv_cache( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.offload_after_refit() -> None +``` + + + + + + +Offload the optimizer and buffers to the CPU. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.offload_before_refit() -> None +``` + + + + + + +Offload the optimizer and buffers to the CPU. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.prepare_for_generation( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> bool +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.prepare_for_lp_inference( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.prepare_for_training( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +Prepare the info for refit. + +**Returns:** `Optional[dict[str, Any]]` + +A dictionary containing the info for refit. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.print_node_ip_and_gpu_id() -> list[tuple[str, int]] +``` + + + + + + +Print the node IP and GPU ID of the current worker. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.save_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None, + tokenizer_path: typing.Optional[str] = None, + checkpointing_cfg: typing.Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None +) -> None +``` + + + + + + +Save a checkpoint of the model. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.score( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec] +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec] +``` + + + + + + +Score a batch of data using the policy. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.shutdown() -> bool +``` + + + + + + +Shut down all HF workers and clean up resources. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.start_gpu_profiling() -> None +``` + + + + + + +Start GPU profiling. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.stop_gpu_profiling() -> None +``` + + + + + + +Stop GPU profiling. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.stream_weights_via_http( + sglang_url_to_gpu_uuids: dict[str, list[str]] +) -> list[ray.ObjectRef] +``` + + + + + + +Send the weights to SGLang servers via HTTP API. + +**Parameters:** + + +Dict mapping SGLang server URL to list of GPU UUIDs it uses + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.stream_weights_via_ipc_zmq( + buffer_size_bytes: int, + kv_scales: typing.Optional[dict[str, float]] = None +) -> list[ray.ObjectRef] +``` + + + + + + +Send the weights for IPC handles via ZMQ socket. + + + + + + + +```python +nemo_rl.models.policy.lm_policy.Policy.train( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + eval_mode: bool = False, + gbs: typing.Optional[int] = None, + mbs: typing.Optional[int] = None, + timer: typing.Optional[nemo_rl.utils.timer.Timer] = None +) -> dict[str, typing.Any] +``` + + + + + + +Train the policy on a batch of data with a given loss function. + + + + + + + + + +```python +nemo_rl.models.policy.lm_policy.PathLike = Union[str, 'os.PathLike[Any]'] +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx new file mode 100644 index 0000000..63b8675 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/utils.mdx @@ -0,0 +1,624 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/utils +title: nemo_rl.models.policy.utils +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`IPCProtocol`](#nemo_rl-models-policy-utils-IPCProtocol) | IPC protocol constants for ZMQ weight streaming. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_gather_ipc_handlers`](#nemo_rl-models-policy-utils-_gather_ipc_handlers) | Gather IPC handlers from all ranks in the default FSDP group, then filter by server. | +| [`_send_tensor_to_sglang`](#nemo_rl-models-policy-utils-_send_tensor_to_sglang) | Send gathered IPC handlers to SGLang server via HTTP. | +| [`_setup_ipc_gather_group`](#nemo_rl-models-policy-utils-_setup_ipc_gather_group) | Setup gather configuration for IPC handlers. | +| [`apply_top_k_only`](#nemo_rl-models-policy-utils-apply_top_k_only) | Apply top-k mask to the logits. | +| [`apply_top_k_top_p`](#nemo_rl-models-policy-utils-apply_top_k_top_p) | Apply top-k and top-p masks to the logits. | +| [`calculate_aligned_size`](#nemo_rl-models-policy-utils-calculate_aligned_size) | Calculate aligned size for memory alignment. | +| [`configure_dynamo_cache`](#nemo_rl-models-policy-utils-configure_dynamo_cache) | Disable dynamo autotune_local_cache. | +| [`get_gpu_info`](#nemo_rl-models-policy-utils-get_gpu_info) | Return information about the GPU being used by this worker. | +| [`get_handle_from_tensor`](#nemo_rl-models-policy-utils-get_handle_from_tensor) | Get IPC handle from a tensor. | +| [`get_megatron_checkpoint_dir`](#nemo_rl-models-policy-utils-get_megatron_checkpoint_dir) | Gets the default megatron checkpoint directory for initial HF -> Mcore conversion. | +| [`get_runtime_env_for_policy_worker`](#nemo_rl-models-policy-utils-get_runtime_env_for_policy_worker) | Get runtime environment configuration for policy workers. | +| [`is_vllm_v1_engine_enabled`](#nemo_rl-models-policy-utils-is_vllm_v1_engine_enabled) | Check if vLLM V1 engine is enabled. | +| [`rebuild_cuda_tensor_from_ipc`](#nemo_rl-models-policy-utils-rebuild_cuda_tensor_from_ipc) | Rebuild a CUDA tensor from an IPC handle. | +| [`resolve_model_class`](#nemo_rl-models-policy-utils-resolve_model_class) | Resolve the appropriate model class for a given model name. | +| [`stream_weights_via_http_impl`](#nemo_rl-models-policy-utils-stream_weights_via_http_impl) | Stream weights to SGLang servers via HTTP API (update_weights_from_tensor). | +| [`stream_weights_via_ipc_zmq_impl`](#nemo_rl-models-policy-utils-stream_weights_via_ipc_zmq_impl) | Shared implementation for streaming weights via IPC ZMQ with improved memory management. | + +### Data + +[`AUTOMODEL_FACTORY`](#nemo_rl-models-policy-utils-AUTOMODEL_FACTORY) + +[`NEMO_AUTOMODEL_AVAILABLE`](#nemo_rl-models-policy-utils-NEMO_AUTOMODEL_AVAILABLE) + +### API + + + + + +```python +class nemo_rl.models.policy.utils.IPCProtocol +``` + + + + + + +**Bases:** `enum.Enum` + +IPC protocol constants for ZMQ weight streaming. + + + + + + + + + + + + + +```python +nemo_rl.models.policy.utils._gather_ipc_handlers( + serialized_handler: str, + gather_group: typing.Optional[torch.distributed.ProcessGroup], + gather_src: typing.Optional[int], + rank: int, + matching_ranks: typing.Optional[list[int]] = None +) -> typing.Optional[list[str]] +``` + + + + + + +Gather IPC handlers from all ranks in the default FSDP group, then filter by server. + +**Parameters:** + + +Serialized IPC handler from this rank + + + +Process group (None means use default FSDP group) + + + +Rank that will collect and filter handlers + + + +Current rank + + + +List of ranks that belong to the same SGLang server + + +**Returns:** `Optional[list[str]]` + +List of serialized handlers in rank order (only on gather_src rank), None otherwise + + + + + + + + +```python +nemo_rl.models.policy.utils._send_tensor_to_sglang( + url: str, + tensor_name: str, + gathered_handlers: list[str], + shape: torch.Size, + dtype: str, + flush_cache: bool = False +) -> None +``` + + + + + + +Send gathered IPC handlers to SGLang server via HTTP. + +Key: gathered_handlers are in rank order [rank0, rank1, ...] +SGLang will automatically match: handler = serialized_handlers[tp_rank] + +**Parameters:** + + +SGLang server URL + + + +Name of the tensor + + + +List of serialized IPC handlers in rank order + + + +Tensor shape + + + +Tensor dtype + + + +Whether to flush cache after this tensor (for last tensor) + + + + + + + + + +```python +nemo_rl.models.policy.utils._setup_ipc_gather_group( + rank: int, + current_device_uuid: str, + sglang_gpu_uuids: list[str], + sglang_url_to_gpu_uuids: dict[str, list[str]] +) -> tuple[typing.Optional[torch.distributed.ProcessGroup], typing.Optional[int], typing.Optional[list[int]]] +``` + + + + + + +Setup gather configuration for IPC handlers. + +**Returns:** `Optional[dist.ProcessGroup]` + +Tuple of (gather_group, gather_src_rank, matching_ranks) + + + + + + + + +```python +nemo_rl.models.policy.utils.apply_top_k_only( + logits: torch.Tensor, + top_k: int +) -> torch.Tensor +``` + + + + + + +Apply top-k mask to the logits. + +Simplified version of VLLM's implementation for scalar parameters. +This implementation doesn't involve sorting the entire vocab. + +Based on VLLM's implementation: +https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py +SPDX-License-Identifier: Apache-2.0 +Copyright contributors to the vLLM project + +**Parameters:** + + +Input logits tensor of shape [batch_size, seq_len, vocab_size] + + + +Top-k sampling parameter. + + +**Returns:** `torch.Tensor` + +Filtered logits with top-k applied + + + + + + + + +```python +nemo_rl.models.policy.utils.apply_top_k_top_p( + logits: torch.Tensor, + top_k: typing.Optional[int] = None, + top_p: typing.Optional[float] = None +) -> torch.Tensor +``` + + + + + + +Apply top-k and top-p masks to the logits. + +Simplified version of VLLM's implementation for scalar parameters. + +Based on VLLM's implementation: +https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py +SPDX-License-Identifier: Apache-2.0 +Copyright contributors to the vLLM project + +**Parameters:** + + +Input logits tensor of shape [batch_size, seq_len, vocab_size] + + + +Top-k sampling parameter. Set to -1 to consider all tokens. + + + +Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. + + +**Returns:** `torch.Tensor` + +Filtered logits with sampling parameters applied + + + + + + + + +```python +nemo_rl.models.policy.utils.calculate_aligned_size( + size_bytes: int, + alignment: int = 512 +) -> int +``` + + + + + + +Calculate aligned size for memory alignment. + +**Parameters:** + + +Size in bytes to align + + + +Alignment boundary in bytes (default 512) + + +**Returns:** `int` + +Aligned size in bytes(int). + + + + + + + + +```python +nemo_rl.models.policy.utils.configure_dynamo_cache() -> None +``` + + + + + + +Disable dynamo autotune_local_cache. + +Dynamo may fail at cached_autotune when there's already a cache with different order of node_bundles. +Disable autotune_local_cache as a workaround. +See https://github.com/pytorch/pytorch/issues/153791 for more details. + + + + + + + + +```python +nemo_rl.models.policy.utils.get_gpu_info( + model: torch.nn.Module +) -> dict[str, typing.Any] +``` + + + + + + +Return information about the GPU being used by this worker. + + + + + + + + +```python +nemo_rl.models.policy.utils.get_handle_from_tensor( + tensor: torch.Tensor +) -> tuple[typing.Any] +``` + + + + + + +Get IPC handle from a tensor. + + + + + + + + +```python +nemo_rl.models.policy.utils.get_megatron_checkpoint_dir() -> str +``` + + + + + + +Gets the default megatron checkpoint directory for initial HF -> Mcore conversion. + +Megatron initial checkpoint should be saved to a path available on all nodes. The directory used will take this order of precendence: +1. $NRL_MEGATRON_CHECKPOINT_DIR (if set) +2. $HF_HOME/nemo_rl (if HF_HOME is set) +3. ~/.cache/huggingface/nemo_rl + +HF_HOME is preferred since many users will also have that path mounted and it means one less directory +to mount into your runtime environment. + + + + + + + + +```python +nemo_rl.models.policy.utils.get_runtime_env_for_policy_worker( + policy_worker_name: str +) -> dict[str, typing.Any] +``` + + + + + + +Get runtime environment configuration for policy workers. + +Note: expandable_segments configuration is handled directly in the worker init methods +to ensure proper GPU detection after CUDA initialization. + + + + + + + + +```python +nemo_rl.models.policy.utils.is_vllm_v1_engine_enabled() -> bool +``` + + + + + + +Check if vLLM V1 engine is enabled. + +**Returns:** `bool` + +True if V1 engine is enabled, False otherwise (defaults to True if not set) + + + + + + + + +```python +nemo_rl.models.policy.utils.rebuild_cuda_tensor_from_ipc( + cuda_ipc_handle: tuple, + device_id: int +) -> torch.Tensor +``` + + + + + + +Rebuild a CUDA tensor from an IPC handle. + + + + + + + + +```python +nemo_rl.models.policy.utils.resolve_model_class( + model_name: str +) -> typing.Any +``` + + + + + + +Resolve the appropriate model class for a given model name. + + + + + + + + +```python +nemo_rl.models.policy.utils.stream_weights_via_http_impl( + params_generator, + sglang_url_to_gpu_uuids: dict[str, list[str]], + rank: int, + worker_name: str, + current_device_uuid: str +) -> None +``` + + + + + + +Stream weights to SGLang servers via HTTP API (update_weights_from_tensor). + +Flow: Each rank creates IPC handler → gather handlers in rank order → send list → SGLang matches by tp_rank index + +Key points: +- Each rank creates handler on its own GPU +- Handlers are gathered in rank order: [rank0_handler, rank1_handler, ...] +- List index = rank = GPU ID +- SGLang automatically matches: handler = serialized_handlers[tp_rank] + +**Parameters:** + + +Generator yielding (name, tensor) pairs + + + +Dict mapping SGLang server URL to list of GPU UUIDs it uses + + + +Worker rank for logging + + + +Name of the worker for logging + + + +UUID of the current training worker's GPU + + + + + + + + + +```python +nemo_rl.models.policy.utils.stream_weights_via_ipc_zmq_impl( + params_generator, + buffer_size_bytes: int, + zmq_socket, + rank: int, + worker_name: str +) -> None +``` + + + + + + +Shared implementation for streaming weights via IPC ZMQ with improved memory management. + +Uses ping-pong double buffering to enable overlapping communication while reusing buffers +to reduce memory allocation overhead and improve stability. + +**Parameters:** + + +Generator yielding (name, tensor) pairs + + + +total size of buffer in bytes for batching parameters + + + +ZMQ socket for communication + + + +Worker rank for logging + + + +Name of the worker for logging + + + + + + + + + +```python +nemo_rl.models.policy.utils.AUTOMODEL_FACTORY: Dict[str, Any] = {'qwen2_5_vl': AutoModelForImageTextToText, 'qwen2_vl': AutoModelForImageTextToT... +``` + + + + + + + + + +```python +nemo_rl.models.policy.utils.NEMO_AUTOMODEL_AVAILABLE = True +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx new file mode 100644 index 0000000..3becc39 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers.mdx @@ -0,0 +1,13 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers +title: nemo_rl.models.policy.workers +--- + +## Submodules + +- **[`nemo_rl.models.policy.workers.base_policy_worker`](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker)** +- **[`nemo_rl.models.policy.workers.dtensor_policy_worker`](/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker)** +- **[`nemo_rl.models.policy.workers.dtensor_policy_worker_v2`](/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2)** +- **[`nemo_rl.models.policy.workers.megatron_policy_worker`](/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker)** +- **[`nemo_rl.models.policy.workers.patches`](/nemo-rl/nemo_rl/models/policy/workers/patches)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx new file mode 100644 index 0000000..0983a40 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.mdx @@ -0,0 +1,309 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers/base_policy_worker +title: nemo_rl.models.policy.workers.base_policy_worker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AbstractPolicyWorker`](#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker) | Base class for policy workers with shared functionality. | + +### API + + + + + +```python +class nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker() +``` + + + + + + +Base class for policy workers with shared functionality. + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.finish_training( + args: typing.Any = (), + kwargs: typing.Any = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_free_memory_bytes() -> int +``` + + + + + + +Get the available free memory. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_gpu_info() -> dict[str, typing.Any] +``` + + + + + + +Return information about the GPU being used by this worker. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_reference_policy_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ReferenceLogprobOutputSpec] +``` + + + + + + +Get the logprobs from the reference policy for a batch of data. + +If micro_batch_size is provided, it will be used instead of the configured +logprob_batch_size. + +**Returns:** `BatchedDataDict[ReferenceLogprobOutputSpec]` + +a BatchedDataDict with key "reference_logprobs" and shape [batch_size, sequence_length]. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.get_zmq_address() -> str +``` + + + + + + +Get the ZMQ address for the current device. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.init_collective( + ip: str, + port: int, + world_size: int, + train_world_size: int +) -> None +``` + + + + + + +Initialize the collective communication. + +**Parameters:** + + +IP address for the process group + + + +Port for the process group + + + +Total world size (train_world_size + inference_world_size) + + + +Number of training workers (used in inference cluster) + + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.is_alive() -> bool +``` + + + + + + +Check if the worker is alive. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.maybe_init_zmq() -> None +``` + + + + + + +Initialize the ZMQ socket if it doesn't exist. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.report_device_id() -> str +``` + + + + + + +Report the UUID of the current CUDA device using NVML. + +**Returns:** `str` + +UUID of the device in the format "GPU-xxxxx" + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.report_node_ip_and_gpu_id() -> tuple[str, int] +``` + + + + + + +Report the node IP and GPU ID of the current worker. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.reset_peak_memory_stats() -> None +``` + + + + + + +Reset peak memory statistics. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.shutdown() -> bool +``` + + + + + + +Shutdown the policy. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.start_gpu_profiling() -> None +``` + + + + + + +Start GPU profiling. + + + + + + + +```python +nemo_rl.models.policy.workers.base_policy_worker.AbstractPolicyWorker.stop_gpu_profiling() -> None +``` + + + + + + +Stop GPU profiling. + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx new file mode 100644 index 0000000..6fb84a0 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker.mdx @@ -0,0 +1,693 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker +title: nemo_rl.models.policy.workers.dtensor_policy_worker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DTensorPolicyWorker`](#nemo_rl-models-policy-workers-dtensor_policy_worker-DTensorPolicyWorker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`get_cpu_state_dict`](#nemo_rl-models-policy-workers-dtensor_policy_worker-get_cpu_state_dict) | Copy the state dict generator to CPU memory. | +| [`unshard_fsdp2_model`](#nemo_rl-models-policy-workers-dtensor_policy_worker-unshard_fsdp2_model) | Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference. | + +### API + + + + + +```python +class nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker( + config: nemo_rl.models.policy.PolicyConfig, + tokenizer: transformers.AutoTokenizer, + processor: typing.Optional[transformers.AutoProcessor] = None, + weights_path: typing.Optional[str] = None, + optimizer_path: typing.Optional[str] = None, + init_optimizer: bool = True, + init_reference_model: bool = True, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [AbstractPolicyWorker](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker), [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.__repr__() -> str +``` + + + + + + +Customizes the actor's prefix in the Ray logs. + +This makes it easier to identify which worker is producing specific log messages. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker._add_noise_to_weights() -> None +``` + + + + + + +Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker._apply_temperature_scaling( + logits: torch.Tensor +) -> torch.Tensor +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.broadcast_weights_for_collective( + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Broadcast the weights for collective communication. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.calibrate_qkv_fp8_scales( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Placeholder for FP8 Q/K/V scale calibration, not implemented for DTensorPolicyWorker. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.create_context_parallel_ctx( + cp_mesh: torch.distributed.device_mesh.DeviceMesh, + cp_buffers: list[torch.Tensor], + cp_seq_dims: list[int], + cp_no_restore_buffers: typing.Set[torch.Tensor], + cp_rotate_method: typing.Optional[str] = None +) +``` + + + + + + +staticmethod + +Create a context parallel context. + +**Parameters:** + + +The device mesh for context parallel. + + + +The buffers for context parallel. + + + +The sequence dimensions for context parallel. + + + +The no restore buffers for context parallel. + + + +The rotation method for context parallel, such as "allgather" or "addtoall". + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.get_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] +``` + + + + + + +Get the logprobs of the model for a batch of data. + +Uses the configured logprob_batch_size to do microbatching. + +Input data is assumed to be right-padded. The method internally converts to +left-padded format for computation, and returns outputs in right-padded format. + +**Returns:** `BatchedDataDict[LogprobOutputSpec]` + +a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.get_topk_logits( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + k: int, + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +``` + + + + + + +Return per-position top-k logits and corresponding global indices. + +Notes: +- Return shapes are [B, S, k]. +- Computes top-k over the full sequence (no trimming of the last position). +- If alignment with next-token targets is required, the caller should handle it. +- If logits are TP-sharded DTensor, performs distributed global top-k across TP. +- Supports context parallelism with proper CP gather. +- Otherwise, computes local top-k on full-vocab tensor. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.load_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None +) -> None +``` + + + + + + +Load a checkpoint into the model. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_buffer_to_device( + model: torch.nn.Module, + device: str | torch.device +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_optimizer_to_device( + device: str | torch.device +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_to_cpu( + model: torch.nn.Module +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_to_cuda( + model: torch.nn.Module +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.move_to_device( + model: torch.nn.Module, + device: str | torch.device +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.offload_after_refit() -> None +``` + + + + + + +Offload as much as possible on the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.offload_before_refit() -> None +``` + + + + + + +Offload the optimizer to the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.prepare_for_lp_inference() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.prepare_for_training( + args = (), + kwargs = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +Prepare state dict metadata for weight refitting and IPC streaming. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.return_model_config() -> dict[str, typing.Any] +``` + + + + + + +Return the model configuration as a dictionary. + +**Returns:** `dict[str, Any]` + +Model configuration dictionary + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.return_state_dict() +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.save_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None, + tokenizer_path: typing.Optional[str] = None +) -> None +``` + + + + + + +Save a checkpoint of the model. + +the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.score( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec] +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.stream_weights_via_ipc_zmq( + buffer_size_bytes: int = 0, + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Stream model weights to peer process via ZMQ IPC socket. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.train( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + eval_mode: bool = False, + gbs: typing.Optional[int] = None, + mbs: typing.Optional[int] = None +) -> dict[str, typing.Any] +``` + + + + + + +Train the policy on a batch of data with a given loss function. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.train_context( + cp_context: typing.Optional[typing.Generator[None, None, None]] = None +) +``` + + + + + + +staticmethod + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker.use_reference_model() -> typing.Generator[None, None, None] +``` + + + + + + +Context manager that temporarily swaps the reference model and active model. + +On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references +On exit: Restores original references and re-flips cuda/cpu + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.get_cpu_state_dict( + state_generator: typing.Iterable[tuple[str, typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]]], + pin_memory: bool = False +) -> dict[str, torch.Tensor] +``` + + + + + + +Copy the state dict generator to CPU memory. + +**Parameters:** + + + +An iterable that yields (key, tensor) pairs from a model state. + + + + +Whether to allocate the CPU tensors in pinned memory for faster GPU transfer. +Defaults to False. + + +**Returns:** `dict[str, torch.Tensor]` + +dict[str, torch.Tensor]: A dictionary mapping parameter names to CPU tensors. + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker.unshard_fsdp2_model( + model: torch.nn.Module +) -> typing.Generator[None, None, None] +``` + + + + + + +Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx new file mode 100644 index 0000000..b6ff0e4 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.mdx @@ -0,0 +1,714 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers/dtensor_policy_worker_v2 +title: nemo_rl.models.policy.workers.dtensor_policy_worker_v2 +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`DTensorPolicyWorkerV2`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-DTensorPolicyWorkerV2) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`_maybe_adapt_tensor_to_hf`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-_maybe_adapt_tensor_to_hf) | - | +| [`_maybe_merge_lora_weight`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-_maybe_merge_lora_weight) | - | +| [`dtensor_params_generator`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-dtensor_params_generator) | Generator that yields (name, tensor) pairs, converting DTensors to local tensors and adapting to HF format. | +| [`get_train_context`](#nemo_rl-models-policy-workers-dtensor_policy_worker_v2-get_train_context) | Create combined context manager for training with context parallel and autocast. | + +### API + + + + + +```python +class nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2( + config: nemo_rl.models.policy.PolicyConfig, + tokenizer: transformers.AutoTokenizer, + processor: typing.Optional[transformers.AutoProcessor] = None, + weights_path: typing.Optional[str] = None, + optimizer_path: typing.Optional[str] = None, + init_optimizer: bool = True, + init_reference_model: bool = True, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [AbstractPolicyWorker](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker), [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.__repr__() -> str +``` + + + + + + +Customizes the actor's prefix in the Ray logs. + +This makes it easier to identify which worker is producing specific log messages. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2._add_noise_to_weights() -> None +``` + + + + + + +Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2._init_checkpoint_manager( + config_updates: typing.Optional[dict[str, typing.Any]] = None, + checkpoint_root: typing.Optional[str] = None +) -> None +``` + + + + + + +Initialize the AutomodelCheckpointManager for this worker. + +This creates the checkpoint manager bound to this worker's device meshes +and initializes its underlying checkpointer. + +**Parameters:** + + +Dict of CheckpointingConfig fields to set during initialization. + + + +Optional root directory for checkpoints. + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.broadcast_weights_for_collective( + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Broadcast the weights for collective communication. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.calibrate_qkv_fp8_scales( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Placeholder for FP8 Q/K/V scale calibration, not implemented for DTensorPolicyWorkerV2. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.get_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] +``` + + + + + + +Get the logprobs of the model for a batch of data. + +Uses the configured logprob_batch_size to do microbatching. + +Input data is assumed to be right-padded. The method internally converts to +left-padded format for computation, and returns outputs in right-padded format. + +**Returns:** `BatchedDataDict[LogprobOutputSpec]` + +a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.get_topk_logits( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + k: int, + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] +``` + + + + + + +Return per-position top-k logits and corresponding global indices. + +Notes: +- Return shapes are [B, S, k]. +- Computes top-k over the full sequence (no trimming of the last position). +- If alignment with next-token targets is required, the caller should handle it. +- If logits are TP-sharded DTensor, performs distributed global top-k across TP. +- Supports context parallelism with proper CP gather. +- Otherwise, computes local top-k on full-vocab tensor. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.load_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None +) -> None +``` + + + + + + +Load a checkpoint into the model using Automodel Checkpointer. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_buffer_to_device( + model: torch.nn.Module, + device: str | torch.device +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_optimizer_to_device( + device: str | torch.device +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_to_cpu( + model: torch.nn.Module +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_to_cuda( + model: torch.nn.Module +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.move_to_device( + model: torch.nn.Module, + device: str | torch.device +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.offload_after_refit() -> None +``` + + + + + + +Offload as much as possible on the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.offload_before_refit() -> None +``` + + + + + + +Offload the optimizer to the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.prepare_for_lp_inference() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.prepare_for_training( + args = (), + kwargs = {} +) -> None +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.prepare_refit_info() -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +Prepare state dict metadata for weight refitting and IPC streaming. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.return_model_config() -> dict[str, typing.Any] +``` + + + + + + +Return the model configuration as a dictionary. + +**Returns:** `dict[str, Any]` + +Model configuration dictionary + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.return_state_dict() +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.save_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None, + tokenizer_path: typing.Optional[str] = None, + checkpointing_cfg: typing.Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None +) -> None +``` + + + + + + +Save a checkpoint of the model. + +the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.score( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.ScoreOutputSpec] +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.stream_weights_via_http( + sglang_url_to_gpu_uuids: dict[str, list[str]] +) -> None +``` + + + + + + +Stream model weights to SGLang servers via HTTP API. + +**Parameters:** + + +Dict mapping SGLang server URL to list of GPU UUIDs it uses + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.stream_weights_via_ipc_zmq( + buffer_size_bytes: int = 0, + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Stream model weights to peer process via ZMQ IPC socket. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.train( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + eval_mode: bool = False, + gbs: typing.Optional[int] = None, + mbs: typing.Optional[int] = None +) -> dict[str, typing.Any] +``` + + + + + + +Train the policy on a batch of data with a given loss function. + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2.use_reference_model() -> typing.Generator[None, None, None] +``` + + + + + + +Context manager that temporarily swaps the reference model and active model. + +On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references +On exit: Restores original references and re-flips cuda/cpu + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2._maybe_adapt_tensor_to_hf( + model_part: torch.nn.Module, + fqn: str, + tensor: torch.Tensor, + quantization: bool = False +) -> list[tuple[str, torch.Tensor]] +``` + + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2._maybe_merge_lora_weight( + module_map: dict[str, torch.nn.Module], + fqn: str, + tensor: torch.Tensor +) -> torch.Tensor +``` + + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.dtensor_params_generator( + model: torch.nn.Module, + target_dtype: torch.dtype +) -> typing.Generator[tuple[str, torch.Tensor], None, None] +``` + + + + + + +Generator that yields (name, tensor) pairs, converting DTensors to local tensors and adapting to HF format. + +**Parameters:** + + +The model whose parameters to generate. + + + +The dtype to convert tensors to. + + + +Optional LoRA config for filtering which layers to merge. + + + + + + + + + +```python +nemo_rl.models.policy.workers.dtensor_policy_worker_v2.get_train_context( + cp_size: int, + cp_mesh: typing.Any, + cp_buffers: list, + sequence_dim: int, + dtype: torch.dtype, + autocast_enabled: bool = True +) -> typing.Generator[None, None, None] +``` + + + + + + +Create combined context manager for training with context parallel and autocast. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx new file mode 100644 index 0000000..729a45d --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.mdx @@ -0,0 +1,638 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker +title: nemo_rl.models.policy.workers.megatron_policy_worker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MegatronPolicyWorker`](#nemo_rl-models-policy-workers-megatron_policy_worker-MegatronPolicyWorker) | - | + +### Data + +[`TokenizerType`](#nemo_rl-models-policy-workers-megatron_policy_worker-TokenizerType) + +### API + + + + + +```python +class nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker( + config: nemo_rl.models.policy.PolicyConfig, + tokenizer: nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType, + weights_path: typing.Optional[str] = None, + optimizer_path: typing.Optional[str] = None, + init_optimizer: bool = True, + init_reference_model: bool = True, + worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding, + kwargs: typing.Any = {} +) +``` + + + + + + +**Bases:** [AbstractPolicyWorker](/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker#nemo_rl-models-policy-workers-base_policy_worker-AbstractPolicyWorker), [ColocatablePolicyInterface](/nemo-rl/nemo_rl/models/policy/interfaces#nemo_rl-models-policy-interfaces-ColocatablePolicyInterface) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.__repr__() +``` + + + + + + +Customizes the actor's prefix in the Ray logs. + +This makes it easier to identify which worker is producing specific log messages. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker._calculate_refit_param_info() -> list[tuple[str, int]] +``` + + + + + + +Calculate parameter information for refit. + +Each task contains: +- param_name: Local parameter name without module prefixes +- mapping: MegatronParamMapping instance for weight transformation +- pp_rank: Pipeline-parallel rank owning the parameter +- vp_stage: Virtual-pipeline stage index +- megatron_module: Reference to Megatron model/submodule +- param_weight: Target parameter tensor for converted weight + +**Returns:** `list[tuple[str, int]]` + +List of (parameter_name, size_in_bytes) tuples. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker._iter_params_with_optional_kv_scales( + kv_scales: typing.Optional[dict[str, float]] = None +) -> typing.Iterator[tuple[str, torch.Tensor]] +``` + + + + + + +Yield exported HF parameters and optionally append FP8 KV/Q scale tensors. + +This helper is used by both IPC-based streaming and collective broadcast +so that the logic for adding KV scales stays consistent in one place. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.broadcast_weights_for_collective( + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Broadcast the weights for collective communication. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.calibrate_qkv_fp8_scales( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +One-shot calibration of Q/K/V activation scales (for FP8 KV cache). + +- Captures each layer's `query_key_value` output through forward hooks, splits Q/K/V, and computes percentile amax. +- In parallel (DP/TP/PP) environments, first computes local percentiles, then takes max across all ranks for conservativeness. +- By default only returns and saves K/V scales, optionally returns Q. + +**Parameters:** + + +Representative sample batch for calibration, following get_logprobs input conventions. + + + +Micro batch size during calibration; if None, reuses logprob_batch_size. + + + +Percentile for amax (e.g. 99.9). + + + +Margin factor, e.g. 1.05. + + + +If provided, rank0 will save results as JSON. + + + +Whether to also return Q scale (usually only K/V needed). + + +**Returns:** `dict[str, Any]` + +{ "format": "fp8", "percentile": float, "margin": float, +"layers": { layer_name: {"k_scale": float, "v_scale": float[, "q_scale": float] } } } + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.check_tensor_parallel_attributes() -> dict[str, typing.Any] +``` + + + + + + +Check tensor parallel attributes on model parameters. + +**Returns:** `dict[str, Any]` + +Dictionary containing information about tensor parallel parameters: + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.disable_forward_pre_hook( + param_sync = True +) +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.enable_forward_pre_hook() +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.generate( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + greedy: bool = False +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec] +``` + + + + + + +Generate a batch of data using huggingface framework generation. + +Returns: + BatchedDataDict conforming to GenerationOutputSpec: + - output_ids: input + generated token IDs + - logprobs: Log probabilities for each token + - generation_lengths: Lengths of each response + +**Parameters:** + + +BatchedDataDict containing input_ids and input_lengths tensors + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.get_logprobs( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any], + micro_batch_size: typing.Optional[int] = None +) -> nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.policy.interfaces.LogprobOutputSpec] +``` + + + + + + +Get the logprobs of the model for a batch of data. + +Uses the configured logprob_batch_size to do microbatching. +Input data is assumed to be right-padded. The method internally converts to +left-padded format for computation, and returns outputs in right-padded format. +If micro_batch_size is provided, it will be used instead of the configured +logprob_batch_size. + +**Returns:** `BatchedDataDict[LogprobOutputSpec]` + +a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.get_topk_logits( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec], + k: int, + micro_batch_size: typing.Optional[int] = None +) +``` + + + + + + +Get the top-k logits and indices for a batch of data. + +The major difference from get_logprobs is that we compute top-k logits and indices for each position in the sequence. + +**Returns:** + +BatchedDataDict containing: +- topk_logits: Tensor of top-k logits for each position in the sequence +- topk_indices: Tensor of top-k indices for each position in the sequence + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.load_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None +) +``` + + + + + + +Load a training checkpoint. + +**Parameters:** + + +The exact directory path from which to load the checkpoint. + + + +If not None, attempts to load optimizer and scheduler states + if self.optimizer and self.scheduler are initialized. + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.move_model( + model: torch.nn.Module, + device: str, + move_params: bool = True, + move_grads: bool = True +) -> torch.nn.Module +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.move_optimizer( + device: str +) +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.offload_after_refit() +``` + + + + + + +Offload as much as possible on the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.offload_before_refit() +``` + + + + + + +Offload the optimizer and buffers to the CPU. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.prepare_for_lp_inference() +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.prepare_for_training( + args = (), + kwargs = {} +) +``` + + + + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.prepare_refit_info() -> None +``` + + + + + + +Prepare state dict metadata for weight refitting and IPC streaming. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.save_checkpoint( + weights_path: str, + optimizer_path: typing.Optional[str] = None, + kwargs = {} +) +``` + + + + + + +Save a training checkpoint. + +**Parameters:** + + +The specific directory path where the checkpoint will be saved. + + + +If not None, optimizer and scheduler states are saved if they exist. + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.stream_weights_via_ipc_zmq( + buffer_size_bytes: int = 0, + kv_scales: typing.Optional[dict[str, float]] = None +) -> None +``` + + + + + + +Stream model weights to peer process via ZMQ IPC socket. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.train( + data: nemo_rl.distributed.batched_data_dict.BatchedDataDict, + loss_fn: nemo_rl.algorithms.interfaces.LossFunction, + eval_mode: bool = False, + gbs: typing.Optional[int] = None, + mbs: typing.Optional[int] = None +) -> dict[str, typing.Any] +``` + + + + + + +Train the policy on a batch of data with a given loss function. + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker.use_reference_model() +``` + + + + + + +Context manager that temporarily swaps the reference model and active model. + +On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references +On exit: Restores original references and re-flips cuda/cpu + + + + + + + + + +```python +nemo_rl.models.policy.workers.megatron_policy_worker.TokenizerType = TypeVar('TokenizerType', bound=PreTrainedTokenizerBase) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx new file mode 100644 index 0000000..4250027 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/models/policy/workers/patches.mdx @@ -0,0 +1,85 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/models/policy/workers/patches +title: nemo_rl.models.policy.workers.patches +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_transformer_engine_file`](#nemo_rl-models-policy-workers-patches-_get_transformer_engine_file) | Return absolute path to a Transformer Engine file or raise if it cannot be found. | +| [`apply_torch_aten_alias_tensor_patch`](#nemo_rl-models-policy-workers-patches-apply_torch_aten_alias_tensor_patch) | Register a sharding rule for `torch.ops.aten.alias.default`. | +| [`apply_transformer_engine_patch`](#nemo_rl-models-policy-workers-patches-apply_transformer_engine_patch) | Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files. | + +### API + + + + + +```python +nemo_rl.models.policy.workers.patches._get_transformer_engine_file( + relative_path: str +) -> str +``` + + + + + + +Return absolute path to a Transformer Engine file or raise if it cannot be found. + +The relative_path should be a POSIX-style path under the transformer_engine +package root, e.g. "pytorch/triton/permutation.py". + + + + + + + + +```python +nemo_rl.models.policy.workers.patches.apply_torch_aten_alias_tensor_patch() +``` + + + + + + +Register a sharding rule for `torch.ops.aten.alias.default`. + +Work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' +in PyTorch 2.9. See https://github.com/pytorch/pytorch/pull/166867 for the upstream fix. +We can remove this patch when we upgrade torch to include this fix. + + + + + + + + +```python +nemo_rl.models.policy.workers.patches.apply_transformer_engine_patch() +``` + + + + + + +Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files. + +This locates the target file via importlib metadata instead of importing +`transformer_engine`, to avoid side effects during initialization. If the +permutation module has already been imported, it will be reloaded so that +the patched source takes effect. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx new file mode 100644 index 0000000..2dc77ed --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/package_info.mdx @@ -0,0 +1,235 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/package_info +title: nemo_rl.package_info +--- + +## Module Contents + +### Data + +[`MAJOR`](#nemo_rl-package_info-MAJOR) + +[`MINOR`](#nemo_rl-package_info-MINOR) + +[`PATCH`](#nemo_rl-package_info-PATCH) + +[`PRE_RELEASE`](#nemo_rl-package_info-PRE_RELEASE) + +[`VERSION`](#nemo_rl-package_info-VERSION) + +[`__contact_emails__`](#nemo_rl-package_info-__contact_emails__) + +[`__contact_names__`](#nemo_rl-package_info-__contact_names__) + +[`__description__`](#nemo_rl-package_info-__description__) + +[`__download_url__`](#nemo_rl-package_info-__download_url__) + +[`__homepage__`](#nemo_rl-package_info-__homepage__) + +[`__keywords__`](#nemo_rl-package_info-__keywords__) + +[`__license__`](#nemo_rl-package_info-__license__) + +[`__package_name__`](#nemo_rl-package_info-__package_name__) + +[`__repository_url__`](#nemo_rl-package_info-__repository_url__) + +[`__shortversion__`](#nemo_rl-package_info-__shortversion__) + +[`__version__`](#nemo_rl-package_info-__version__) + +### API + + + + + +```python +nemo_rl.package_info.MAJOR = 0 +``` + + + + + + + + + +```python +nemo_rl.package_info.MINOR = 5 +``` + + + + + + + + + +```python +nemo_rl.package_info.PATCH = 0 +``` + + + + + + + + + +```python +nemo_rl.package_info.PRE_RELEASE = 'rc0' +``` + + + + + + + + + +```python +nemo_rl.package_info.VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE) +``` + + + + + + + + + +```python +nemo_rl.package_info.__contact_emails__ = 'nemo-tookit@nvidia.com' +``` + + + + + + + + + +```python +nemo_rl.package_info.__contact_names__ = 'NVIDIA' +``` + + + + + + + + + +```python +nemo_rl.package_info.__description__ = 'NeMo-RL - a toolkit for model alignment' +``` + + + + + + + + + +```python +nemo_rl.package_info.__download_url__ = 'https://github.com/NVIDIA-NeMo/RL/releases' +``` + + + + + + + + + +```python +nemo_rl.package_info.__homepage__ = 'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/' +``` + + + + + + + + + +```python +nemo_rl.package_info.__keywords__ = 'deep learning, machine learning, gpu, NLP, NeMo, nvidia, pytorch, torch, langua... +``` + + + + + + + + + +```python +nemo_rl.package_info.__license__ = 'Apache2' +``` + + + + + + + + + +```python +nemo_rl.package_info.__package_name__ = 'nemo_rl' +``` + + + + + + + + + +```python +nemo_rl.package_info.__repository_url__ = 'https://github.com/NVIDIA-NeMo/RL' +``` + + + + + + + + + +```python +nemo_rl.package_info.__shortversion__ = '.'.join(map(str, VERSION[:3])) +``` + + + + + + + + + +```python +nemo_rl.package_info.__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:]) +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx new file mode 100644 index 0000000..b7dfc66 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils.mdx @@ -0,0 +1,22 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils +title: nemo_rl.utils +--- + +## Submodules + +- **[`nemo_rl.utils.automodel_checkpoint`](/nemo-rl/nemo_rl/utils/automodel_checkpoint)** +- **[`nemo_rl.utils.checkpoint`](/nemo-rl/nemo_rl/utils/checkpoint)** +- **[`nemo_rl.utils.config`](/nemo-rl/nemo_rl/utils/config)** +- **[`nemo_rl.utils.flops_formulas`](/nemo-rl/nemo_rl/utils/flops_formulas)** +- **[`nemo_rl.utils.flops_tracker`](/nemo-rl/nemo_rl/utils/flops_tracker)** +- **[`nemo_rl.utils.logger`](/nemo-rl/nemo_rl/utils/logger)** +- **[`nemo_rl.utils.memory_tracker`](/nemo-rl/nemo_rl/utils/memory_tracker)** +- **[`nemo_rl.utils.native_checkpoint`](/nemo-rl/nemo_rl/utils/native_checkpoint)** +- **[`nemo_rl.utils.nsys`](/nemo-rl/nemo_rl/utils/nsys)** +- **[`nemo_rl.utils.nvml`](/nemo-rl/nemo_rl/utils/nvml)** +- **[`nemo_rl.utils.packed_tensor`](/nemo-rl/nemo_rl/utils/packed_tensor)** +- **[`nemo_rl.utils.prefetch_venvs`](/nemo-rl/nemo_rl/utils/prefetch_venvs)** +- **[`nemo_rl.utils.timer`](/nemo-rl/nemo_rl/utils/timer)** +- **[`nemo_rl.utils.venvs`](/nemo-rl/nemo_rl/utils/venvs)** diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx new file mode 100644 index 0000000..2afdec6 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/automodel_checkpoint.mdx @@ -0,0 +1,436 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/automodel_checkpoint +title: nemo_rl.utils.automodel_checkpoint +--- + +Automodel checkpoint utilities for DTensor policy workers. + +This module provides a wrapper class around the nemo_automodel Checkpointer +for saving and loading model checkpoints in DTensor-based policy workers. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`AutomodelCheckpointManager`](#nemo_rl-utils-automodel_checkpoint-AutomodelCheckpointManager) | Manages checkpointing for DTensor-based models using nemo_automodel's Checkpointer. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_infer_checkpoint_root`](#nemo_rl-utils-automodel_checkpoint-_infer_checkpoint_root) | Infer checkpoint root directory from weights path. | +| [`detect_checkpoint_format`](#nemo_rl-utils-automodel_checkpoint-detect_checkpoint_format) | Detect model save format and PEFT status from checkpoint directory. | + +### API + + + + + +```python +class nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager( + dp_mesh: torch.distributed.device_mesh.DeviceMesh, + tp_mesh: torch.distributed.device_mesh.DeviceMesh, + model_state_dict_keys: typing.Optional[list[str]] = None, + moe_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None +) +``` + + + + + + +Manages checkpointing for DTensor-based models using nemo_automodel's Checkpointer. + +This class provides a clean interface for saving and loading model checkpoints, +wrapping the nemo_automodel Checkpointer with configuration management. + + + + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager._get_dp_rank() -> int +``` + + + + + + +Get the data parallel rank. + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager._get_tp_rank() -> int +``` + + + + + + +Get the tensor parallel rank. + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager._rebuild_checkpointer_addons() -> None +``` + + + + + + +Rebuild the checkpointer's _addons list based on current config. + +The Checkpointer's _addons list is populated during __init__ based on config. +When config changes (e.g., model_save_format or is_peft), we need to rebuild +the addons list to match the new config. + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.init_checkpointer( + config_updates: typing.Optional[dict[str, typing.Any]] = None, + checkpoint_root: typing.Optional[str] = None +) -> None +``` + + + + + + +Initialize the Automodel Checkpointer if not already created. + +This method creates a new Checkpointer instance with the provided configuration. +If a checkpointer already exists, this method does nothing. + +**Parameters:** + + +Dict of CheckpointingConfig fields to set during initialization. + + + +Optional root directory for checkpoints. + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.load_base_model( + model: torch.nn.Module, + model_name: str, + hf_cache_dir: typing.Optional[str] = None, + dequantize_base_checkpoint: bool = False, + peft_init_method: typing.Optional[str] = None +) -> None +``` + + + + + + +Load base model weights using the Automodel Checkpointer. + +This method loads the initial HuggingFace model weights into the parallelized model. + +**Parameters:** + + +The model to load weights into. + + + +Name or path of the model. + + + +Optional HuggingFace cache directory. + + + +Whether to dequantize the base checkpoint. + + +**Raises:** + +- `AssertionError`: If checkpointer has not been initialized. + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.load_checkpoint( + model: torch.nn.Module, + weights_path: str, + optimizer: typing.Optional[torch.optim.Optimizer] = None, + optimizer_path: typing.Optional[str] = None, + scheduler: typing.Optional[torch.optim.lr_scheduler.LRScheduler] = None +) -> None +``` + + + + + + +Load a checkpoint into the model using Automodel Checkpointer. + +**Parameters:** + + +The model to load weights into. + + + +Path to the checkpoint weights. + + + +Optional optimizer to load state into. + + + +Optional path to optimizer checkpoint. + + + +Optional learning rate scheduler. + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.save_checkpoint( + model: torch.nn.Module, + weights_path: str, + optimizer: typing.Optional[torch.optim.Optimizer] = None, + optimizer_path: typing.Optional[str] = None, + scheduler: typing.Optional[torch.optim.lr_scheduler.LRScheduler] = None, + tokenizer: typing.Optional[transformers.AutoTokenizer] = None, + tokenizer_path: typing.Optional[str] = None, + checkpointing_cfg: typing.Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None, + lora_enabled: bool = False, + peft_config: typing.Optional[nemo_automodel.components._peft.lora.PeftConfig] = None +) -> None +``` + + + + + + +Save a checkpoint of the model. + +The optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + +**Parameters:** + + +The model to save. + + + +Path to save model weights. + + + +Optional optimizer to save. + + + +Optional path to save optimizer state. + + + +Optional learning rate scheduler. + + + +Optional tokenizer to save with the checkpoint. + + + +Optional path to save tokenizer separately. + + + +Checkpointing configuration. + + + +Whether LoRA is enabled. + + + +Optional PEFT configuration. + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.set_model_state_dict_keys( + keys: list[str] +) -> None +``` + + + + + + +Set the model state dict keys for checkpoint validation. + +**Parameters:** + + +List of model state dict keys. + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager.update_checkpointer_config( + config_updates: typing.Optional[dict[str, typing.Any]] = None, + checkpoint_root: typing.Optional[str] = None +) -> None +``` + + + + + + +Update the configuration of an existing Checkpointer. + +This method updates the mutable config fields on the existing Checkpointer instance. +If no checkpointer exists, this method does nothing. + +Note: Some config changes (like model_save_format) require rebuilding the +checkpointer's internal addons list. This method handles that automatically. + +**Parameters:** + + +Dict of CheckpointingConfig fields to update. + + + +Optional root directory for checkpoints. + + + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint._infer_checkpoint_root( + weights_path: str +) -> str +``` + + + + + + +Infer checkpoint root directory from weights path. + +When weights_path ends with "…/weights/model", we need the parent of +the weights directory (the checkpoint root), not the weights directory itself. + +**Parameters:** + + +Path to model weights (e.g., "/path/to/policy/weights/model") + + +**Returns:** `str` + +Checkpoint root directory (e.g., "/path/to/policy") + + + + + + + + +```python +nemo_rl.utils.automodel_checkpoint.detect_checkpoint_format( + weights_path: str +) -> tuple[str, bool] +``` + + + + + + +Detect model save format and PEFT status from checkpoint directory. + +**Parameters:** + + +Path to the checkpoint directory (e.g., weights/model) + + +**Returns:** `tuple[str, bool]` + +(model_save_format, is_peft) where: + model_save_format is "torch_save" for DCP or "safetensors" for safetensors + is_peft is True if PEFT/adapter patterns are detected + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx new file mode 100644 index 0000000..8c380d8 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/checkpoint.mdx @@ -0,0 +1,411 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/checkpoint +title: nemo_rl.utils.checkpoint +--- + +Checkpoint management utilities for the rl algorithm loop. + +It handles logic at the algorithm level. Each RL Actor is expected to have its +own checkpoint saving function (called by the algorithm loop). + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CheckpointManager`](#nemo_rl-utils-checkpoint-CheckpointManager) | Manages model checkpoints during training. | +| [`CheckpointingConfig`](#nemo_rl-utils-checkpoint-CheckpointingConfig) | Configuration for checkpoint management. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_load_checkpoint_history`](#nemo_rl-utils-checkpoint-_load_checkpoint_history) | Load the history of checkpoints and their metrics. | + +### Data + +[`PathLike`](#nemo_rl-utils-checkpoint-PathLike) + +### API + + + + + +```python +class nemo_rl.utils.checkpoint.CheckpointManager( + config: nemo_rl.utils.checkpoint.CheckpointingConfig +) +``` + + + + + + +Manages model checkpoints during training. + +This class handles creating checkpoint dirs, saving training info, and +configurations. It also provides utilities for keeping just the top-k checkpoints. +The checkpointing structure looks like this: + + +```python +checkpoint_dir/ + step_0/ + training_info.json + config.yaml + policy.py (up to the algorithm loop to save here) + policy_optimizer.py (up to the algorithm loop to save here) + ... + step_1/ + ... +``` + + + +Attributes: Derived from the CheckpointingConfig. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.finalize_checkpoint( + checkpoint_path: nemo_rl.utils.checkpoint.PathLike +) -> None +``` + + + + + + +Complete a checkpoint by moving it from temporary to permanent location. + +If a checkpoint at the target location already exists (i.e when resuming training), +we override the old one. +Also triggers cleanup of old checkpoints based on the keep_top_k setting. + +**Parameters:** + + +Path to the temporary checkpoint directory. + + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.get_best_checkpoint_path() -> typing.Optional[str] +``` + + + + + + +Get the path to the best checkpoint based on the metric. + +Returns the path to the checkpoint with the best metric value. If no checkpoints +exist, returns None. If some checkpoints are missing the metric, they are filtered +out with a warning. If no checkpoints have the metric, returns the latest checkpoint. + +**Returns:** `Optional[str]` + +Optional[str]: Path to the best checkpoint, or None if no checkpoints exist. + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.get_latest_checkpoint_path() -> typing.Optional[str] +``` + + + + + + +Get the path to the latest checkpoint. + +Returns the path to the checkpoint with the highest step number. + +**Returns:** `Optional[str]` + +Optional[str]: Path to the latest checkpoint, or None if no checkpoints exist. + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.init_tmp_checkpoint( + step: int, + training_info: typing.Mapping[str, typing.Any], + run_config: typing.Optional[typing.Mapping[str, typing.Any]] = None +) -> nemo_rl.utils.checkpoint.PathLike +``` + + + + + + +Initialize a temporary checkpoint directory. + +Creates a temporary directory for a new checkpoint and saves training info +and configuration. The directory is named 'tmp_step_{step}' and will be renamed +to 'step_{step}' when the checkpoint is completed. +We do it this way to allow the algorithm loop to save any files it wants to save +in a safe, temporary directory. + +**Parameters:** + + +The training step number. + + + +Dictionary containing training metrics and info. + + + +Optional configuration for the training run. + + +**Returns:** `PathLike` + +Path to the temporary checkpoint directory. + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.load_training_info( + checkpoint_path: typing.Optional[nemo_rl.utils.checkpoint.PathLike] = None +) -> typing.Optional[dict[str, typing.Any]] +``` + + + + + + +Load the training info from a checkpoint. + +**Parameters:** + + +Path to the checkpoint. If None, +returns None. + + +**Returns:** `Optional[dict[str, Any]]` + +Optional[dict[str, Any]]: Dictionary containing the training info, or None if +checkpoint_path is None. + + + + + + + +```python +nemo_rl.utils.checkpoint.CheckpointManager.remove_old_checkpoints( + exclude_latest: bool = True +) -> None +``` + + + + + + +Remove checkpoints that are not in the top-k or latest based on the (optional) metric. + +If keep_top_k is set, this method removes all checkpoints except the top-k +best ones. The "best" checkpoints are determined by: +- If a metric is provided: the given metric value and the higher_is_better setting. + When multiple checkpoints have the same metric value, more recent checkpoints + (higher step numbers) are preferred. +- If no metric is provided: the step number. The most recent k checkpoints are kept. + +**Parameters:** + + +Whether to exclude the latest checkpoint from deletion. (may result in K+1 checkpoints) + + + + + + + + + + +```python +class nemo_rl.utils.checkpoint.CheckpointingConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + +Configuration for checkpoint management. + +Attributes: +enabled (bool): Whether checkpointing is enabled. +checkpoint_dir (PathLike): Directory where checkpoints will be saved. +metric_name (str | None): Name of the metric to use for determining best checkpoints. + Must be of the form "val:<metric_name>" or "train:<metric_name>" to indicate whether + the metric should be taken from the validation or training metrics. +higher_is_better (bool): Whether higher values of the metric indicate better performance. +keep_top_k (Optional[int]): Number of best checkpoints to keep. If None, all checkpoints are kept. +model_save_format (str | None): Format for saving model (v2 allowed values: "torch_save" or "safetensors", v1 allowed values: None). +save_consolidated (bool): Whether to save consolidated checkpoints (for HF compatibility). +model_cache_dir (str): Directory for model cache (for safetensors format). +model_repo_id (str): Repository ID for the model (for safetensors format). +is_peft (bool): Whether the model uses PEFT. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.checkpoint._load_checkpoint_history( + checkpoint_dir: pathlib.Path +) -> list[tuple[int, nemo_rl.utils.checkpoint.PathLike, dict[str, typing.Any]]] +``` + + + + + + +Load the history of checkpoints and their metrics. + +**Parameters:** + + +Directory containing the checkpoints. + + +**Returns:** `list[tuple[int, PathLike, dict[str, Any]]]` + +list[tuple[int, PathLike, dict[str, Any]]]: List of tuples containing +(step_number, checkpoint_path, info) for each checkpoint. + + + + + + + + +```python +nemo_rl.utils.checkpoint.PathLike = Union[str, 'os.PathLike[Any]'] +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx new file mode 100644 index 0000000..ba6ab36 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/config.mdx @@ -0,0 +1,266 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/config +title: nemo_rl.utils.config +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`OverridesError`](#nemo_rl-utils-config-OverridesError) | Custom exception for Hydra override parsing errors. | + +### Functions + +| Name | Description | +|------|-------------| +| [`load_config`](#nemo_rl-utils-config-load_config) | Load a config file with inheritance support and convert it to an OmegaConf object. | +| [`load_config_with_inheritance`](#nemo_rl-utils-config-load_config_with_inheritance) | Load a config file with inheritance support. | +| [`merge_with_override`](#nemo_rl-utils-config-merge_with_override) | Merge configs with support for _override_ marker to completely override sections. | +| [`parse_hydra_overrides`](#nemo_rl-utils-config-parse_hydra_overrides) | Parse and apply Hydra overrides to an OmegaConf config. | +| [`register_omegaconf_resolvers`](#nemo_rl-utils-config-register_omegaconf_resolvers) | Register shared OmegaConf resolvers used in configs. | +| [`resolve_path`](#nemo_rl-utils-config-resolve_path) | Resolve a path relative to the base path. | + +### API + + + + + +```python +class nemo_rl.utils.config.OverridesError() +``` + + + + + + +Exception + +**Bases:** `Exception` + +Custom exception for Hydra override parsing errors. + + + + + + + + +```python +nemo_rl.utils.config.load_config( + config_path: typing.Union[str, pathlib.Path] +) -> omegaconf.DictConfig +``` + + + + + + +Load a config file with inheritance support and convert it to an OmegaConf object. + +The config inheritance system supports: + +1. Single inheritance: + ```python + # child.yaml + defaults: parent.yaml + common: + value: 43 + ``` + +2. Multiple inheritance: + ```python + # child.yaml + defaults: + - parent1.yaml + - parent2.yaml + common: + value: 44 + ``` + +3. Nested inheritance: + ```python + # parent.yaml + defaults: grandparent.yaml + common: + value: 43 + + # child.yaml + defaults: parent.yaml + common: + value: 44 + ``` + +4. Variable interpolation: + ```python + # parent.yaml + base_value: 42 + derived: + value: ${base_value} + + # child.yaml + defaults: parent.yaml + base_value: 43 # This will update both base_value and derived.value + ``` + +The system handles: +- Relative and absolute paths +- Multiple inheritance +- Nested inheritance +- Variable interpolation + +The inheritance is resolved depth-first, with later configs overriding earlier ones. +This means in multiple inheritance, the last config in the list takes precedence. + +**Parameters:** + + +Path to the config file + + +**Returns:** `DictConfig` + +Merged config dictionary + + + + + + + + +```python +nemo_rl.utils.config.load_config_with_inheritance( + config_path: typing.Union[str, pathlib.Path], + base_dir: typing.Optional[typing.Union[str, pathlib.Path]] = None +) -> omegaconf.DictConfig +``` + + + + + + +Load a config file with inheritance support. + +**Parameters:** + + +Path to the config file + + + +Base directory for resolving relative paths. If None, uses config_path's directory + + +**Returns:** `DictConfig` + +Merged config dictionary + + + + + + + + +```python +nemo_rl.utils.config.merge_with_override( + base_config: omegaconf.DictConfig, + override_config: omegaconf.DictConfig +) -> omegaconf.DictConfig +``` + + + + + + +Merge configs with support for _override_ marker to completely override sections. + + + + + + + + +```python +nemo_rl.utils.config.parse_hydra_overrides( + cfg: omegaconf.DictConfig, + overrides: list[str] +) -> omegaconf.DictConfig +``` + + + + + + +Parse and apply Hydra overrides to an OmegaConf config. + +**Parameters:** + + +OmegaConf config to apply overrides to + + + +List of Hydra override strings + + +**Returns:** `DictConfig` + +Updated config with overrides applied + +**Raises:** + +- `OverridesError`: If there's an error parsing or applying overrides + + + + + + + + +```python +nemo_rl.utils.config.register_omegaconf_resolvers() -> None +``` + + + + + + +Register shared OmegaConf resolvers used in configs. + + + + + + + + +```python +nemo_rl.utils.config.resolve_path( + base_path: pathlib.Path, + path: str +) -> pathlib.Path +``` + + + + + + +Resolve a path relative to the base path. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx new file mode 100644 index 0000000..a0ee8f2 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_formulas.mdx @@ -0,0 +1,501 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/flops_formulas +title: nemo_rl.utils.flops_formulas +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FLOPSConfig`](#nemo_rl-utils-flops_formulas-FLOPSConfig) | Contains the model hparams needed for FLOPS computations. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_hybrid_model_flops`](#nemo_rl-utils-flops_formulas-_hybrid_model_flops) | Model FLOPs for hybrid model. | +| [`_mamba_layer_flops`](#nemo_rl-utils-flops_formulas-_mamba_layer_flops) | Model FLOPs for Mamba layer. We ignore part of the flops of scan because the chunk size is not known from model config. | +| [`_mlp_layer_flops`](#nemo_rl-utils-flops_formulas-_mlp_layer_flops) | Model FLOPs for MLP layer. | +| [`_non_mla_attn_layer_flops`](#nemo_rl-utils-flops_formulas-_non_mla_attn_layer_flops) | Model FLOPs for attention layer. | +| [`bert`](#nemo_rl-utils-flops_formulas-bert) | Model FLOPs for BERT family. | +| [`deepseekv3`](#nemo_rl-utils-flops_formulas-deepseekv3) | Model FLOPs for DeepSeek V3. | +| [`flux`](#nemo_rl-utils-flops_formulas-flux) | Model FLOPs for FLUX. | +| [`gpt3`](#nemo_rl-utils-flops_formulas-gpt3) | Model FLOPs for GPT3 family. | +| [`llama`](#nemo_rl-utils-flops_formulas-llama) | Model FLOPs for llama3 family. | +| [`mixtral`](#nemo_rl-utils-flops_formulas-mixtral) | Model FLOPs for mixtral family. | +| [`nemotron`](#nemo_rl-utils-flops_formulas-nemotron) | Model FLOPs for nemotron family. | +| [`nemotronh`](#nemo_rl-utils-flops_formulas-nemotronh) | Model FLOPs for NemotronH. | +| [`qwen2`](#nemo_rl-utils-flops_formulas-qwen2) | Model FLOPs for Qwen2 family. | +| [`qwen3`](#nemo_rl-utils-flops_formulas-qwen3) | Model FLOPs for Qwen3 family. | +| [`transformer`](#nemo_rl-utils-flops_formulas-transformer) | Calculate FLOPs for a standard Transformer model. | + +### API + + + + + +```python +class nemo_rl.utils.flops_formulas.FLOPSConfig( + gbs: int, + enc_seq_len: typing.Optional[int] = None, + hs: typing.Optional[int] = None, + layers: typing.Optional[int] = None, + ffn_hs: typing.Optional[int] = None, + attention_heads: typing.Optional[int] = None, + moe_router_topk: typing.Optional[int] = None, + query_groups: typing.Optional[int] = None, + img_seq_len: typing.Optional[int] = None, + img_h: typing.Optional[int] = None, + img_w: typing.Optional[int] = None, + in_channels: typing.Optional[int] = None, + patch_dim: typing.Optional[int] = None, + class_token_len: typing.Optional[int] = None, + projector_type: typing.Optional[str] = None, + inp_s: typing.Optional[int] = None, + model_pattern: typing.Optional[str] = None, + vocab_size: typing.Optional[int] = None, + model_channels: typing.Optional[int] = None, + vec_in_dim: typing.Optional[int] = None, + q_lora_rank: typing.Optional[int] = None, + kv_lora_rank: typing.Optional[int] = None, + qk_head_dim: typing.Optional[int] = None, + qk_pos_emb_head_dim: typing.Optional[int] = None, + v_head_dim: typing.Optional[int] = None, + moe_layer_freq: typing.Optional[typing.Union[int, typing.List[int]]] = None, + moe_shared_expert_intermediate_size: typing.Optional[int] = None, + moe_ffn_hidden_size: typing.Optional[int] = None, + mtp_num_layers: typing.Optional[int] = None, + causal_self_attn: typing.Optional[bool] = None, + is_hybrid_model: bool = False, + hybrid_override_pattern: typing.Optional[str] = None, + mamba_state_dim: typing.Optional[int] = None, + mamba_head_dim: typing.Optional[int] = None, + mamba_num_groups: typing.Optional[int] = None, + mamba_num_heads: typing.Optional[int] = None +) +``` + + + + + + +Dataclass + +Contains the model hparams needed for FLOPS computations. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.flops_formulas._hybrid_model_flops( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for hybrid model. + + + + + + + + +```python +nemo_rl.utils.flops_formulas._mamba_layer_flops( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for Mamba layer. We ignore part of the flops of scan because the chunk size is not known from model config. + + + + + + + + +```python +nemo_rl.utils.flops_formulas._mlp_layer_flops( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for MLP layer. + + + + + + + + +```python +nemo_rl.utils.flops_formulas._non_mla_attn_layer_flops( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for attention layer. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.bert( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for BERT family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.deepseekv3( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for DeepSeek V3. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.flux( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for FLUX. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.gpt3( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for GPT3 family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.llama( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for llama3 family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.mixtral( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for mixtral family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.nemotron( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for nemotron family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.nemotronh( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for NemotronH. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.qwen2( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for Qwen2 family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.qwen3( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Model FLOPs for Qwen3 family. + + + + + + + + +```python +nemo_rl.utils.flops_formulas.transformer( + config: nemo_rl.utils.flops_formulas.FLOPSConfig +) +``` + + + + + + +Calculate FLOPs for a standard Transformer model. + +Note: This does not cover encoder-decoder models. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx new file mode 100644 index 0000000..1965cd5 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/flops_tracker.mdx @@ -0,0 +1,215 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/flops_tracker +title: nemo_rl.utils.flops_tracker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`FLOPTracker`](#nemo_rl-utils-flops_tracker-FLOPTracker) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`convert_config_to_flops_config`](#nemo_rl-utils-flops_tracker-convert_config_to_flops_config) | Convert a pretrained config to a tuple containing a FLOPSConfig and a flops formula. | +| [`get_default_hf_config`](#nemo_rl-utils-flops_tracker-get_default_hf_config) | Get the default Hugging Face config for a model. | +| [`get_theoretical_tflops`](#nemo_rl-utils-flops_tracker-get_theoretical_tflops) | Get the theoretical total flops for a device name. | +| [`is_using_tf32`](#nemo_rl-utils-flops_tracker-is_using_tf32) | Check if the current device is using TF32. | + +### Data + +[`THEORETICAL_TFLOPS`](#nemo_rl-utils-flops_tracker-THEORETICAL_TFLOPS) + +### API + + + + + +```python +class nemo_rl.utils.flops_tracker.FLOPTracker( + model_name: str, + base_config: nemo_rl.utils.flops_formulas.FLOPSConfig | None = None, + flops_formula: typing.Callable[[FLOPSConfig], float] | None = None +) +``` + + + + + + + + + + + + +```python +nemo_rl.utils.flops_tracker.FLOPTracker.from_config( + model_name: str, + config: transformers.configuration_utils.PretrainedConfig +) -> nemo_rl.utils.flops_tracker.FLOPTracker +``` + + + + + + +classmethod + + + + + + + +```python +nemo_rl.utils.flops_tracker.FLOPTracker.reset() +``` + + + + + + + + + + + + +```python +nemo_rl.utils.flops_tracker.FLOPTracker.track( + n_samples: int, + padded_seq_len: int +) +``` + + + + + + + + + + + + +```python +nemo_rl.utils.flops_tracker.FLOPTracker.track_batch( + sequence_lengths: list[int] +) +``` + + + + + + +Track the flops for a batch of sequences. + + + + + + + + + +```python +nemo_rl.utils.flops_tracker.convert_config_to_flops_config( + config: transformers.configuration_utils.PretrainedConfig +) -> tuple[nemo_rl.utils.flops_formulas.FLOPSConfig, typing.Callable] +``` + + + + + + +Convert a pretrained config to a tuple containing a FLOPSConfig and a flops formula. + + + + + + + + +```python +nemo_rl.utils.flops_tracker.get_default_hf_config( + model_name: str +) -> transformers.configuration_utils.PretrainedConfig +``` + + + + + + +Get the default Hugging Face config for a model. + +Both the DTensor and MCore paths use the same default config, we initialize the model config +here to allow computation of theoretical flops which is agnostic to the backend. + + + + + + + + +```python +nemo_rl.utils.flops_tracker.get_theoretical_tflops( + device_name: str, + model_dtype: torch.dtype +) -> float +``` + + + + + + +Get the theoretical total flops for a device name. + + + + + + + + +```python +nemo_rl.utils.flops_tracker.is_using_tf32() -> bool +``` + + + + + + +Check if the current device is using TF32. + + + + + + + + +```python +nemo_rl.utils.flops_tracker.THEORETICAL_TFLOPS = {('NVIDIA A100 80GB PCIe', torch.bfloat16): 624 / 2, ('NVIDIA A100 80GB PCIe', t... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx new file mode 100644 index 0000000..b78a4b3 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/logger.mdx @@ -0,0 +1,1856 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/logger +title: nemo_rl.utils.logger +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`GPUMonitoringConfig`](#nemo_rl-utils-logger-GPUMonitoringConfig) | - | +| [`GpuMetricSnapshot`](#nemo_rl-utils-logger-GpuMetricSnapshot) | - | +| [`Logger`](#nemo_rl-utils-logger-Logger) | Main logger class that delegates to multiple backend loggers. | +| [`LoggerConfig`](#nemo_rl-utils-logger-LoggerConfig) | - | +| [`LoggerInterface`](#nemo_rl-utils-logger-LoggerInterface) | Abstract base class for logger backends. | +| [`MLflowConfig`](#nemo_rl-utils-logger-MLflowConfig) | - | +| [`MLflowLogger`](#nemo_rl-utils-logger-MLflowLogger) | MLflow logger backend. | +| [`RayGpuMonitorLogger`](#nemo_rl-utils-logger-RayGpuMonitorLogger) | Monitor GPU utilization across a Ray cluster and log metrics to a parent logger. | +| [`SwanlabConfig`](#nemo_rl-utils-logger-SwanlabConfig) | - | +| [`SwanlabLogger`](#nemo_rl-utils-logger-SwanlabLogger) | SwanLab logger backend. | +| [`TensorboardConfig`](#nemo_rl-utils-logger-TensorboardConfig) | - | +| [`TensorboardLogger`](#nemo_rl-utils-logger-TensorboardLogger) | Tensorboard logger backend. | +| [`WandbConfig`](#nemo_rl-utils-logger-WandbConfig) | - | +| [`WandbLogger`](#nemo_rl-utils-logger-WandbLogger) | Weights & Biases logger backend. | + +### Functions + +| Name | Description | +|------|-------------| +| [`configure_rich_logging`](#nemo_rl-utils-logger-configure_rich_logging) | Configure rich logging for more visually appealing log output. | +| [`flatten_dict`](#nemo_rl-utils-logger-flatten_dict) | Flatten a nested dictionary. | +| [`get_next_experiment_dir`](#nemo_rl-utils-logger-get_next_experiment_dir) | Create a new experiment directory with an incremented ID. | +| [`print_message_log_samples`](#nemo_rl-utils-logger-print_message_log_samples) | Visualization for message logs and rewards using a more visual approach with emoji indicators and horizontal layout. | + +### Data + +[`_rich_logging_configured`](#nemo_rl-utils-logger-_rich_logging_configured) + +### API + + + + + +```python +class nemo_rl.utils.logger.GPUMonitoringConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.GpuMetricSnapshot +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.Logger( + cfg: nemo_rl.utils.logger.LoggerConfig +) +``` + + + + + + +**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) + +Main logger class that delegates to multiple backend loggers. + + + + + + + + + + + +```python +nemo_rl.utils.logger.Logger.__del__() -> None +``` + + + + + + +Clean up resources when the logger is destroyed. + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_batched_dict_as_jsonl( + to_log: nemo_rl.distributed.batched_data_dict.BatchedDataDict[typing.Any] | dict[str, typing.Any], + filename: str +) -> None +``` + + + + + + +Log a list of dictionaries to a JSONL file. + +**Parameters:** + + +BatchedDataDict to log + + + +Filename to log to (within the log directory) + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log histogram metrics to all backends if available. + +**Parameters:** + + +List of histogram values + + + +Global step value + + + +Name of the metric + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +Log hyperparameters to all enabled backends. + +**Parameters:** + + +Dict of hyperparameters to log + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +Log metrics to all enabled backends. + +**Parameters:** + + +Dict of metrics to log + + + +Global step value + + + +Optional prefix for metric names + + + +Optional name of a field in metrics to use as step instead + of the provided step value (currently only needed for wandb) + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +Log a matplotlib figure to all backends. + +**Parameters:** + + +Matplotlib figure to log + + + +Global step value + + + +Name of the plot + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_plot_per_worker_timeline_metrics( + metrics: dict[int, list[typing.Any]], + step: int, + prefix: str, + name: str, + timeline_interval: float +) -> None +``` + + + + + + +Log a plot of per-worker timeline metrics. + +**Parameters:** + + +Dictionary of metrics to log, where the keys are the worker IDs and the values are the lists of metric values + + + +dict[str, list[Any]] = {worker_id: [metric_value_1, metric_value_2, ...]} + + + +Global step value + + + +Name of the plot + + + +Interval between timeline points (in seconds) + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_plot_token_mult_prob_error( + data: dict[str, typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log a plot of log probability errors in samples. + +This function logs & plots the per-token log-probabilities and errors over the sequence +for the sample with the highest multiplicative probability error in the batch. + +**Parameters:** + + +Dictionary of log probability samples + + + +Global step value + + + +Name of the plot + + + + + + + + +```python +nemo_rl.utils.logger.Logger.log_string_list_as_jsonl( + to_log: list[str], + filename: str +) -> None +``` + + + + + + +Log a list of strings to a JSONL file. + +**Parameters:** + + +list of strings to log + + + +Filename to log to (within the log directory) + + + + + + + + + + +```python +class nemo_rl.utils.logger.LoggerConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.LoggerInterface() +``` + + + + + + +Abstract + +Abstract base class for logger backends. + + + + + + +```python +nemo_rl.utils.logger.LoggerInterface.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +abstract + +Log histogram metrics. + + + + + + + +```python +nemo_rl.utils.logger.LoggerInterface.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +abstract + +Log dictionary of hyperparameters. + + + + + + + +```python +nemo_rl.utils.logger.LoggerInterface.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +abstract + +Log a dictionary of metrics. + + + + + + + +```python +nemo_rl.utils.logger.LoggerInterface.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +abstract + +Log a matplotlib figure. + + + + + + + + + +```python +class nemo_rl.utils.logger.MLflowConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.MLflowLogger( + cfg: nemo_rl.utils.logger.MLflowConfig, + log_dir: typing.Optional[str] = None +) +``` + + + + + + +**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) + +MLflow logger backend. + + + + + + + + +```python +nemo_rl.utils.logger.MLflowLogger.__del__() -> None +``` + + + + + + +Clean up resources when the logger is destroyed. + + + + + + + +```python +nemo_rl.utils.logger.MLflowLogger.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log histogram metrics to MLflow. + + + + + + + +```python +nemo_rl.utils.logger.MLflowLogger.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +Log hyperparameters to MLflow. + +**Parameters:** + + +Dictionary of hyperparameters to log + + + + + + + + +```python +nemo_rl.utils.logger.MLflowLogger.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +Log metrics to MLflow. + +**Parameters:** + + +Dict of metrics to log + + + +Global step value + + + +Optional prefix for metric names + + + +Optional step metric name (ignored in MLflow) + + + + + + + + +```python +nemo_rl.utils.logger.MLflowLogger.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +Log a plot to MLflow. + +**Parameters:** + + +Matplotlib figure to log + + + +Global step value + + + +Name of the plot + + + + + + + + + + +```python +class nemo_rl.utils.logger.RayGpuMonitorLogger( + collection_interval: int | float, + flush_interval: int | float, + metric_prefix: str, + step_metric: str, + parent_logger: typing.Optional[nemo_rl.utils.logger.Logger] = None +) +``` + + + + + + +Monitor GPU utilization across a Ray cluster and log metrics to a parent logger. + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._collect( + metrics: bool = False, + sku: bool = False +) -> dict[str, typing.Any] +``` + + + + + + +Collect GPU metrics from all Ray nodes. + +**Returns:** `dict[str, Any]` + +Dictionary of collected metrics + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._collect_gpu_sku() -> dict[str, str] +``` + + + + + + +Collect GPU SKU from all Ray nodes. + +Note: This is an internal API and users are not expected to call this. + +**Returns:** `dict[str, str]` + +Dictionary of SKU types on all Ray nodes + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._collect_metrics() -> dict[str, typing.Any] +``` + + + + + + +Collect GPU metrics from all Ray nodes. + +**Returns:** `dict[str, Any]` + +Dictionary of collected metrics + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._collection_loop() -> None +``` + + + + + + +Main collection loop that runs in a separate thread. + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._fetch_and_parse_metrics( + node_idx: int, + metric_address: str, + parser_fn: typing.Callable +) +``` + + + + + + +Fetch metrics from a node and parse GPU metrics. + +**Parameters:** + + +Index of the node + + + +Address of the metrics endpoint + + +**Returns:** + +Dictionary of GPU metrics + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._parse_gpu_sku( + sample: prometheus_client.samples.Sample, + node_idx: int +) -> dict[str, str] +``` + + + + + + +Parse a GPU metric sample into a standardized format. + +**Parameters:** + + +Prometheus metric sample + + + +Index of the node + + +**Returns:** `dict[str, str]` + +Dictionary with metric name and value + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger._parse_metric( + sample: prometheus_client.samples.Sample, + node_idx: int +) -> dict[str, typing.Any] +``` + + + + + + +Parse a metric sample into a standardized format. + +**Parameters:** + + +Prometheus metric sample + + + +Index of the node + + +**Returns:** `dict[str, Any]` + +Dictionary with metric name and value + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger.flush() -> None +``` + + + + + + +Flush collected metrics to the parent logger. + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger.start() -> None +``` + + + + + + +Start the GPU monitoring thread. + + + + + + + +```python +nemo_rl.utils.logger.RayGpuMonitorLogger.stop() -> None +``` + + + + + + +Stop the GPU monitoring thread. + + + + + + + + + +```python +class nemo_rl.utils.logger.SwanlabConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.SwanlabLogger( + cfg: nemo_rl.utils.logger.SwanlabConfig, + log_dir: typing.Optional[str] = None +) +``` + + + + + + +**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) + +SwanLab logger backend. + + + + + + + + +```python +nemo_rl.utils.logger.SwanlabLogger.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log histogram metrics to swanlab. + + + + + + + +```python +nemo_rl.utils.logger.SwanlabLogger.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +Update the Swanlab run configuration with the provided hyperparameters. + +**Parameters:** + + +Mapping of hyperparameter names to values to store in the run configuration. + + + + + + + + +```python +nemo_rl.utils.logger.SwanlabLogger.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +Log metrics to the associated Swanlab run. + +**Parameters:** + + +Mapping of metric names to metric values. + + + +Global step value to associate with all logged metrics. + + + +Optional prefix applied to metric names; metric names equal to `step_metric` are not prefixed. + + + +Name of a metric that should be excluded from prefixing. + + + + + + + + +```python +nemo_rl.utils.logger.SwanlabLogger.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +Log a plot to swanlab. + +**Parameters:** + + +Matplotlib figure to log + + + +Global step value + + + + + + + + + + +```python +class nemo_rl.utils.logger.TensorboardConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + +```python +class nemo_rl.utils.logger.TensorboardLogger( + cfg: nemo_rl.utils.logger.TensorboardConfig, + log_dir: typing.Optional[str] = None +) +``` + + + + + + +**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) + +Tensorboard logger backend. + + + + + + + + +```python +nemo_rl.utils.logger.TensorboardLogger._coerce_to_scalar( + value: typing.Any +) -> int | float | bool | str | None +``` + + + + + + +staticmethod + +Coerce a value to a Python scalar for TensorBoard logging. + +Returns the coerced value, or None if it can't be converted to a scalar. + + + + + + + +```python +nemo_rl.utils.logger.TensorboardLogger.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log histogram metrics to Tensorboard. + + + + + + + +```python +nemo_rl.utils.logger.TensorboardLogger.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +Log hyperparameters to Tensorboard. + +**Parameters:** + + +Dictionary of hyperparameters to log + + + + + + + + +```python +nemo_rl.utils.logger.TensorboardLogger.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +Log metrics to Tensorboard. + +**Parameters:** + + +Dict of metrics to log + + + +Global step value + + + +Optional prefix for metric names + + + +Optional step metric name (ignored in TensorBoard) + + + + + + + + +```python +nemo_rl.utils.logger.TensorboardLogger.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +Log a plot to Tensorboard. + +**Parameters:** + + +Dictionary of plot data + + + +Global step value + + + + + + + + + + +```python +class nemo_rl.utils.logger.WandbConfig +``` + + + + + + +**Bases:** `typing.TypedDict` + + + + + + + + + + + + + + + +```python +class nemo_rl.utils.logger.WandbLogger( + cfg: nemo_rl.utils.logger.WandbConfig, + log_dir: typing.Optional[str] = None +) +``` + + + + + + +**Bases:** [LoggerInterface](#nemo_rl-utils-logger-LoggerInterface) + +Weights & Biases logger backend. + + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger._log_code() +``` + + + + + + +Log code that is tracked by git to wandb. + +This function gets a list of all files tracked by git in the project root +and manually uploads them to the current wandb run as an artifact. + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger._log_diffs() +``` + + + + + + +Log git diffs to wandb. + +This function captures and logs two types of diffs: +1. Uncommitted changes (working tree diff against HEAD) +2. All changes (including uncommitted) against the main branch + +Each diff is saved as a text file in a wandb artifact. + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger.define_metric( + name: str, + step_metric: typing.Optional[str] = None +) -> None +``` + + + + + + +Define a metric with custom step metric. + +**Parameters:** + + +Name of the metric or pattern (e.g. 'ray/*') + + + +Optional name of the step metric to use + + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger.log_histogram( + histogram: list[typing.Any], + step: int, + name: str +) -> None +``` + + + + + + +Log histogram metrics to wandb. + +**Parameters:** + + +List of histogram values + + + +Global step value + + + +Name of the metric + + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger.log_hyperparams( + params: typing.Mapping[str, typing.Any] +) -> None +``` + + + + + + +Log hyperparameters to wandb. + +**Parameters:** + + +Dict of hyperparameters to log + + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger.log_metrics( + metrics: dict[str, typing.Any], + step: int, + prefix: typing.Optional[str] = '', + step_metric: typing.Optional[str] = None, + step_finished: bool = False +) -> None +``` + + + + + + +Log metrics to wandb. + +**Parameters:** + + +Dict of metrics to log + + + +Global step value + + + +Optional prefix for metric names + + + +Optional name of a field in metrics to use as step instead + of the provided step value + + + + + + + + +```python +nemo_rl.utils.logger.WandbLogger.log_plot( + figure: matplotlib.pyplot.Figure, + step: int, + name: str +) -> None +``` + + + + + + +Log a plot to wandb. + +**Parameters:** + + +Matplotlib figure to log + + + +Global step value + + + + + + + + + + +```python +nemo_rl.utils.logger.configure_rich_logging( + level: str = 'INFO', + show_time: bool = True, + show_path: bool = True +) -> None +``` + + + + + + +Configure rich logging for more visually appealing log output. + +**Parameters:** + + +The logging level to use + + + +Whether to show timestamps in logs + + + +Whether to show file paths in logs + + + + + + + + + +```python +nemo_rl.utils.logger.flatten_dict( + d: typing.Mapping[str, typing.Any], + sep: str = '.' +) -> dict[str, typing.Any] +``` + + + + + + +Flatten a nested dictionary. + +Handles nested dictionaries and lists by creating keys with separators. +For lists, the index is used as part of the key. + +**Parameters:** + + +Dictionary to flatten + + + +Separator to use between nested keys + + +**Returns:** `dict[str, Any]` + +Flattened dictionary with compound keys + +**Examples:** + + + +```python +>>> from nemo_rl.utils.logger import flatten_dict +>>> flatten_dict({"a": 1, "b": {"c": 2}}) +{'a': 1, 'b.c': 2} + +>>> flatten_dict({"a": [1, 2], "b": {"c": [3, 4]}}) +{'a.0': 1, 'a.1': 2, 'b.c.0': 3, 'b.c.1': 4} + +>>> flatten_dict({"a": [{"b": 1}, {"c": 2}]}) +{'a.0.b': 1, 'a.1.c': 2} +``` + + + + + + + + + + +```python +nemo_rl.utils.logger.get_next_experiment_dir( + base_log_dir: str +) -> str +``` + + + + + + +Create a new experiment directory with an incremented ID. + +**Parameters:** + + +The base log directory path + + +**Returns:** `str` + +Path to the new experiment directory with incremented ID + + + + + + + + +```python +nemo_rl.utils.logger.print_message_log_samples( + message_logs: list[nemo_rl.data.interfaces.LLMMessageLogType], + rewards: list[float], + num_samples: int = 5, + step: int = 0 +) -> None +``` + + + + + + +Visualization for message logs and rewards using a more visual approach with emoji indicators and horizontal layout. + +**Parameters:** + + +List of message logs to sample from + + + +List of rewards corresponding to each message log + + + +Number of samples to display (default: 5) + + + +Current training step (for display purposes) + + + + + + + + + +```python +nemo_rl.utils.logger._rich_logging_configured = False +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx new file mode 100644 index 0000000..e06cd16 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/memory_tracker.mdx @@ -0,0 +1,122 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/memory_tracker +title: nemo_rl.utils.memory_tracker +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`MemoryTracker`](#nemo_rl-utils-memory_tracker-MemoryTracker) | - | +| [`MemoryTrackerDataPoint`](#nemo_rl-utils-memory_tracker-MemoryTrackerDataPoint) | - | + +### API + + + + + +```python +class nemo_rl.utils.memory_tracker.MemoryTracker() +``` + + + + + + +**Bases:** `BaseModel` + + + + + + + +```python +nemo_rl.utils.memory_tracker.MemoryTracker.model_post_init( + context +) +``` + + + + + + + + + + + + +```python +nemo_rl.utils.memory_tracker.MemoryTracker.snapshot_start_of_stage( + new_stage: str, + all_current_variables: typing.List[str] +) -> None +``` + + + + + + + + + + + + + + +```python +class nemo_rl.utils.memory_tracker.MemoryTrackerDataPoint() +``` + + + + + + +**Bases:** `BaseModel` + + + + + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.memory_tracker.MemoryTrackerDataPoint.get_snapshot_str() -> str +``` + + + + + + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx new file mode 100644 index 0000000..073652c --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/native_checkpoint.mdx @@ -0,0 +1,351 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/native_checkpoint +title: nemo_rl.utils.native_checkpoint +--- + +Checkpoint management utilities for HF models. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ModelState`](#nemo_rl-utils-native_checkpoint-ModelState) | Helper class for tracking model state in distributed checkpointing. | +| [`OptimizerState`](#nemo_rl-utils-native_checkpoint-OptimizerState) | Helper class for tracking optimizer state in distributed checkpointing. | + +### Functions + +| Name | Description | +|------|-------------| +| [`convert_dcp_to_hf`](#nemo_rl-utils-native_checkpoint-convert_dcp_to_hf) | Convert a Torch DCP checkpoint to a Hugging Face checkpoint. | +| [`load_checkpoint`](#nemo_rl-utils-native_checkpoint-load_checkpoint) | Load a model weights and optionally optimizer state. | +| [`save_checkpoint`](#nemo_rl-utils-native_checkpoint-save_checkpoint) | Save a checkpoint of the model and optionally optimizer state. | + +### API + + + + + +```python +class nemo_rl.utils.native_checkpoint.ModelState( + model: torch.nn.Module +) +``` + + + + + + +**Bases:** `Stateful` + +Helper class for tracking model state in distributed checkpointing. + +This class is compliant with the Stateful protocol, allowing DCP to automatically +call state_dict/load_state_dict as needed in the dcp.save/load APIs. + +**Parameters:** + + +The PyTorch model to track. + + + + + + + +```python +nemo_rl.utils.native_checkpoint.ModelState.load_state_dict( + state_dict: dict[str, typing.Any] +) -> None +``` + + + + + + +Load the state dictionary into the model. + +**Parameters:** + + +State dictionary to load. + + + + + + + + +```python +nemo_rl.utils.native_checkpoint.ModelState.state_dict() -> dict[str, typing.Any] +``` + + + + + + +Get the model's state dictionary. + +**Returns:** `dict[str, Any]` + +Dictionary containing the model's state dict with CPU offloading enabled. + + + + + + + + + +```python +class nemo_rl.utils.native_checkpoint.OptimizerState( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: typing.Optional[typing.Any] = None +) +``` + + + + + + +**Bases:** `Stateful` + +Helper class for tracking optimizer state in distributed checkpointing. + +This class is compliant with the Stateful protocol, allowing DCP to automatically +call state_dict/load_state_dict as needed in the dcp.save/load APIs. + +**Parameters:** + + +The PyTorch model associated with the optimizer. + + + +The optimizer to track. + + + +Optional learning rate scheduler. + + + + + + + +```python +nemo_rl.utils.native_checkpoint.OptimizerState.load_state_dict( + state_dict: dict[str, typing.Any] +) -> None +``` + + + + + + +Load the state dictionaries into the optimizer and scheduler. + +**Parameters:** + + +State dictionary containing optimizer and scheduler states to load. + + + + + + + + +```python +nemo_rl.utils.native_checkpoint.OptimizerState.state_dict() -> dict[str, typing.Any] +``` + + + + + + +Get the optimizer and scheduler state dictionaries. + +**Returns:** `dict[str, Any]` + +Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled. + + + + + + + + + +```python +nemo_rl.utils.native_checkpoint.convert_dcp_to_hf( + dcp_ckpt_path: str, + hf_ckpt_path: str, + model_name_or_path: str, + tokenizer_name_or_path: str, + overwrite: bool = False, + hf_overrides: typing.Optional[dict[str, typing.Any]] = {} +) -> str +``` + + + + + + +Convert a Torch DCP checkpoint to a Hugging Face checkpoint. + +This is not an optimized utility. If checkpoint is too large, consider saving DCP during training +and using this utility to convert to HF format. + +**Parameters:** + + +Path to DCP checkpoint + + + +Path to save HF checkpoint + + + +Model name or path for config + + + +Tokenizer name or path. + Defaults to model_name_or_path if None. + + + +Whether to overwrite existing checkpoint. Defaults to False. + + +**Returns:** `str` + +Path to the saved HF checkpoint + +**Raises:** + +- `FileExistsError`: If HF checkpoint already exists and overwrite is False + + + + + + + + +```python +nemo_rl.utils.native_checkpoint.load_checkpoint( + model: torch.nn.Module, + weights_path: str, + optimizer: typing.Optional[torch.optim.Optimizer] = None, + scheduler: typing.Optional[typing.Any] = None, + optimizer_path: typing.Optional[str] = None +) -> None +``` + + + + + + +Load a model weights and optionally optimizer state. + +**Parameters:** + + +The PyTorch model whose weights to update + + + +Path to load model weights from + + + +Optional optimizer to load state into + + + +Optional scheduler to load state into + + + +Path to load optimizer state from (required if optimizer provided) + + + + + + + + + +```python +nemo_rl.utils.native_checkpoint.save_checkpoint( + model: torch.nn.Module, + weights_path: str, + optimizer: typing.Optional[torch.optim.Optimizer] = None, + scheduler: typing.Optional[typing.Any] = None, + optimizer_path: typing.Optional[str] = None, + tokenizer: typing.Optional[typing.Any] = None, + tokenizer_path: typing.Optional[str] = None +) -> None +``` + + + + + + +Save a checkpoint of the model and optionally optimizer state. + +**Parameters:** + + +The PyTorch model to save + + + +Path to save model weights + + + +Optional optimizer to save + + + +Optional scheduler to save + + + +Path to save optimizer state (required if optimizer provided) + + + +Optional tokenizer to save + + + +Path to save tokenizer state (required if tokenizer provided) + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx new file mode 100644 index 0000000..1c32117 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/nsys.mdx @@ -0,0 +1,138 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/nsys +title: nemo_rl.utils.nsys +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`ProfilablePolicy`](#nemo_rl-utils-nsys-ProfilablePolicy) | - | + +### Functions + +| Name | Description | +|------|-------------| +| [`maybe_gpu_profile_step`](#nemo_rl-utils-nsys-maybe_gpu_profile_step) | - | +| [`wrap_with_nvtx_name`](#nemo_rl-utils-nsys-wrap_with_nvtx_name) | A decorator to wrap a function with an NVTX range with the given name. | + +### Data + +[`NRL_NSYS_PROFILE_STEP_RANGE`](#nemo_rl-utils-nsys-NRL_NSYS_PROFILE_STEP_RANGE) + +[`NRL_NSYS_WORKER_PATTERNS`](#nemo_rl-utils-nsys-NRL_NSYS_WORKER_PATTERNS) + +### API + + + + + +```python +class nemo_rl.utils.nsys.ProfilablePolicy() +``` + + + + + + +Protocol + + + + + +```python +nemo_rl.utils.nsys.ProfilablePolicy.start_gpu_profiling() -> None +``` + + + + + + + + + + + + +```python +nemo_rl.utils.nsys.ProfilablePolicy.stop_gpu_profiling() -> None +``` + + + + + + + + + + + + + + +```python +nemo_rl.utils.nsys.maybe_gpu_profile_step( + policy: nemo_rl.utils.nsys.ProfilablePolicy, + step: int +) +``` + + + + + + + + + + + + + +```python +nemo_rl.utils.nsys.wrap_with_nvtx_name( + name: str +) +``` + + + + + + +A decorator to wrap a function with an NVTX range with the given name. + + + + + + + + +```python +nemo_rl.utils.nsys.NRL_NSYS_PROFILE_STEP_RANGE = os.environ.get('NRL_NSYS_PROFILE_STEP_RANGE', '') +``` + + + + + + + + + +```python +nemo_rl.utils.nsys.NRL_NSYS_WORKER_PATTERNS = os.environ.get('NRL_NSYS_WORKER_PATTERNS', '') +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx new file mode 100644 index 0000000..a8ada45 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/nvml.mdx @@ -0,0 +1,100 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/nvml +title: nemo_rl.utils.nvml +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`device_id_to_physical_device_id`](#nemo_rl-utils-nvml-device_id_to_physical_device_id) | Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES. | +| [`get_device_uuid`](#nemo_rl-utils-nvml-get_device_uuid) | Get the UUID of a CUDA device using NVML. | +| [`get_free_memory_bytes`](#nemo_rl-utils-nvml-get_free_memory_bytes) | Get the free memory of a CUDA device in bytes using NVML. | +| [`nvml_context`](#nemo_rl-utils-nvml-nvml_context) | Context manager for NVML initialization and shutdown. | + +### API + + + + + +```python +nemo_rl.utils.nvml.device_id_to_physical_device_id( + device_id: int +) -> int +``` + + + + + + +Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES. + + + + + + + + +```python +nemo_rl.utils.nvml.get_device_uuid( + device_idx: int +) -> str +``` + + + + + + +Get the UUID of a CUDA device using NVML. + + + + + + + + +```python +nemo_rl.utils.nvml.get_free_memory_bytes( + device_idx: int +) -> float +``` + + + + + + +Get the free memory of a CUDA device in bytes using NVML. + + + + + + + + +```python +nemo_rl.utils.nvml.nvml_context() -> typing.Generator[None, None, None] +``` + + + + + + +Context manager for NVML initialization and shutdown. + +**Raises:** + +- `RuntimeError`: If NVML initialization fails + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx new file mode 100644 index 0000000..786e38c --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/packed_tensor.mdx @@ -0,0 +1,140 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/packed_tensor +title: nemo_rl.utils.packed_tensor +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_num_buffers`](#nemo_rl-utils-packed_tensor-get_num_buffers) | - | +| [`get_target_packed_tensor_size`](#nemo_rl-utils-packed_tensor-get_target_packed_tensor_size) | - | +| [`packed_broadcast_consumer`](#nemo_rl-utils-packed_tensor-packed_broadcast_consumer) | Consume a packed tensor and unpack it into a list of tensors. | +| [`packed_broadcast_producer`](#nemo_rl-utils-packed_tensor-packed_broadcast_producer) | Broadcast a list of tensors in a packed manner. | + +### API + + + + + +```python +nemo_rl.utils.packed_tensor.get_num_buffers() +``` + + + + + + + + + + + + + +```python +nemo_rl.utils.packed_tensor.get_target_packed_tensor_size() +``` + + + + + + + + + + + + + +```python +nemo_rl.utils.packed_tensor.packed_broadcast_consumer( + iterator, + group, + src, + post_unpack_func +) +``` + + + + + + +Consume a packed tensor and unpack it into a list of tensors. + +**Parameters:** + + +iterator of model parameters. Returns a tuple of (name, tensor) + + + +process group (vllm PyNcclCommunicator) + + + +source rank (0 in current implementation) + + + +function to apply to each tensor after unpacking + + +**Returns:** + +None + + + + + + + + +```python +nemo_rl.utils.packed_tensor.packed_broadcast_producer( + iterator, + group, + src, + post_iter_func +) +``` + + + + + + +Broadcast a list of tensors in a packed manner. + +**Parameters:** + + +iterator of model parameters. Returns a tuple of (name, tensor) + + + +process group (vllm PyNcclCommunicator) + + + +source rank (0 in current implementation) + + + +function to apply to each tensor before packing, should return a tensor + + +**Returns:** + +None + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx new file mode 100644 index 0000000..b0fb043 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/prefetch_venvs.mdx @@ -0,0 +1,108 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/prefetch_venvs +title: nemo_rl.utils.prefetch_venvs +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`create_frozen_environment_symlinks`](#nemo_rl-utils-prefetch_venvs-create_frozen_environment_symlinks) | Create python-{ClassName} wrapper scripts in /usr/local/bin for frozen environment support. | +| [`prefetch_venvs`](#nemo_rl-utils-prefetch_venvs-prefetch_venvs) | Prefetch all virtual environments that will be used by workers. | + +### Data + +[`args`](#nemo_rl-utils-prefetch_venvs-args) + +[`parser`](#nemo_rl-utils-prefetch_venvs-parser) + +### API + + + + + +```python +nemo_rl.utils.prefetch_venvs.create_frozen_environment_symlinks( + venv_configs +) +``` + + + + + + +Create python-{ClassName} wrapper scripts in /usr/local/bin for frozen environment support. + +Only runs in container (when NRL_CONTAINER=1 is set). + +**Parameters:** + + +Dictionary mapping py_executable to list of actor FQNs + + + + + + + + + +```python +nemo_rl.utils.prefetch_venvs.prefetch_venvs( + filters = None, + negative_filters = None +) +``` + + + + + + +Prefetch all virtual environments that will be used by workers. + +**Parameters:** + + +List of strings to match against actor FQNs. If provided, only + actors whose FQN contains at least one of the filter strings will + be prefetched. If None, all venvs are prefetched. + + + +List of strings to exclude from prefetching. Actors whose + FQN contains any of these strings will be skipped. + + + + + + + + + +```python +nemo_rl.utils.prefetch_venvs.args = parser.parse_args() +``` + + + + + + + + + +```python +nemo_rl.utils.prefetch_venvs.parser = argparse.ArgumentParser(description='Prefetch virtual environments for Ray actor... +``` + + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx new file mode 100644 index 0000000..13c7047 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/timer.mdx @@ -0,0 +1,441 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/timer +title: nemo_rl.utils.timer +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`TimeoutChecker`](#nemo_rl-utils-timer-TimeoutChecker) | - | +| [`Timer`](#nemo_rl-utils-timer-Timer) | A utility for timing code execution. | + +### Functions + +| Name | Description | +|------|-------------| +| [`convert_to_seconds`](#nemo_rl-utils-timer-convert_to_seconds) | Converts a time string in the format 'DD:HH:MM:SS' to total seconds. | + +### API + + + + + +```python +class nemo_rl.utils.timer.TimeoutChecker( + timeout: typing.Optional[str] = '00:03:45:00', + fit_last_save_time: bool = False +) +``` + + + + + + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.timer.TimeoutChecker.check_save() +``` + + + + + + + + + + + + +```python +nemo_rl.utils.timer.TimeoutChecker.mark_iteration() +``` + + + + + + + + + + + + +```python +nemo_rl.utils.timer.TimeoutChecker.start_iterations() +``` + + + + + + + + + + + + + + +```python +class nemo_rl.utils.timer.Timer() +``` + + + + + + +A utility for timing code execution. + +Supports two usage patterns: +1. Explicit start/stop: timer.start("label"), timer.stop("label") +2. Context manager: with timer.time("label"): ... + +The timer keeps track of multiple timing measurements for each label, +and supports different reductions on these measurements (mean, median, +min, max, std dev). + +Example usage: + + +```python +timer = Timer() + +# Method 1: start/stop +timer.start("load_data") +data = load_data() +timer.stop("load_data") + +# Method 2: context manager +with timer.time("model_forward"): + model_outputs = model(inputs) + +# Multiple timing measurements for the same operation +for batch in dataloader: + with timer.time("model_forward_multiple"): + outputs = model(batch) + +# Get all times for one label +model_forward_times = timer.get_elapsed("model_forward_multiple") + +# Get reductions for one label +mean_forward_time = timer.reduce("model_forward_multiple") +max_forward_time = timer.reduce("model_forward_multiple", "max") +``` + + + + + + + + + + + + + + + + +```python +nemo_rl.utils.timer.Timer.get_elapsed( + label: str +) -> list[float] +``` + + + + + + +Get all elapsed time measurements for a specific label. + +**Parameters:** + + +The timing label to get elapsed times for + + +**Returns:** `list[float]` + +A list of all elapsed time measurements in seconds + +**Raises:** + +- `KeyError`: If the label doesn't exist + + + + + + + +```python +nemo_rl.utils.timer.Timer.get_latest_elapsed( + label: str +) -> float +``` + + + + + + +Get the most recent elapsed time measurement for a specific label. + +**Parameters:** + + +The timing label to get the latest elapsed time for + + +**Returns:** `float` + +The most recent elapsed time measurement in seconds + +**Raises:** + +- `KeyError`: If the label doesn't exist +- `IndexError`: If the label exists but has no measurements + + + + + + + +```python +nemo_rl.utils.timer.Timer.get_timing_metrics( + reduction_op: typing.Union[str, dict[str, str]] = 'mean' +) -> dict[str, float | list[float]] +``` + + + + + + +Get all timing measurements with optional reduction. + +**Parameters:** + + +Either a string specifying a reduction operation to apply to all labels, + or a dictionary mapping specific labels to reduction operations. + Valid reduction operations are: "mean", "median", "min", "max", "std", "sum", "count". + If a label is not in the dictionary, no reduction is applied and all measurements are returned. + + +**Returns:** `dict[str, float | list[float]]` + +A dictionary mapping labels to either: + +**Raises:** + +- `ValueError`: If an invalid reduction operation is provided + + + + + + + +```python +nemo_rl.utils.timer.Timer.reduce( + label: str, + operation: str = 'mean' +) -> float +``` + + + + + + +Apply a reduction function to timing measurements for the specified label. + +**Parameters:** + + +The timing label to get reduction for + + + +The type of reduction to apply. Valid options are: +- "mean": Average time (default) +- "median": Median time +- "min": Minimum time +- "max": Maximum time +- "std": Standard deviation +- "sum": Total time +- "count": Number of measurements + + +**Returns:** `float` + +A single float with the reduction result + +**Raises:** + +- `KeyError`: If the label doesn't exist +- `ValueError`: If an invalid operation is provided + + + + + + + +```python +nemo_rl.utils.timer.Timer.reset( + label: typing.Optional[str] = None +) -> None +``` + + + + + + +Reset timings for the specified label or all labels. + +**Parameters:** + + +Optional label to reset. If None, resets all timers. + + + + + + + + +```python +nemo_rl.utils.timer.Timer.start( + label: str +) -> None +``` + + + + + + +Start timing for the given label. + + + + + + + +```python +nemo_rl.utils.timer.Timer.stop( + label: str +) -> float +``` + + + + + + +Stop timing for the given label and return the elapsed time. + +**Parameters:** + + +The label to stop timing for + + +**Returns:** `float` + +The elapsed time in seconds + +**Raises:** + +- `ValueError`: If the timer for the given label is not running + + + + + + + +```python +nemo_rl.utils.timer.Timer.time( + label: str +) -> typing.Generator[None, None, None] +``` + + + + + + +Context manager for timing a block of code. + +**Parameters:** + + +The label to use for this timing + + + + + + + + + + +```python +nemo_rl.utils.timer.convert_to_seconds( + time_string: str +) -> int +``` + + + + + + +Converts a time string in the format 'DD:HH:MM:SS' to total seconds. + +**Parameters:** + + +Time duration string, e.g., '00:03:45:00'. + + +**Returns:** `int` + +Total time in seconds. + + + diff --git a/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx new file mode 100644 index 0000000..f2e6973 --- /dev/null +++ b/fern/library-docs/nemo-rl-docs/nemo-rl/nemo_rl/utils/venvs.mdx @@ -0,0 +1,177 @@ +--- +layout: overview +slug: nemo-rl/nemo_rl/utils/venvs +title: nemo_rl.utils.venvs +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`_env_builder`](#nemo_rl-utils-venvs-_env_builder) | - | +| [`create_local_venv`](#nemo_rl-utils-venvs-create_local_venv) | Create a virtual environment using uv and execute a command within it. | +| [`create_local_venv_on_each_node`](#nemo_rl-utils-venvs-create_local_venv_on_each_node) | Create a virtual environment on each Ray node. | + +### Data + +[`DEFAULT_VENV_DIR`](#nemo_rl-utils-venvs-DEFAULT_VENV_DIR) + +[`dir_path`](#nemo_rl-utils-venvs-dir_path) + +[`git_root`](#nemo_rl-utils-venvs-git_root) + +[`logger`](#nemo_rl-utils-venvs-logger) + +### API + + + + + +```python +nemo_rl.utils.venvs._env_builder( + py_executable: str, + venv_name: str, + node_idx: int, + force_rebuild: bool = False +) +``` + + + + + + + + + + + + + +```python +nemo_rl.utils.venvs.create_local_venv( + py_executable: str, + venv_name: str, + force_rebuild: bool = False +) -> str +``` + + + + + + +Create a virtual environment using uv and execute a command within it. + +The output can be used as a py_executable for a Ray worker assuming the worker +nodes also have access to the same file system as the head node. + +This function is cached to avoid multiple calls to uv to create the same venv, +which avoids duplicate logging. + +**Parameters:** + + +Command to run with the virtual environment (e.g., "uv.sh run --locked") + + + +Name of the virtual environment (e.g., "foobar.Worker") + + + +If True, force rebuild the venv even if it already exists + + +**Returns:** `str` + +Path to the python executable in the created virtual environment + + + + + + + + +```python +nemo_rl.utils.venvs.create_local_venv_on_each_node( + py_executable: str, + venv_name: str +) +``` + + + + + + +Create a virtual environment on each Ray node. + +**Parameters:** + + +Command to run with the virtual environment + + + +Name of the virtual environment + + +**Returns:** + +Path to the python executable in the created virtual environment + + + + + + + + +```python +nemo_rl.utils.venvs.DEFAULT_VENV_DIR = os.path.join(git_root, 'venvs') +``` + + + + + + + + + +```python +nemo_rl.utils.venvs.dir_path = os.path.dirname(os.path.abspath(__file__)) +``` + + + + + + + + + +```python +nemo_rl.utils.venvs.git_root = os.path.abspath(os.path.join(dir_path, '../..')) +``` + + + + + + + + + +```python +nemo_rl.utils.venvs.logger = logging.getLogger(__name__) +``` + + + + diff --git a/fern/library-docs/ttl-docs/_navigation.yml b/fern/library-docs/ttl-docs/_navigation.yml new file mode 100644 index 0000000..e4a8b13 --- /dev/null +++ b/fern/library-docs/ttl-docs/_navigation.yml @@ -0,0 +1,149 @@ +# AUTO-GENERATED by `fern docs md generate` — DO NOT EDIT +- type: section + title: _mlir_libs + slug: ttl/ttl/_mlir_libs + children: + - type: section + title: _site_initialize_1 + slug: ttl/ttl/_mlir_libs/_site_initialize_1 + children: + - type: page + title: _site_initialize_1 + slug: ttl/ttl/_mlir_libs/_site_initialize_1 + pageId: ttl/ttl/_mlir_libs/_site_initialize_1.mdx +- type: section + title: _src + slug: ttl/ttl/_src + children: + - type: section + title: auto_profile + slug: ttl/ttl/_src/auto_profile + children: + - type: page + title: auto_profile + slug: ttl/ttl/_src/auto_profile + pageId: ttl/ttl/_src/auto_profile.mdx + - type: section + title: tensor_registry + slug: ttl/ttl/_src/tensor_registry + children: + - type: page + title: tensor_registry + slug: ttl/ttl/_src/tensor_registry + pageId: ttl/ttl/_src/tensor_registry.mdx + - type: section + title: ttl_ast + slug: ttl/ttl/_src/ttl_ast + children: + - type: page + title: ttl_ast + slug: ttl/ttl/_src/ttl_ast + pageId: ttl/ttl/_src/ttl_ast.mdx +- type: section + title: circular_buffer + slug: ttl/ttl/circular_buffer + children: + - type: page + title: circular_buffer + slug: ttl/ttl/circular_buffer + pageId: ttl/ttl/circular_buffer.mdx +- type: section + title: constants + slug: ttl/ttl/constants + children: + - type: page + title: constants + slug: ttl/ttl/constants + pageId: ttl/ttl/constants.mdx +- type: section + title: diagnostics + slug: ttl/ttl/diagnostics + children: + - type: page + title: diagnostics + slug: ttl/ttl/diagnostics + pageId: ttl/ttl/diagnostics.mdx +- type: section + title: dialects + slug: ttl/ttl/dialects + children: + - type: section + title: _ods_common + slug: ttl/ttl/dialects/_ods_common + children: + - type: page + title: _ods_common + slug: ttl/ttl/dialects/_ods_common + pageId: ttl/ttl/dialects/_ods_common.mdx + - type: section + title: ttl + slug: ttl/ttl/dialects/ttl + children: + - type: page + title: ttl + slug: ttl/ttl/dialects/ttl + pageId: ttl/ttl/dialects/ttl.mdx +- type: section + title: dtype_utils + slug: ttl/ttl/dtype_utils + children: + - type: page + title: dtype_utils + slug: ttl/ttl/dtype_utils + pageId: ttl/ttl/dtype_utils.mdx +- type: section + title: kernel_runner + slug: ttl/ttl/kernel_runner + children: + - type: page + title: kernel_runner + slug: ttl/ttl/kernel_runner + pageId: ttl/ttl/kernel_runner.mdx +- type: section + title: layouts + slug: ttl/ttl/layouts + children: + - type: page + title: layouts + slug: ttl/ttl/layouts + pageId: ttl/ttl/layouts.mdx +- type: section + title: operators + slug: ttl/ttl/operators + children: + - type: page + title: operators + slug: ttl/ttl/operators + pageId: ttl/ttl/operators.mdx +- type: section + title: ttl + slug: ttl/ttl/ttl + children: + - type: page + title: ttl + slug: ttl/ttl/ttl + pageId: ttl/ttl/ttl.mdx +- type: section + title: ttl_api + slug: ttl/ttl/ttl_api + children: + - type: page + title: ttl_api + slug: ttl/ttl/ttl_api + pageId: ttl/ttl/ttl_api.mdx +- type: section + title: ttl_math + slug: ttl/ttl/ttl_math + children: + - type: page + title: ttl_math + slug: ttl/ttl/ttl_math + pageId: ttl/ttl/ttl_math.mdx +- type: section + title: ttl_utils + slug: ttl/ttl/ttl_utils + children: + - type: page + title: ttl_utils + slug: ttl/ttl/ttl_utils + pageId: ttl/ttl/ttl_utils.mdx diff --git a/fern/library-docs/ttl-docs/ttl/ttl.mdx b/fern/library-docs/ttl-docs/ttl/ttl.mdx new file mode 100644 index 0000000..30235a7 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl.mdx @@ -0,0 +1,60 @@ +--- +layout: overview +slug: ttl/ttl +title: ttl +--- + +## Subpackages + +- **[`ttl._mlir_libs`](/ttl/ttl/_mlir_libs)** +- **[`ttl._src`](/ttl/ttl/_src)** +- **[`ttl.dialects`](/ttl/ttl/dialects)** + +## Submodules + +- **[`ttl.circular_buffer`](/ttl/ttl/circular_buffer)** +- **[`ttl.constants`](/ttl/ttl/constants)** +- **[`ttl.diagnostics`](/ttl/ttl/diagnostics)** +- **[`ttl.dtype_utils`](/ttl/ttl/dtype_utils)** +- **[`ttl.ir`](/ttl/ttl/ir)** +- **[`ttl.kernel_runner`](/ttl/ttl/kernel_runner)** +- **[`ttl.layouts`](/ttl/ttl/layouts)** +- **[`ttl.operators`](/ttl/ttl/operators)** +- **[`ttl.ttl`](/ttl/ttl/ttl)** +- **[`ttl.ttl_api`](/ttl/ttl/ttl_api)** +- **[`ttl.ttl_math`](/ttl/ttl/ttl_math)** +- **[`ttl.ttl_utils`](/ttl/ttl/ttl_utils)** + +## Package Contents + +### Data + +[`__all__`](#ttl-__all__) + +[`__version__`](#ttl-__version__) + +### API + + + + + +```python +ttl.__all__ = ['kernel', 'compute', 'datamovement', 'Program', 'CircularBuffer', 'TensorBlock'... +``` + + + + + + + + + +```python +ttl.__version__ = '0.1.0' +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/_mlir_libs.mdx b/fern/library-docs/ttl-docs/ttl/ttl/_mlir_libs.mdx new file mode 100644 index 0000000..7dd8cd3 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/_mlir_libs.mdx @@ -0,0 +1,9 @@ +--- +layout: overview +slug: ttl/ttl/_mlir_libs +title: ttl._mlir_libs +--- + +## Submodules + +- **[`ttl._mlir_libs._site_initialize_1`](/ttl/ttl/_mlir_libs/_site_initialize_1)** diff --git a/fern/library-docs/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx b/fern/library-docs/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx new file mode 100644 index 0000000..414f0c4 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/_mlir_libs/_site_initialize_1.mdx @@ -0,0 +1,35 @@ +--- +layout: overview +slug: ttl/ttl/_mlir_libs/_site_initialize_1 +title: ttl._mlir_libs._site_initialize_1 +--- + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`register_dialects`](#ttl-_mlir_libs-_site_initialize_1-register_dialects) | Called by MLIR site initialization to add TTL dialects to the registry. | + +### API + + + + + +```python +ttl._mlir_libs._site_initialize_1.register_dialects( + registry +) +``` + + + + + + +Called by MLIR site initialization to add TTL dialects to the registry. + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/_src.mdx b/fern/library-docs/ttl-docs/ttl/ttl/_src.mdx new file mode 100644 index 0000000..20050a7 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/_src.mdx @@ -0,0 +1,11 @@ +--- +layout: overview +slug: ttl/ttl/_src +title: ttl._src +--- + +## Submodules + +- **[`ttl._src.auto_profile`](/ttl/ttl/_src/auto_profile)** +- **[`ttl._src.tensor_registry`](/ttl/ttl/_src/tensor_registry)** +- **[`ttl._src.ttl_ast`](/ttl/ttl/_src/ttl_ast)** diff --git a/fern/library-docs/ttl-docs/ttl/ttl/_src/auto_profile.mdx b/fern/library-docs/ttl-docs/ttl/ttl/_src/auto_profile.mdx new file mode 100644 index 0000000..6e2407b --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/_src/auto_profile.mdx @@ -0,0 +1,479 @@ +--- +layout: overview +slug: ttl/ttl/_src/auto_profile +title: ttl._src.auto_profile +--- + +Auto-profiling infrastructure for tt-lang kernels. + +Enabled via TTLANG_AUTO_PROFILE=1 environment variable. +Automatically instruments every operation with signposts and generates +a visual profile report showing cycle counts per source line. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`Colors`](#ttl-_src-auto_profile-Colors) | ANSI color codes for terminal output. | +| [`ProfileResult`](#ttl-_src-auto_profile-ProfileResult) | Represents profiling results for a single signpost. | +| [`SourceLineMapper`](#ttl-_src-auto_profile-SourceLineMapper) | Maps signpost markers back to source code lines. | + +### Functions + +| Name | Description | +|------|-------------| +| [`build_cb_wait_to_dma_map`](#ttl-_src-auto_profile-build_cb_wait_to_dma_map) | Build mapping from cb_wait locations to DMA barrier locations. | +| [`build_dma_producer_to_cb_map`](#ttl-_src-auto_profile-build_dma_producer_to_cb_map) | Build mapping from DMA barrier locations to CB index. | +| [`generate_signpost_name`](#ttl-_src-auto_profile-generate_signpost_name) | Generate before/after signpost names for an operation. | +| [`get_line_mapper`](#ttl-_src-auto_profile-get_line_mapper) | Get the global line mapper instance. | +| [`is_auto_profile_enabled`](#ttl-_src-auto_profile-is_auto_profile_enabled) | Check if auto-profiling is enabled via environment variable. | +| [`load_cb_flow_graph`](#ttl-_src-auto_profile-load_cb_flow_graph) | Load CB flow graph JSON from same directory as CSV. | +| [`parse_device_profile_csv`](#ttl-_src-auto_profile-parse_device_profile_csv) | Parse the device profile CSV and extract signpost timing data. | +| [`parse_signpost_name`](#ttl-_src-auto_profile-parse_signpost_name) | Parse op name and implicit flag from signpost name. | +| [`print_profile_report`](#ttl-_src-auto_profile-print_profile_report) | Print a profile report organized by thread. | + +### Data + +[`_global_line_mapper`](#ttl-_src-auto_profile-_global_line_mapper) + +### API + + + + + +```python +class ttl._src.auto_profile.Colors() +``` + + + + + + +ANSI color codes for terminal output. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +ttl._src.auto_profile.Colors.cb_bg( + cb_index: int +) -> str +``` + + + + + + +classmethod + +Get background color for a CB index, or empty if out of range. + + + + + + + + + +```python +class ttl._src.auto_profile.ProfileResult( + signpost: str, + thread: str, + cycles: int, + lineno: int, + source: str +) +``` + + + + + + +Represents profiling results for a single signpost. + + + + + + + + + + +```python +class ttl._src.auto_profile.SourceLineMapper() +``` + + + + + + +Maps signpost markers back to source code lines. + + + + + + + + + + + + + + +```python +ttl._src.auto_profile.SourceLineMapper.get_line_info( + signpost_name: str +) -> typing.Optional[typing.Tuple[int, str]] +``` + + + + + + +Get line number and source for a signpost. + + + + + + + +```python +ttl._src.auto_profile.SourceLineMapper.register_signpost( + signpost_name: str, + lineno: int, + source: str +) +``` + + + + + + +Register a signpost with its source line information. + + + + + + + +```python +ttl._src.auto_profile.SourceLineMapper.set_source( + source_lines: typing.List[str] +) +``` + + + + + + +Set the source code lines for display. + + + + + + + + + +```python +ttl._src.auto_profile.build_cb_wait_to_dma_map( + cb_flow: typing.Optional[typing.Dict] +) -> typing.Dict[typing.Tuple[str, int], typing.Tuple[str, int, int]] +``` + + + + + + +Build mapping from cb_wait locations to DMA barrier locations. + +Only maps consumers waiting for DMA reads (data flowing into CB). +cb_wait ops waiting for compute output (where DMA is a write) are not mapped. + +**Returns:** `Dict[Tuple[str, int], Tuple[str, int, int]]` + +Dict mapping (kernel, line) of cb_wait -> (barrier_kernel, barrier_line, cb_index) + + + + + + + + +```python +ttl._src.auto_profile.build_dma_producer_to_cb_map( + cb_flow: typing.Optional[typing.Dict] +) -> typing.Dict[typing.Tuple[str, int], int] +``` + + + + + + +Build mapping from DMA barrier locations to CB index. + +**Returns:** `Dict[Tuple[str, int], int]` + +Dict mapping (kernel, line) of DMA read barrier -> cb_index + + + + + + + + +```python +ttl._src.auto_profile.generate_signpost_name( + operation: str, + lineno: int, + col: int +) -> typing.Tuple[str, str] +``` + + + + + + +Generate before/after signpost names for an operation. + +**Returns:** `Tuple[str, str]` + +Tuple of (before_name, after_name) + + + + + + + + +```python +ttl._src.auto_profile.get_line_mapper() -> ttl._src.auto_profile.SourceLineMapper +``` + + + + + + +Get the global line mapper instance. + + + + + + + + +```python +ttl._src.auto_profile.is_auto_profile_enabled() -> bool +``` + + + + + + +Check if auto-profiling is enabled via environment variable. + + + + + + + + +```python +ttl._src.auto_profile.load_cb_flow_graph( + csv_path: pathlib.Path +) -> typing.Optional[typing.Dict] +``` + + + + + + +Load CB flow graph JSON from same directory as CSV. + + + + + + + + +```python +ttl._src.auto_profile.parse_device_profile_csv( + csv_path: pathlib.Path, + line_mapper: ttl._src.auto_profile.SourceLineMapper +) -> typing.List[ttl._src.auto_profile.ProfileResult] +``` + + + + + + +Parse the device profile CSV and extract signpost timing data. + +**Parameters:** + + +Path to profile_log_device.csv + + + +Mapper to correlate signposts to source lines + + +**Returns:** `List[ProfileResult]` + +List of ProfileResult objects sorted by line number + + + + + + + + +```python +ttl._src.auto_profile.parse_signpost_name( + signpost: str +) -> typing.Tuple[typing.Optional[str], bool] +``` + + + + + + +Parse op name and implicit flag from signpost name. + +Returns (op_name, is_implicit) where op_name is None for line-only signposts. +Examples: + "line_52_before" -> (None, False) + "line_52_cb_wait_before" -> ("cb_wait", False) + "line_52_implicit_cb_pop_before" -> ("cb_pop", True) + + + + + + + + +```python +ttl._src.auto_profile.print_profile_report( + results: typing.List[ttl._src.auto_profile.ProfileResult], + all_source_lines: typing.Dict[str, typing.List[str]], + thread_to_kernel: typing.Dict[str, str], + line_mapper: typing.Optional[ttl._src.auto_profile.SourceLineMapper] = None, + cb_wait_to_dma: typing.Optional[typing.Dict[typing.Tuple[str, int], typing.Tuple[str, int, int]]] = None, + dma_producer_to_cb: typing.Optional[typing.Dict[typing.Tuple[str, int], int]] = None, + kernel_line_offsets: typing.Optional[typing.Dict[str, int]] = None +) +``` + + + + + + +Print a profile report organized by thread. + +Shows full source context with cycle annotations where available. +Each thread displays its corresponding kernel's source code. + +**Parameters:** + + +List of ProfileResult from CSV parsing + + + +Dict mapping kernel name to source lines + + + +Dict mapping RISC thread name to kernel name + + + +Optional SourceLineMapper with line offset info + + + +Optional mapping from (kernel, line) -> (dma_kernel, dma_line, cb_index) + + + +Optional mapping from (kernel, line) -> cb_index for DMA producers + + + +Optional mapping from kernel name to line offset + + + + + + + + + +```python +ttl._src.auto_profile._global_line_mapper = SourceLineMapper() +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/_src/tensor_registry.mdx b/fern/library-docs/ttl-docs/ttl/ttl/_src/tensor_registry.mdx new file mode 100644 index 0000000..0c2cdf6 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/_src/tensor_registry.mdx @@ -0,0 +1,169 @@ +--- +layout: overview +slug: ttl/ttl/_src/tensor_registry +title: ttl._src.tensor_registry +--- + +Registry for tensor global names, used to track tensor parameter names. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_tensor_global_index`](#ttl-_src-tensor_registry-get_tensor_global_index) | Get the global index for a tensor. | +| [`get_tensor_global_name`](#ttl-_src-tensor_registry-get_tensor_global_name) | Get the global name for a tensor, checking registry first then attribute. | +| [`get_tensor_source`](#ttl-_src-tensor_registry-get_tensor_source) | Get the source location where a tensor was assigned, if tracked. | +| [`register_tensor_name`](#ttl-_src-tensor_registry-register_tensor_name) | Register a global name and index for a tensor. | +| [`register_tensor_source`](#ttl-_src-tensor_registry-register_tensor_source) | Register the source location where a tensor variable was assigned. | + +### Data + +[`_tensor_index_registry`](#ttl-_src-tensor_registry-_tensor_index_registry) + +[`_tensor_name_registry`](#ttl-_src-tensor_registry-_tensor_name_registry) + +[`_tensor_source_registry`](#ttl-_src-tensor_registry-_tensor_source_registry) + +### API + + + + + +```python +ttl._src.tensor_registry.get_tensor_global_index( + tensor +) -> int +``` + + + + + + +Get the global index for a tensor. + + + + + + + + +```python +ttl._src.tensor_registry.get_tensor_global_name( + tensor +) -> str +``` + + + + + + +Get the global name for a tensor, checking registry first then attribute. + + + + + + + + +```python +ttl._src.tensor_registry.get_tensor_source( + tensor +) -> typing.Optional[typing.Tuple[str, int]] +``` + + + + + + +Get the source location where a tensor was assigned, if tracked. + + + + + + + + +```python +ttl._src.tensor_registry.register_tensor_name( + tensor, + name: str, + index: int = -1 +) -> None +``` + + + + + + +Register a global name and index for a tensor. + + + + + + + + +```python +ttl._src.tensor_registry.register_tensor_source( + tensor, + source_file: str, + line: int +) -> None +``` + + + + + + +Register the source location where a tensor variable was assigned. + + + + + + + + +```python +ttl._src.tensor_registry._tensor_index_registry: Dict[int, int] = {} +``` + + + + + + + + + +```python +ttl._src.tensor_registry._tensor_name_registry: Dict[int, str] = {} +``` + + + + + + + + + +```python +ttl._src.tensor_registry._tensor_source_registry: Dict[int, Tuple[str, int]] = {} +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/_src/ttl_ast.mdx b/fern/library-docs/ttl-docs/ttl/ttl/_src/ttl_ast.mdx new file mode 100644 index 0000000..cc33586 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/_src/ttl_ast.mdx @@ -0,0 +1,731 @@ +--- +layout: overview +slug: ttl/ttl/_src/ttl_ast +title: ttl._src.ttl_ast +--- + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CompilerContext`](#ttl-_src-ttl_ast-CompilerContext) | Immutable compilation context for TTL kernels. | +| [`TTLGenericCompiler`](#ttl-_src-ttl_ast-TTLGenericCompiler) | Compiler that generates TTL dialect ops from Python AST. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_build_tensor_type`](#ttl-_src-ttl_ast-_build_tensor_type) | Build MLIR tensor type for a ttnn tensor with TTNNLayoutAttr. | +| [`_get_annotation_name`](#ttl-_src-ttl_ast-_get_annotation_name) | Extract the type name from an annotation node. | +| [`_make_file_loc`](#ttl-_src-ttl_ast-_make_file_loc) | Create an MLIR file location from an AST node. | +| [`_raise_tensor_error`](#ttl-_src-ttl_ast-_raise_tensor_error) | Raise TTLangCompileError with tensor source location if available. | +| [`syntax`](#ttl-_src-ttl_ast-syntax) | - | + +### API + + + + + +```python +class ttl._src.ttl_ast.CompilerContext( + grid: typing.List[int], + memory_space: str, + tiled: bool +) +``` + + + + + + +Dataclass + +Immutable compilation context for TTL kernels. + + + + + + + + + + + + + + + + +```python +class ttl._src.ttl_ast.TTLGenericCompiler( + name, + kernel_type = None, + captures = {}, + args = (), + kwargs = {} +) +``` + + + + + + +**Bases:** `TTCompilerBase` + +Compiler that generates TTL dialect ops from Python AST. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._build_index_or_range( + node +) +``` + + + + + + +Convert AST node to (start_value, is_range) tuple. + +For slice syntax (start:end), returns (start_value, True). +For index syntax (value), returns (value, False). + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._close_final_signpost() +``` + + + + + + +Close the final signpost at the end of function body. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._emit_cb_from_capture( + cb +) +``` + + + + + + +Emit ttl.bind_cb for a captured CircularBuffer instance. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._emit_entry( + node +) +``` + + + + + + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._emit_line_signpost_if_needed( + node +) +``` + + + + + + +Emit signposts at line boundaries for auto-profiling. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._emit_op_signposts( + op_name: str, + node, + op_fn, + implicit = False +) +``` + + + + + + +Emit signposts for CB operations with op name included. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._emit_signpost( + name: str +) +``` + + + + + + +Emit a signpost operation into the MLIR. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._get_cb_tensor_type( + cb_val, + node = None +) +``` + + + + + + +Extract the tensor type from a TTL CB type. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._is_ttl_math_access( + node +) +``` + + + + + + +Check if node is ttl.math.XXX access pattern. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._is_ttl_module_access( + node +) +``` + + + + + + +Check if node is ttl.XXX access pattern. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._loc_for_node( + node +) +``` + + + + + + +Return file location for node if debug_locations enabled, else name location. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._raise_error( + node, + message: str +) +``` + + + + + + +Raise a TTLangCompileError with source location from AST node. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._resolve_ttl_function( + node, + func_args, + kwargs +) +``` + + + + + + +Resolve and call a ttl.XXX or ttl.math.XXX function. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._to_index_value( + node +) +``` + + + + + + +Convert AST node to MLIR index Value. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler._try_emit_auto_signposts( + node, + visit_fn +) +``` + + + + + + +Emit line-based signposts if auto-profiling is enabled. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Assign( + node +) +``` + + + + + + +Handle tuple unpacking for TTL functions like core(dims=2). + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_AsyncFunctionDef( + node +) +``` + + + + + + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Attribute( + node, + func_args = [], + kwargs = {} +) +``` + + + + + + +Override to set location context and catch errors for method calls. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_BinOp( + node +) +``` + + + + + + +Override to inject auto-profiling and provide better error messages. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Call( + node +) +``` + + + + + + +Override to set location context, catch errors, and inject auto-profiling. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Constant( + node +) +``` + + + + + + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_FunctionDef( + node +) +``` + + + + + + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_List( + node +) +``` + + + + + + +Parse a list of constants. Returns a Python list, not MLIR values. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Name( + node +) +``` + + + + + + +Override to check function globals for simple constants. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_Subscript( + node +) +``` + + + + + + +Handle tensor[row, col] or tensor[r0:r1, c0:c1] indexing. + + + + + + + +```python +ttl._src.ttl_ast.TTLGenericCompiler.visit_With( + node +) +``` + + + + + + +Handle 'with' for CircularBuffer acquire/release. + +Acquire ops (wait/reserve) are generated left-to-right. +Release ops (pop/push) are generated in reverse order at scope end. + + + + + + + + + +```python +ttl._src.ttl_ast._build_tensor_type( + ctx, + tensor, + grid, + tiled, + memory_space +) +``` + + + + + + +Build MLIR tensor type for a ttnn tensor with TTNNLayoutAttr. + + + + + + + + +```python +ttl._src.ttl_ast._get_annotation_name( + annotation +) +``` + + + + + + +Extract the type name from an annotation node. + +Handles both simple names (CircularBuffer) and qualified names (ttl.CircularBuffer). +Returns the simple type name (e.g., 'CircularBuffer') in both cases. + + + + + + + + +```python +ttl._src.ttl_ast._make_file_loc( + ctx, + source_file: str, + node, + line_offset: int = 0 +) -> Location +``` + + + + + + +Create an MLIR file location from an AST node. + + + + + + + + +```python +ttl._src.ttl_ast._raise_tensor_error( + tensor, + message: str +) +``` + + + + + + +Raise TTLangCompileError with tensor source location if available. + + + + + + + + +```python +ttl._src.ttl_ast.syntax( + syntax_name +) +``` + + + + + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/circular_buffer.mdx b/fern/library-docs/ttl-docs/ttl/ttl/circular_buffer.mdx new file mode 100644 index 0000000..d86d376 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/circular_buffer.mdx @@ -0,0 +1,239 @@ +--- +layout: overview +slug: ttl/ttl/circular_buffer +title: ttl.circular_buffer +--- + +Circular buffer operations for inter-thread communication. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CircularBuffer`](#ttl-circular_buffer-CircularBuffer) | Circular buffer for inter-thread communication. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_cb_tensor_type`](#ttl-circular_buffer-_get_cb_tensor_type) | Extract the tensor type from a TTL CB type. | +| [`_next_cb_index`](#ttl-circular_buffer-_next_cb_index) | Get next CB index and increment counter. | +| [`_reset_cb_counter`](#ttl-circular_buffer-_reset_cb_counter) | Reset the CB index counter. Called at kernel start. | +| [`get_cb_count`](#ttl-circular_buffer-get_cb_count) | Return number of CBs allocated so far. | +| [`make_dataflow_buffer_like`](#ttl-circular_buffer-make_dataflow_buffer_like) | Create a circular buffer with properties derived from a tensor. | + +### Data + +[`_cb_index_counter`](#ttl-circular_buffer-_cb_index_counter) + +### API + + + + + +```python +class ttl.circular_buffer.CircularBuffer( + tensor: typing.Any, + shape: typing.Tuple[int, int], + buffer_factor: int +) +``` + + + + + + +Circular buffer for inter-thread communication. + +Circular buffers provide producer-consumer synchronization between +compute and data movement threads. + +Can be instantiated via make_dataflow_buffer_like() in kernel body, +then captured by thread closures. Methods generate TTL ops during compilation. + + + + + + + + +```python +ttl.circular_buffer.CircularBuffer.reserve( + ast_self: ttl.circular_buffer.CircularBuffer +) -> ttl.ttl_api.TensorBlock +``` + + + + + + +Reserve space in the circular buffer (producer acquire). + +Use in producer threads to acquire space for writing. Must be followed +by push() to signal data is ready. + +**Returns:** `TensorBlock` + +The reserved space with CB association. + + + + + + + +```python +ttl.circular_buffer.CircularBuffer.wait( + ast_self: ttl.circular_buffer.CircularBuffer +) -> ttl.ttl_api.TensorBlock +``` + + + + + + +Wait for data from the circular buffer (consumer acquire). + +Use in consumer threads to acquire data. Must be followed by pop() +to signal consumption is complete. + +**Returns:** `TensorBlock` + +The acquired data with CB association. + + + + + + + + + +```python +ttl.circular_buffer._get_cb_tensor_type( + cb_val +) +``` + + + + + + +Extract the tensor type from a TTL CB type. + + + + + + + + +```python +ttl.circular_buffer._next_cb_index() +``` + + + + + + +Get next CB index and increment counter. + + + + + + + + +```python +ttl.circular_buffer._reset_cb_counter() +``` + + + + + + +Reset the CB index counter. Called at kernel start. + + + + + + + + +```python +ttl.circular_buffer.get_cb_count() +``` + + + + + + +Return number of CBs allocated so far. + + + + + + + + +```python +ttl.circular_buffer.make_dataflow_buffer_like( + tensor: typing.Any, + shape: typing.Tuple[int, int], + buffer_factor: int = 2 +) -> ttl.circular_buffer.CircularBuffer +``` + + + + + + +Create a circular buffer with properties derived from a tensor. + +**Parameters:** + + +Tensor that determines the CB's data type + + + +(rows, cols) in tiles for wait/reserve operations + + + +Capacity multiplier (default 2 for double-buffering) + + +**Returns:** `CircularBuffer` + +CircularBuffer for use in thread function closures + + + + + + + + +```python +ttl.circular_buffer._cb_index_counter = 0 +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/constants.mdx b/fern/library-docs/ttl-docs/ttl/ttl/constants.mdx new file mode 100644 index 0000000..5aa63c2 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/constants.mdx @@ -0,0 +1,41 @@ +--- +layout: overview +slug: ttl/ttl/constants +title: ttl.constants +--- + +Constants used throughout the DSL. + +## Module Contents + +### Data + +[`DEFAULT_TILE_SIZE`](#ttl-constants-DEFAULT_TILE_SIZE) + +[`SUPPORTED_MEMORY_SPACES`](#ttl-constants-SUPPORTED_MEMORY_SPACES) + +### API + + + + + +```python +ttl.constants.DEFAULT_TILE_SIZE = 32 +``` + + + + + + + + + +```python +ttl.constants.SUPPORTED_MEMORY_SPACES = frozenset(['L1', 'DRAM']) +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/diagnostics.mdx b/fern/library-docs/ttl-docs/ttl/ttl/diagnostics.mdx new file mode 100644 index 0000000..709211b --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/diagnostics.mdx @@ -0,0 +1,466 @@ +--- +layout: overview +slug: ttl/ttl/diagnostics +title: ttl.diagnostics +--- + +Diagnostic utilities for formatting compiler errors with source context. + +This module provides Rust/Swift-style error formatting that displays +source code snippets with ASCII arrows pointing to the error location. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`SourceDiagnostic`](#ttl-diagnostics-SourceDiagnostic) | Format errors with source context and ASCII arrows. | +| [`TTLangCompileError`](#ttl-diagnostics-TTLangCompileError) | Exception for tt-lang compilation errors with source context. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_extract_core_message`](#ttl-diagnostics-_extract_core_message) | Extract the core error message from MLIR diagnostic output. | +| [`_extract_note`](#ttl-diagnostics-_extract_note) | Extract any note from the MLIR error message. | +| [`_read_file_lines`](#ttl-diagnostics-_read_file_lines) | Read source lines from a file if it exists. | +| [`_verbose_errors_enabled`](#ttl-diagnostics-_verbose_errors_enabled) | Check if verbose MLIR error output is enabled. | +| [`extract_location_from_mlir_error`](#ttl-diagnostics-extract_location_from_mlir_error) | Extract source location from an MLIR error message. | +| [`find_variable_assignment`](#ttl-diagnostics-find_variable_assignment) | Find the line where a variable was assigned, searching backwards. | +| [`format_mlir_error`](#ttl-diagnostics-format_mlir_error) | Format an MLIR error with source context if location is available. | +| [`format_python_error`](#ttl-diagnostics-format_python_error) | Format a Python error with source context. | +| [`parse_mlir_location`](#ttl-diagnostics-parse_mlir_location) | Parse an MLIR location string to extract file, line, and column. | + +### API + + + + + +```python +class ttl.diagnostics.SourceDiagnostic( + source_lines: typing.List[str], + filename: str +) +``` + + + + + + +Format errors with source context and ASCII arrows. + +Produces error messages in the style of modern compilers (Rust, Swift): + + error: type mismatch in add operation + --> kernel.py:43:16 + | + 43 | result = l + r + | ^^^ expected bf16, got f32 + | + + + + + + +```python +ttl.diagnostics.SourceDiagnostic.format_error( + line: int, + col: int, + message: str, + label: str = 'error', + span_length: int = 1, + note: typing.Optional[str] = None +) -> str +``` + + + + + + +Format an error with source context. + +**Parameters:** + + +1-based line number + + + +1-based column number + + + +Main error message + + + +Error label (e.g., "error", "warning") + + + +Length of the underline (^^^) + + + +Optional additional note + + +**Returns:** `str` + +Formatted error string with source context + + + + + + + +```python +ttl.diagnostics.SourceDiagnostic.format_error_chain( + errors: typing.List[typing.Tuple[int, int, str, typing.Optional[str]]] +) -> str +``` + + + + + + +Format multiple related errors. + +**Parameters:** + + +List of (line, col, message, note) tuples + + +**Returns:** `str` + +Formatted error chain + + + + + + + + + +```python +class ttl.diagnostics.TTLangCompileError( + message: str, + source_file: typing.Optional[str] = None, + line: typing.Optional[int] = None, + col: typing.Optional[int] = None, + source_lines: typing.Optional[typing.List[str]] = None +) +``` + + + + + + +Exception + +**Bases:** `Exception` + +Exception for tt-lang compilation errors with source context. + +This exception carries enough information to produce pretty error messages +pointing to the exact source location where the error occurred. + + + + + + +```python +ttl.diagnostics.TTLangCompileError.format() -> str +``` + + + + + + +Format error with source context if available. + + + + + + + + + +```python +ttl.diagnostics._extract_core_message( + error_msg: str +) -> str +``` + + + + + + +Extract the core error message from MLIR diagnostic output. + +This extracts: "expects transfer handle to be synchronized with ttl.wait" + + + + + + + + +```python +ttl.diagnostics._extract_note( + error_msg: str +) -> typing.Optional[str] +``` + + + + + + +Extract any note from the MLIR error message. + + + + + + + + +```python +ttl.diagnostics._read_file_lines( + filepath: str +) -> typing.Optional[typing.List[str]] +``` + + + + + + +Read source lines from a file if it exists. + + + + + + + + +```python +ttl.diagnostics._verbose_errors_enabled() -> bool +``` + + + + + + +Check if verbose MLIR error output is enabled. + + + + + + + + +```python +ttl.diagnostics.extract_location_from_mlir_error( + error_msg: str +) -> typing.Optional[typing.Tuple[str, int, int]] +``` + + + + + + +Extract source location from an MLIR error message. + +**Parameters:** + + +Full MLIR error message + + +**Returns:** `Optional[Tuple[str, int, int]]` + +Tuple of (filename, line, col) or None if no location found + + + + + + + + +```python +ttl.diagnostics.find_variable_assignment( + source_lines: typing.List[str], + var_name: str, + before_line: int +) -> typing.Optional[int] +``` + + + + + + +Find the line where a variable was assigned, searching backwards. + +**Parameters:** + + +List of source lines (0-indexed) + + + +Variable name to search for + + + +Search backwards from this 1-based line number + + +**Returns:** `Optional[int]` + +1-based line number where assignment was found, or None + + + + + + + + +```python +ttl.diagnostics.format_mlir_error( + error_msg: str, + source_lines: typing.Optional[typing.List[str]] = None, + source_file: typing.Optional[str] = None +) -> str +``` + + + + + + +Format an MLIR error with source context if location is available. + +**Parameters:** + + +The MLIR error message + + + +Original Python source lines (optional, will read from file if needed) + + + +Source filename (optional, extracted from error if not provided) + + +**Returns:** `str` + +Formatted error message, with source context if available + + + + + + + + +```python +ttl.diagnostics.format_python_error( + error: Exception, + source_file: str, + line: int, + source_lines: typing.Optional[typing.List[str]] = None +) -> str +``` + + + + + + +Format a Python error with source context. + +**Parameters:** + + +The Python exception + + + +Source file path + + + +Line number in source file + + + +Source lines (will read from file if not provided) + + +**Returns:** `str` + +Formatted error message with source context + + + + + + + + +```python +ttl.diagnostics.parse_mlir_location( + loc_str: str +) -> typing.Optional[typing.Tuple[str, int, int]] +``` + + + + + + +Parse an MLIR location string to extract file, line, and column. + +MLIR locations can appear in several formats: +- loc("filename":line:col) +- loc("filename":line:col to :line:col) +- loc(#loc1) with #loc1 = loc("filename":line:col) + +**Parameters:** + + +MLIR location string + + +**Returns:** `Optional[Tuple[str, int, int]]` + +Tuple of (filename, line, col) or None if not parseable + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/dialects.mdx b/fern/library-docs/ttl-docs/ttl/ttl/dialects.mdx new file mode 100644 index 0000000..2eb4b50 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/dialects.mdx @@ -0,0 +1,12 @@ +--- +layout: overview +slug: ttl/ttl/dialects +title: ttl.dialects +--- + +TTLang dialect modules. + +## Submodules + +- **[`ttl.dialects._ods_common`](/ttl/ttl/dialects/_ods_common)** +- **[`ttl.dialects.ttl`](/ttl/ttl/dialects/ttl)** diff --git a/fern/library-docs/ttl-docs/ttl/ttl/dialects/_ods_common.mdx b/fern/library-docs/ttl-docs/ttl/ttl/dialects/_ods_common.mdx new file mode 100644 index 0000000..7b2c71e --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/dialects/_ods_common.mdx @@ -0,0 +1,39 @@ +--- +layout: overview +slug: ttl/ttl/dialects/_ods_common +title: ttl.dialects._ods_common +--- + +## Module Contents + +### Data + +[`__all__`](#ttl-dialects-_ods_common-__all__) + +[`_cext`](#ttl-dialects-_ods_common-_cext) + +### API + + + + + +```python +ttl.dialects._ods_common.__all__ = ['_cext'] +``` + + + + + + + + + +```python +ttl.dialects._ods_common._cext = _upstream._cext +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/dialects/ttl.mdx b/fern/library-docs/ttl-docs/ttl/ttl/dialects/ttl.mdx new file mode 100644 index 0000000..d8c69a1 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/dialects/ttl.mdx @@ -0,0 +1,81 @@ +--- +layout: overview +slug: ttl/ttl/dialects/ttl +title: ttl.dialects.ttl +--- + +TTL (TT-Lang) dialect Python bindings. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`ensure_dialects_registered`](#ttl-dialects-ttl-ensure_dialects_registered) | Ensure TTL dialect is registered with the given MLIR context. | + +### Data + +[`CircularBufferType`](#ttl-dialects-ttl-CircularBufferType) + +[`SliceAttr`](#ttl-dialects-ttl-SliceAttr) + +[`__all__`](#ttl-dialects-ttl-__all__) + +### API + + + + + +```python +ttl.dialects.ttl.ensure_dialects_registered( + ctx +) +``` + + + + + + +Ensure TTL dialect is registered with the given MLIR context. + + + + + + + + +```python +ttl.dialects.ttl.CircularBufferType = ir.CircularBufferType +``` + + + + + + + + + +```python +ttl.dialects.ttl.SliceAttr = ir.SliceAttr +``` + + + + + + + + + +```python +ttl.dialects.ttl.__all__ = [*[name for name in (globals().keys()) if not name.startswith('_')]] +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/dtype_utils.mdx b/fern/library-docs/ttl-docs/ttl/ttl/dtype_utils.mdx new file mode 100644 index 0000000..fce940c --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/dtype_utils.mdx @@ -0,0 +1,212 @@ +--- +layout: overview +slug: ttl/ttl/dtype_utils +title: ttl.dtype_utils +--- + +Data type conversion utilities between PyTorch, TTNN, and MLIR types. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`is_ttnn_tensor`](#ttl-dtype_utils-is_ttnn_tensor) | Check if tensor is a ttnn.Tensor. | +| [`tensor_dtype_to_ttcore_datatype`](#ttl-dtype_utils-tensor_dtype_to_ttcore_datatype) | Convert tensor dtype to ttcore.DataType, supporting both torch and ttnn dtypes. | +| [`tile_bytes_from_dtype`](#ttl-dtype_utils-tile_bytes_from_dtype) | Calculate tile size in bytes from ttnn dtype. | +| [`torch_dtype_to_ttcore_datatype`](#ttl-dtype_utils-torch_dtype_to_ttcore_datatype) | Convert PyTorch dtype to ttcore.DataType enum. | +| [`torch_dtype_to_ttnn_datatype`](#ttl-dtype_utils-torch_dtype_to_ttnn_datatype) | Convert PyTorch dtype to ttnn.DataType enum. | +| [`ttnn_dtype_to_ttcore_datatype`](#ttl-dtype_utils-ttnn_dtype_to_ttcore_datatype) | Convert ttnn.DataType to ttcore.DataType enum. | + +### API + + + + + +```python +ttl.dtype_utils.is_ttnn_tensor( + tensor +) -> bool +``` + + + + + + +Check if tensor is a ttnn.Tensor. + + + + + + + + +```python +ttl.dtype_utils.tensor_dtype_to_ttcore_datatype( + dtype +) +``` + + + + + + +Convert tensor dtype to ttcore.DataType, supporting both torch and ttnn dtypes. + +**Parameters:** + + +Either torch dtype or ttnn.DataType + + +**Returns:** + +ttcore.DataType enum value + + + + + + + + +```python +ttl.dtype_utils.tile_bytes_from_dtype( + dtype +) -> int +``` + + + + + + +Calculate tile size in bytes from ttnn dtype. + +For tiled tensors, each tile is 32x32 elements. The byte size depends on +the data type's element size plus any format-specific overhead. + +**Parameters:** + + +ttnn.DataType enum value + + +**Returns:** `int` + +Tile size in bytes + +**Raises:** + +- `ValueError`: If dtype is not supported + + + + + + + + +```python +ttl.dtype_utils.torch_dtype_to_ttcore_datatype( + torch_dtype +) +``` + + + + + + +Convert PyTorch dtype to ttcore.DataType enum. + +**Parameters:** + + +PyTorch dtype (torch.float32, torch.int32, etc.) + + +**Returns:** + +ttcore.DataType enum value + +**Raises:** + +- `ValueError`: If dtype is not supported + + + + + + + + +```python +ttl.dtype_utils.torch_dtype_to_ttnn_datatype( + torch_dtype +) +``` + + + + + + +Convert PyTorch dtype to ttnn.DataType enum. + +**Parameters:** + + +PyTorch dtype (torch.float32, torch.bfloat16, etc.) + + +**Returns:** + +ttnn.DataType enum value + +**Raises:** + +- `ImportError`: If ttnn is not available +- `ValueError`: If dtype is not supported + + + + + + + + +```python +ttl.dtype_utils.ttnn_dtype_to_ttcore_datatype( + ttnn_dtype +) +``` + + + + + + +Convert ttnn.DataType to ttcore.DataType enum. + +**Parameters:** + + +ttnn.DataType enum value + + +**Returns:** + +ttcore.DataType enum value + +**Raises:** + +- `ValueError`: If dtype is not supported + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/kernel_runner.mdx b/fern/library-docs/ttl-docs/ttl/ttl/kernel_runner.mdx new file mode 100644 index 0000000..a87c92a --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/kernel_runner.mdx @@ -0,0 +1,274 @@ +--- +layout: overview +slug: ttl/ttl/kernel_runner +title: ttl.kernel_runner +--- + +Shared kernel execution logic for tt-lang. + +Provides functions for building kernel descriptors, CB descriptors, and +executing kernels on device via ttnn.generic_op. Used by both the Python +DSL (CompiledTTNNKernel) and ME2E tests. + +This module provides a single reusable implementation of kernel argument +building and execution. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`KernelSpec`](#ttl-kernel_runner-KernelSpec) | Specification for a single kernel to execute. | + +### Functions + +| Name | Description | +|------|-------------| +| [`build_cb_descriptors`](#ttl-kernel_runner-build_cb_descriptors) | Build circular buffer descriptors for ttnn.generic_op. | +| [`build_kernel_descriptors`](#ttl-kernel_runner-build_kernel_descriptors) | Build kernel descriptors for ttnn.generic_op. | +| [`build_tensor_accessor_args`](#ttl-kernel_runner-build_tensor_accessor_args) | Build compile-time args for tensor accessors. | +| [`run_kernel_on_device`](#ttl-kernel_runner-run_kernel_on_device) | Execute kernels on device using ttnn.generic_op. | + +### Data + +[`__all__`](#ttl-kernel_runner-__all__) + +### API + + + + + +```python +class ttl.kernel_runner.KernelSpec( + path: str, + thread_type: str, + tensor_indices: typing.List[int], + config: typing.Any +) +``` + + + + + + +Dataclass + +Specification for a single kernel to execute. + + + + + + + + + + + + + + + + +```python +ttl.kernel_runner.build_cb_descriptors( + tensors: typing.List[typing.Any], + cb_configs: typing.List[typing.Any], + core_ranges: typing.Any +) -> typing.List[typing.Any] +``` + + + + + + +Build circular buffer descriptors for ttnn.generic_op. + +**Parameters:** + + +List of ttnn.Tensor objects. Each tensor's position (0, 1, 2, ...) +corresponds to its CB index. For intermediate CBs (not backed by +input/output tensors), pass None in the corresponding position. + + + +List of CircularBuffer objects for each CB, indexed by CB index. +Each CB has shape, buffer_factor, tensor (for dtype), and _cb_index attributes. + + + +ttnn.CoreRangeSet for CB allocation. + + +**Returns:** `List[Any]` + +List of ttnn.CBDescriptor objects. + + + + + + + + +```python +ttl.kernel_runner.build_kernel_descriptors( + kernel_specs: typing.List[ttl.kernel_runner.KernelSpec], + tensors: typing.List[typing.Any], + tensor_accessor_args: typing.List[int], + core_ranges: typing.Any, + grid_cols: int, + grid_rows: int, + num_cbs: int +) -> typing.List[typing.Any] +``` + + + + + + +Build kernel descriptors for ttnn.generic_op. + +**Parameters:** + + +List of kernel specifications. + + + +List of ttnn.Tensor objects. Position in this list determines +the global tensor index. Individual kernels access subsets via +tensor_indices in each KernelSpec. + + + +Flattened compile-time args from all tensors. + + + +ttnn.CoreRangeSet for kernel execution. + + + +Number of grid columns (x dimension). + + + +Number of grid rows (y dimension). + + + +Total number of circular buffers (including intermediate CBs). + + +**Returns:** `List[Any]` + +List of ttnn.KernelDescriptor objects. + + + + + + + + +```python +ttl.kernel_runner.build_tensor_accessor_args( + tensors: typing.List[typing.Any] +) -> typing.List[int] +``` + + + + + + +Build compile-time args for tensor accessors. + +**Parameters:** + + +List of ttnn.Tensor objects on device. + + +**Returns:** `List[int]` + +List of compile-time args (flattened TensorAccessorArgs for all tensors). + + + + + + + + +```python +ttl.kernel_runner.run_kernel_on_device( + kernel_specs: typing.List[ttl.kernel_runner.KernelSpec], + tensors: typing.List[typing.Any], + cb_configs: typing.List[typing.Any], + core_ranges: typing.Any, + program_hash: int = None +) -> typing.Any +``` + + + + + + +Execute kernels on device using ttnn.generic_op. + +This is the main entry point for kernel execution. It builds all +descriptors and runs the program. + +**Parameters:** + + +List of kernel specifications (path, thread_type, tensor_indices, config). + + + +List of ttnn.Tensor objects. Position in this list determines the +global tensor index. Individual kernels access subsets via tensor_indices +in each KernelSpec. + + + +List of CircularBuffer objects for each CB, indexed by CB index. +Includes both tensor-backed CBs and intermediate CBs. Each CB has shape, +buffer_factor, tensor (for dtype), and _cb_index attributes. + + + +ttnn.CoreRangeSet for kernel execution. + + + +Hash for tt-metal program cache (not yet used). + + +**Returns:** `Any` + +Result from ttnn.generic_op (typically None or output tensor). + + + + + + + + +```python +ttl.kernel_runner.__all__ = ['KernelSpec', 'build_tensor_accessor_args', 'build_kernel_descriptors', 'build_... +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/layouts.mdx b/fern/library-docs/ttl-docs/ttl/ttl/layouts.mdx new file mode 100644 index 0000000..41784be --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/layouts.mdx @@ -0,0 +1,126 @@ +--- +layout: overview +slug: ttl/ttl/layouts +title: ttl.layouts +--- + +Layout creation utilities for tensor distribution across cores. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`TTNNLayoutConfig`](#ttl-layouts-TTNNLayoutConfig) | Configuration for TTNN layout creation. Supports L1/DRAM interleaved tiled layouts. | + +### Functions + +| Name | Description | +|------|-------------| +| [`create_ttnn_layout`](#ttl-layouts-create_ttnn_layout) | Create a TTNNLayoutAttr for L1 interleaved tiled tensors. | + +### Data + +[`_TTNN_BUFFER_TYPE_L1`](#ttl-layouts-_TTNN_BUFFER_TYPE_L1) + +[`_TTNN_TENSOR_MEMORY_LAYOUT_INTERLEAVED`](#ttl-layouts-_TTNN_TENSOR_MEMORY_LAYOUT_INTERLEAVED) + +### API + + + + + +```python +class ttl.layouts.TTNNLayoutConfig( + logical_shape: typing.List[int], + grid: typing.List[int], + dtype: str +) +``` + + + + + + +Dataclass + +Configuration for TTNN layout creation. Supports L1/DRAM interleaved tiled layouts. + + + + + + + + + + + + + + + + +```python +ttl.layouts.create_ttnn_layout( + ctx, + config: ttl.layouts.TTNNLayoutConfig +) +``` + + + + + + +Create a TTNNLayoutAttr for L1 interleaved tiled tensors. + +Supports: L1/DRAM memory, Interleaved layout, tiled (32x32 tiles). + +**Parameters:** + + +MLIR context + + + +Configuration with logical_shape, grid, and dtype + + +**Returns:** + +TTNNLayoutAttr + +**Raises:** + +- `ValueError`: If configuration is unsupported + + + + + + + + +```python +ttl.layouts._TTNN_BUFFER_TYPE_L1 = 1 +``` + + + + + + + + + +```python +ttl.layouts._TTNN_TENSOR_MEMORY_LAYOUT_INTERLEAVED = 0 +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/operators.mdx b/fern/library-docs/ttl-docs/ttl/ttl/operators.mdx new file mode 100644 index 0000000..387aa56 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/operators.mdx @@ -0,0 +1,714 @@ +--- +layout: overview +slug: ttl/ttl/operators +title: ttl.operators +--- + +DSL operators for tensor operations and data movement. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CopyTransferHandler`](#ttl-operators-CopyTransferHandler) | Transfer handle for asynchronous copy operations. | +| [`TensorBlock`](#ttl-operators-TensorBlock) | Represents a block of tensor data in the TTL dialect. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_get_cb_from_block`](#ttl-operators-_get_cb_from_block) | Extract the CB from a block (result of ttl.attach_cb). | +| [`_get_cb_shape`](#ttl-operators-_get_cb_shape) | Extract the block shape from a CB value. | +| [`_get_constant_int`](#ttl-operators-_get_constant_int) | Extract Python int from MLIR arith.ConstantOp or return as-is if already int. | +| [`_get_current_grid`](#ttl-operators-_get_current_grid) | Get the current grid dimensions. | +| [`_is_block`](#ttl-operators-_is_block) | Check if a value is a block (result of cb.reserve() or cb.wait()). | +| [`_make_tensor_slice`](#ttl-operators-_make_tensor_slice) | Create a ttl.tensor_slice from a tensor, tile indices, and shape. | +| [`_process_tensor_subscript`](#ttl-operators-_process_tensor_subscript) | Process tensor subscript and create tensor slice. | +| [`_set_current_grid`](#ttl-operators-_set_current_grid) | Set the current grid dimensions. Called before compiling threads. | +| [`broadcast`](#ttl-operators-broadcast) | Broadcast over specified dimensions. | +| [`copy`](#ttl-operators-copy) | Initiate an asynchronous data transfer using ttl.copy. | +| [`core`](#ttl-operators-core) | Get the coordinates of the current core. | +| [`grid_size`](#ttl-operators-grid_size) | Get the size of the grid. | +| [`signpost`](#ttl-operators-signpost) | Emit a profiling marker visible in Tracy. | + +### Data + +[`CoreCoordinate`](#ttl-operators-CoreCoordinate) + +[`IndexedTensor`](#ttl-operators-IndexedTensor) + +[`__all__`](#ttl-operators-__all__) + +[`_current_grid`](#ttl-operators-_current_grid) + +### API + + + + + +```python +class ttl.operators.CopyTransferHandler() +``` + + + + + + +Transfer handle for asynchronous copy operations. + +CopyTransferHandler objects are returned by copy() calls and must be +explicitly waited on to ensure transfer completion. + + + + + + +```python +ttl.operators.CopyTransferHandler.wait( + ast_self: ttl.operators.CopyTransferHandler +) +``` + + + + + + +Block until the copy operation completes. + + + + + + + + + +```python +class ttl.operators.TensorBlock( + shape, + dtype +) +``` + + + + + + +Represents a block of tensor data in the TTL dialect. + +TensorBlock supports arithmetic operations through operator +overloading. Operations generate TTL high-level ops that get lowered +to ttl.compute blocks. + + + + + + +```python +ttl.operators.TensorBlock.__add__( + ast_self: ttl.operators.TensorBlock, + rhs: ttl.operators.TensorBlock +) -> ttl.operators.TensorBlock +``` + + + + + + +Element-wise addition using ttl.add. + +**Parameters:** + + +Right operand tensor. Must have the same shape as self. + + +**Returns:** `TensorBlock` + +Result tensor with the same shape as inputs. + + + + + + + +```python +ttl.operators.TensorBlock.__matmul__( + ast_self: ttl.operators.TensorBlock, + rhs: ttl.operators.TensorBlock +) -> ttl.operators.TensorBlock +``` + + + + + + +Matrix multiplication is not yet supported in TTL mode. + + + + + + + +```python +ttl.operators.TensorBlock.__mul__( + ast_self: ttl.operators.TensorBlock, + rhs: ttl.operators.TensorBlock +) -> ttl.operators.TensorBlock +``` + + + + + + +Element-wise multiplication using ttl.mul. + + + + + + + +```python +ttl.operators.TensorBlock.__sub__( + ast_self: ttl.operators.TensorBlock, + rhs: ttl.operators.TensorBlock +) -> ttl.operators.TensorBlock +``` + + + + + + +Element-wise subtraction using ttl.sub. + + + + + + + +```python +ttl.operators.TensorBlock.__truediv__( + ast_self: ttl.operators.TensorBlock, + rhs: ttl.operators.TensorBlock +) -> ttl.operators.TensorBlock +``` + + + + + + +Element-wise division using ttl.div. + + + + + + + +```python +ttl.operators.TensorBlock.pop( + ast_self: ttl.operators.TensorBlock +) -> None +``` + + + + + + +Signal that data has been consumed (consumer release). + +Finalizes a wait() operation by signaling that the block has been +consumed and space is available for producers. This operation is non-blocking. + +Must be called on a block acquired via wait(). + + + + + + + +```python +ttl.operators.TensorBlock.push( + ast_self: ttl.operators.TensorBlock +) -> None +``` + + + + + + +Signal that data is ready in the circular buffer (producer release). + +Finalizes a reserve() operation by signaling that the block has been +written and is ready for consumers. This operation is non-blocking. + +Must be called on a block acquired via reserve(). + + + + + + + +```python +ttl.operators.TensorBlock.store( + ast_self: ttl.operators.TensorBlock, + rhs: ttl.operators.TensorBlock +) -> None +``` + + + + + + +Store result tensor to the output CB reserve view. + +Emits ttl.store (the actual store) and ttl.attach_cb (CB association +for compute formation). The store carries its destination as an SSA +operand; no downstream pass infers or synthesizes stores. + + + + + + + + + +```python +ttl.operators._get_cb_from_block( + block +) +``` + + + + + + +Extract the CB from a block (result of ttl.attach_cb). + +The attach_cb op has signature: (tensor, cb) -> tensor +So the CB is operand[1]. + + + + + + + + +```python +ttl.operators._get_cb_shape( + cb_val +) +``` + + + + + + +Extract the block shape from a CB value. + + + + + + + + +```python +ttl.operators._get_constant_int( + val +) +``` + + + + + + +Extract Python int from MLIR arith.ConstantOp or return as-is if already int. + + + + + + + + +```python +ttl.operators._get_current_grid() -> typing.Tuple[int, int] +``` + + + + + + +Get the current grid dimensions. + + + + + + + + +```python +ttl.operators._is_block( + value +) -> bool +``` + + + + + + +Check if a value is a block (result of cb.reserve() or cb.wait()). + +A block is a tensor with an attached CB, produced by ttl.attach_cb. + + + + + + + + +```python +ttl.operators._make_tensor_slice( + tensor, + indices, + slice_shape +) +``` + + + + + + +Create a ttl.tensor_slice from a tensor, tile indices, and shape. + +**Parameters:** + + +The source tensor to slice from + + + +(row, col) tile indices for the slice start position + + + +(rows, cols) shape for the slice in tiles + + + + + + + + + +```python +ttl.operators._process_tensor_subscript( + subscript_tuple, + cb_shape +) +``` + + + + + + +Process tensor subscript and create tensor slice. + +**Parameters:** + + +(tensor, indices) where indices are [(value, is_range), ...] + + + +[rows, cols] shape from the CB + + +**Returns:** + +Tensor slice with shape matching cb_shape + + + + + + + + +```python +ttl.operators._set_current_grid( + grid: typing.Tuple[int, int] +) -> None +``` + + + + + + +Set the current grid dimensions. Called before compiling threads. + + + + + + + + +```python +ttl.operators.broadcast( + input: ttl.operators.TensorBlock, + output: ttl.operators.TensorBlock, + dims: typing.List[int] +) -> ttl.operators.TensorBlock +``` + + + + + + +Broadcast over specified dimensions. + +**Parameters:** + + +Input tensor (CB-attached) + + + +Output tensor (CB-attached, used for output CB tracking) + + + +Dimensions to broadcast over + + +**Returns:** `TensorBlock` + +Result tensor with broadcast values + + + + + + + + +```python +ttl.operators.copy( + src, + dst +) -> ttl.operators.CopyTransferHandler +``` + + + + + + +Initiate an asynchronous data transfer using ttl.copy. + +For multi-tile CBs (shape > 1x1), use range syntax: tensor[0:2, 0:2] +For single-tile CBs (shape 1x1), use index syntax: tensor[0, 0] + +**Parameters:** + + +Source tensor/slice (for reads) or block (for writes) + + + +Destination block (for reads) or tensor/slice (for writes) + + +**Returns:** `CopyTransferHandler` + +CopyTransferHandler handle that must be waited on for completion + + + + + + + + +```python +ttl.operators.core( + dims +) +``` + + + + + + +Get the coordinates of the current core. + +Currently only dims=2 is supported (temporary restriction). + +**Parameters:** + + +Number of dimensions to return (must be 2) + + +**Returns:** + +For dims=2: Tuple (x, y) where x is column coordinate and y is row coordinate + +**Raises:** + +- `ValueError`: If dims is not 2 + + + + + + + + +```python +ttl.operators.grid_size( + dims +) +``` + + + + + + +Get the size of the grid. + +Currently only dims=2 is supported (temporary restriction). + +**Parameters:** + + +Number of dimensions to return (must be 2) + + +**Returns:** + +For dims=2: Tuple (x_size, y_size) where x_size is columns and y_size is rows + +**Raises:** + +- `ValueError`: If dims is not 2 + + + + + + + + +```python +ttl.operators.signpost( + name: str +) +``` + + + + + + +Emit a profiling marker visible in Tracy. + +The marker creates a DeviceZoneScopedN in the generated C++ code, +which will appear in Tracy profiler traces when TT_METAL_DEVICE_PROFILER=1. + +**Parameters:** + + +Name for the profiling region (must be a string literal) + + + + + + + + + +```python +ttl.operators.CoreCoordinate = Tuple[int, int] +``` + + + + + + + + + +```python +ttl.operators.IndexedTensor = Union['TensorBlock', Tuple['TensorBlock', Tuple[int, ...]]] +``` + + + + + + + + + +```python +ttl.operators.__all__ = ['TensorBlock', 'CopyTransferHandler', 'copy', 'core', 'grid_size', 'signpost', ... +``` + + + + + + + + + +```python +ttl.operators._current_grid: Tuple[int, int] = (-1, -1) +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/ttl.mdx b/fern/library-docs/ttl-docs/ttl/ttl/ttl.mdx new file mode 100644 index 0000000..f55fc50 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/ttl.mdx @@ -0,0 +1,27 @@ +--- +layout: overview +slug: ttl/ttl/ttl +title: ttl.ttl +--- + +TTL DSL module providing the unified ttl.* API namespace. + +## Module Contents + +### Data + +[`__all__`](#ttl-ttl-__all__) + +### API + + + + + +```python +ttl.ttl.__all__ = ['kernel', 'compute', 'datamovement', 'Program', 'make_dataflow_buffer_like', 'c... +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/ttl_api.mdx b/fern/library-docs/ttl-docs/ttl/ttl/ttl_api.mdx new file mode 100644 index 0000000..e890735 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/ttl_api.mdx @@ -0,0 +1,907 @@ +--- +layout: overview +slug: ttl/ttl/ttl_api +title: ttl.ttl_api +--- + +Main API for the TTL dialect Python DSL. + +## Module Contents + +### Classes + +| Name | Description | +|------|-------------| +| [`CompiledTTNNKernel`](#ttl-ttl_api-CompiledTTNNKernel) | A compiled tt-lang kernel ready for execution via ttnn.generic_op. | +| [`Program`](#ttl-ttl_api-Program) | Immutable container for kernel threads and their arguments. | + +### Functions + +| Name | Description | +|------|-------------| +| [`_clear_thread_registry`](#ttl-ttl_api-_clear_thread_registry) | Clear the thread registry before kernel execution. | +| [`_collect_captures`](#ttl-ttl_api-_collect_captures) | Collect and convert captured variables from function closure. | +| [`_collect_cb_configs`](#ttl-ttl_api-_collect_cb_configs) | Extract CircularBuffer objects from thread closures, indexed by cb_index. | +| [`_compile`](#ttl-ttl_api-_compile) | Internal decorator for compiling kernel threads. | +| [`_compile_kernel`](#ttl-ttl_api-_compile_kernel) | Compile kernel function to MLIR and return CompiledTTNNKernel. | +| [`_compile_ttnn_kernel`](#ttl-ttl_api-_compile_ttnn_kernel) | Compile kernel to CompiledTTNNKernel for execution via ttnn.generic_op. | +| [`_detect_memory_space_from_tensor`](#ttl-ttl_api-_detect_memory_space_from_tensor) | Detect memory space (L1/DRAM) from a ttnn tensor's buffer type. | +| [`_get_registered_threads`](#ttl-ttl_api-_get_registered_threads) | Get all registered threads and clear the registry. | +| [`_get_source_line_offset`](#ttl-ttl_api-_get_source_line_offset) | Get the line offset to convert parsed AST line numbers to actual file lines. | +| [`_get_tensor_cache_info`](#ttl-ttl_api-_get_tensor_cache_info) | Extract cache-relevant info from a tensor: (shape, dtype, memory_space, layout). | +| [`_has_float32_args`](#ttl-ttl_api-_has_float32_args) | Check if any input tensor uses float32 dtype. | +| [`_is_interleaved_tensor`](#ttl-ttl_api-_is_interleaved_tensor) | Check if a ttnn tensor has interleaved memory layout. | +| [`_make_cache_key`](#ttl-ttl_api-_make_cache_key) | Create cache key from tensor properties and runtime compute config parameters. | +| [`_register_thread`](#ttl-ttl_api-_register_thread) | Register a thread function during decoration. | +| [`_resolve_grid`](#ttl-ttl_api-_resolve_grid) | Resolve grid, evaluating callable or 'auto' if needed. | +| [`_run_profiling_pipeline`](#ttl-ttl_api-_run_profiling_pipeline) | Read device profiler data and display profile report. | +| [`_should_execute`](#ttl-ttl_api-_should_execute) | Check if kernel execution should proceed (not compile-only mode). | +| [`_track_tensor_sources`](#ttl-ttl_api-_track_tensor_sources) | Track source locations for tensor arguments. | +| [`_write_kernel_to_tmp`](#ttl-ttl_api-_write_kernel_to_tmp) | Write kernel source to /tmp and return the file path. | +| [`compute`](#ttl-ttl_api-compute) | Decorator for compute thread functions. | +| [`datamovement`](#ttl-ttl_api-datamovement) | Decorator for data movement thread functions. | +| [`pykernel_gen`](#ttl-ttl_api-pykernel_gen) | Decorator for generating TTL kernels from Python functions. | + +### Data + +[`__all__`](#ttl-ttl_api-__all__) + +[`_thread_registry`](#ttl-ttl_api-_thread_registry) + +[`kernel`](#ttl-ttl_api-kernel) + +### API + + + + + +```python +class ttl.ttl_api.CompiledTTNNKernel( + kernel_paths, + kernel_configs, + kernel_arg_specs, + num_tensors, + core_ranges, + kernel_tensor_indices, + cb_configs = None, + program_hash = None, + source_lines = None, + all_source_lines = None, + thread_to_kernel = None, + kernel_line_offsets = None +) +``` + + + + + + +A compiled tt-lang kernel ready for execution via ttnn.generic_op. + +Caches compilation artifacts (kernel paths, CB descriptors) so the kernel +can be executed multiple times with different tensors without recompiling. + + + + + + + + + + + + + + + + + +```python +ttl.ttl_api.CompiledTTNNKernel.__call__( + args = () +) +``` + + + + + + +Execute the kernel with the given tensors. + + + + + + + + + +```python +class ttl.ttl_api.Program( + threads = (), + args = (), + kwargs = None +) +``` + + + + + + +Immutable container for kernel threads and their arguments. + +A Program encapsulates compute and data movement threads along with +the arguments to be passed during execution. After construction, all +fields should be treated as read-only. + + + + + + + + + + + + + + + + + +```python +ttl.ttl_api.Program.__call__( + args = (), + kwargs = {} +) +``` + + + + + + + + + + + + + + +```python +ttl.ttl_api._clear_thread_registry() -> None +``` + + + + + + +Clear the thread registry before kernel execution. + + + + + + + + +```python +ttl.ttl_api._collect_captures( + f: typing.Callable +) -> typing.Dict[str, typing.Union[int, ttl.circular_buffer.CircularBuffer]] +``` + + + + + + +Collect and convert captured variables from function closure. + +**Parameters:** + + +Function with closure to inspect + + +**Returns:** `Dict[str, Union[int, CircularBuffer]]` + +Dictionary mapping variable names to converted values + +**Raises:** + +- `TypeError`: If closure contains unsupported variable types + + + + + + + + +```python +ttl.ttl_api._collect_cb_configs( + threads +) +``` + + + + + + +Extract CircularBuffer objects from thread closures, indexed by cb_index. + +Returns a list of CircularBuffer objects indexed by cb_index. Each CB has +shape, buffer_factor, tensor (for dtype), and _cb_index attributes. + + + + + + + + +```python +ttl.ttl_api._compile( + kernel_type: typing.Optional[str] = None, + verbose: bool = False +) -> typing.Callable +``` + + + + + + +Internal decorator for compiling kernel threads. + +**Parameters:** + + +Type of kernel ("compute" or "datamovement") + + + +Enable verbose compilation output + + +**Returns:** `Callable` + +Decorator function for kernel compilation + + + + + + + + +```python +ttl.ttl_api._compile_kernel( + f: typing.Callable, + args: tuple, + kwargs: dict, + grid: typing.Union[tuple, typing.List[int]], + indexing_maps: typing.List[typing.Callable], + iterator_types: typing.List[str], + num_outs: int, + memory_space: str, + tiled: bool, + program_hash: int, + fp32_dest_acc_en: typing.Optional[bool] = None, + dst_full_sync_en: typing.Optional[bool] = None +) -> typing.Optional[ttl.ttl_api.CompiledTTNNKernel] +``` + + + + + + +Compile kernel function to MLIR and return CompiledTTNNKernel. + +**Parameters:** + + +User kernel function + + + +Positional arguments for the kernel + + + +Keyword arguments for the kernel + + + +Grid dimensions + + + +List of lambda functions for indexing + + + +List of iterator type strings + + + +Number of output arguments + + + +"L1" or "DRAM" + + + +Whether to use tiled layout + + + +Hash for tt-metal program cache + + + +Optional override for fp32_dest_acc_en + + + +Optional override for dst_full_sync_en + + +**Returns:** `Optional[CompiledTTNNKernel]` + +CompiledTTNNKernel ready for execution + + + + + + + + +```python +ttl.ttl_api._compile_ttnn_kernel( + module, + args, + grid, + num_outs, + thread_tensor_indices, + cb_configs = None, + program_hash = None, + fp32_dest_acc_en: typing.Optional[bool] = None, + dst_full_sync_en: typing.Optional[bool] = None, + verbose = True, + source_lines = None, + all_source_lines = None, + kernel_line_offsets = None +) +``` + + + + + + +Compile kernel to CompiledTTNNKernel for execution via ttnn.generic_op. + +Builds kernel paths, configs, and CB descriptors from compiled MLIR module. + +**Parameters:** + + +MLIR module after D2M pipeline (with EmitC kernels) + + + +Input/output tensors (used for shape/dtype info) + + + +Grid dimensions tuple + + + +Number of output tensors + + + +Hash for tt-metal program cache + + + +Print compilation info + + + +Source code lines for auto-profiling reports + + +**Returns:** + +CompiledTTNNKernel ready for execution + + + + + + + + +```python +ttl.ttl_api._detect_memory_space_from_tensor( + tensor, + default: str +) -> str +``` + + + + + + +Detect memory space (L1/DRAM) from a ttnn tensor's buffer type. + + + + + + + + +```python +ttl.ttl_api._get_registered_threads() -> typing.List[typing.Callable] +``` + + + + + + +Get all registered threads and clear the registry. + + + + + + + + +```python +ttl.ttl_api._get_source_line_offset( + f +) -> int +``` + + + + + + +Get the line offset to convert parsed AST line numbers to actual file lines. + + + + + + + + +```python +ttl.ttl_api._get_tensor_cache_info( + tensor +) -> tuple +``` + + + + + + +Extract cache-relevant info from a tensor: (shape, dtype, memory_space, layout). + + + + + + + + +```python +ttl.ttl_api._has_float32_args( + args +) -> bool +``` + + + + + + +Check if any input tensor uses float32 dtype. + +Inspects the tensor arguments to detect float32. This is used to +automatically enable fp32_dest_acc_en configuration for compute kernels. + +**Parameters:** + + +List of tensor arguments (torch or ttnn) + + +**Returns:** `bool` + +True if any tensor uses float32 dtype, False otherwise + + + + + + + + +```python +ttl.ttl_api._is_interleaved_tensor( + tensor +) -> bool +``` + + + + + + +Check if a ttnn tensor has interleaved memory layout. + + + + + + + + +```python +ttl.ttl_api._make_cache_key( + args: tuple, + fp32_dest_acc_en: typing.Optional[bool], + dst_full_sync_en: typing.Optional[bool] +) -> tuple +``` + + + + + + +Create cache key from tensor properties and runtime compute config parameters. + + + + + + + + +```python +ttl.ttl_api._register_thread( + thread_fn: typing.Callable +) -> None +``` + + + + + + +Register a thread function during decoration. + + + + + + + + +```python +ttl.ttl_api._resolve_grid( + grid, + args, + kwargs +) +``` + + + + + + +Resolve grid, evaluating callable or 'auto' if needed. + + + + + + + + +```python +ttl.ttl_api._run_profiling_pipeline( + tensors: tuple, + all_source_lines: typing.Dict[str, typing.List[str]], + thread_to_kernel: typing.Dict[str, str], + kernel_line_offsets: typing.Optional[typing.Dict[str, int]] = None +) +``` + + + + + + +Read device profiler data and display profile report. + +Called after kernel execution when auto-profiling is enabled. + +**Parameters:** + + +Tuple of tensor arguments passed to the kernel + + + +Dict mapping kernel name to source lines + + + +Dict mapping RISC thread name to kernel name + + + + + + + + + +```python +ttl.ttl_api._should_execute() -> bool +``` + + + + + + +Check if kernel execution should proceed (not compile-only mode). + + + + + + + + +```python +ttl.ttl_api._track_tensor_sources( + f_params, + args, + source_file: str +) -> None +``` + + + + + + +Track source locations for tensor arguments. + +Searches backwards from the kernel call site to find where each +tensor variable was assigned, then registers that location. + + + + + + + + +```python +ttl.ttl_api._write_kernel_to_tmp( + name: str, + source: str +) -> str +``` + + + + + + +Write kernel source to /tmp and return the file path. + + + + + + + + +```python +ttl.ttl_api.compute( + verbose: bool = False +) -> typing.Callable +``` + + + + + + +Decorator for compute thread functions. + +Compute threads execute on Tensix cores and perform mathematical operations. + +**Parameters:** + + +Enable verbose compilation output + + +**Returns:** `Callable` + +Decorator for compute kernel compilation + + + + + + + + +```python +ttl.ttl_api.datamovement( + verbose: bool = False +) -> typing.Callable +``` + + + + + + +Decorator for data movement thread functions. + +Data movement threads handle DMA operations between memory hierarchies. + +**Parameters:** + + +Enable verbose compilation output + + +**Returns:** `Callable` + +Decorator for data movement kernel compilation + + + + + + + + +```python +ttl.ttl_api.pykernel_gen( + grid: typing.Optional[typing.Union[tuple, typing.Callable]] = None, + indexing_maps: typing.Optional[typing.List[typing.Callable]] = None, + iterator_types: typing.Optional[typing.List[str]] = None, + num_outs: int = 1, + memory_space: str = 'L1', + tiled: bool = True, + fp32_dest_acc_en: typing.Optional[bool] = None, + dst_full_sync_en: typing.Optional[bool] = None +) -> typing.Callable +``` + + + + + + +Decorator for generating TTL kernels from Python functions. + +This decorator compiles Python functions into TTL dialect operations, +handling thread compilation, stream creation, and pipeline execution. +Kernels are compiled to C++ for execution via ttnn.generic_op. + +**Parameters:** + + +Grid dimensions as tuple (e.g., (2, 2)) or callable + + + +List of lambda functions for indexing (optional) + + + +List of iterator types ("parallel", "reduction") + + + +Number of output arguments + + + +"L1" or "DRAM" + + + +Whether to use tiled layout + + + +Optional override for fp32_dest_acc_en + + + +Optional override for dst_full_sync_en + + +**Returns:** `Callable` + +Decorated function that compiles and executes the kernel + +**Raises:** + +- `AssertionError`: If required parameters are missing or invalid + + + + + + + + +```python +ttl.ttl_api.__all__ = ['pykernel_gen', 'kernel', 'Program', 'compute', 'datamovement', 'TensorBlock', ... +``` + + + + + + + + + +```python +ttl.ttl_api._thread_registry: List[Callable] = [] +``` + + + + + + + + + +```python +ttl.ttl_api.kernel = pykernel_gen +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/ttl_math.mdx b/fern/library-docs/ttl-docs/ttl/ttl/ttl_math.mdx new file mode 100644 index 0000000..9472960 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/ttl_math.mdx @@ -0,0 +1,29 @@ +--- +layout: overview +slug: ttl/ttl/ttl_math +title: ttl.ttl_math +--- + +TTL math operations namespace (ttl.math). + +Re-exports elementwise operations from the generated module. + +## Module Contents + +### Data + +[`__all__`](#ttl-ttl_math-__all__) + +### API + + + + + +```python +ttl.ttl_math.__all__ = ['broadcast', *_generated_all] +``` + + + + diff --git a/fern/library-docs/ttl-docs/ttl/ttl/ttl_utils.mdx b/fern/library-docs/ttl-docs/ttl/ttl/ttl_utils.mdx new file mode 100644 index 0000000..813bd24 --- /dev/null +++ b/fern/library-docs/ttl-docs/ttl/ttl/ttl_utils.mdx @@ -0,0 +1,70 @@ +--- +layout: overview +slug: ttl/ttl/ttl_utils +title: ttl.ttl_utils +--- + +Utility functions for tt-lang. + +## Module Contents + +### Functions + +| Name | Description | +|------|-------------| +| [`get_thread_type_string`](#ttl-ttl_utils-get_thread_type_string) | Map kernel type to thread type string. | + +### Data + +[`_KERNEL_TYPE_TO_THREAD_TYPE`](#ttl-ttl_utils-_KERNEL_TYPE_TO_THREAD_TYPE) + +### API + + + + + +```python +ttl.ttl_utils.get_thread_type_string( + input: typing.Union[str, object] +) -> str +``` + + + + + + +Map kernel type to thread type string. + +Handles both string kernel types and MLIR ThreadTypeAttr. + +**Parameters:** + + +Either a string kernel type ("compute", "datamovement", "ethernet") + or a ttkernel.ThreadTypeAttr from MLIR IR + + +**Returns:** `str` + +Thread type string: "compute", "noc", "ethernet" + +**Raises:** + +- `ValueError`: If input is a string that's not a valid kernel type + + + + + + + + +```python +ttl.ttl_utils._KERNEL_TYPE_TO_THREAD_TYPE = {'compute': 'compute', 'datamovement': 'noc', 'ethernet': 'ethernet'} +``` + + + + diff --git a/fern/pages/cub/block_reduce_v3.mdx b/fern/pages/cub/block_reduce_v5.mdx similarity index 89% rename from fern/pages/cub/block_reduce_v3.mdx rename to fern/pages/cub/block_reduce_v5.mdx index 8bd6d0d..970bc46 100644 --- a/fern/pages/cub/block_reduce_v3.mdx +++ b/fern/pages/cub/block_reduce_v5.mdx @@ -86,7 +86,7 @@ The thread block length in threads along the X dimension ### BlockReduce inline - + Collective constructor using a private static allocation of shared memory as temporary storage. @@ -145,9 +145,14 @@ T cub::BlockReduce::Reduce( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The return value is undefined in threads other than thread0. + + + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -206,9 +211,14 @@ T cub::BlockReduce::Reduce( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Performance is sensitive to the degree of data movement across the block. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The return value is undefined in threads other than thread0. + + + +Performance is sensitive to the degree of data movement across the block. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -272,9 +282,14 @@ T cub::BlockReduce::Reduce( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The return value is undefined in threads other than thread0. + + + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -344,9 +359,14 @@ T cub::BlockReduce::Sum( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The return value is undefined in threads other than thread0. + + + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -394,9 +414,14 @@ T cub::BlockReduce::Sum( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Performance is sensitive to the degree of data movement across the block. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The return value is undefined in threads other than thread0. + + + +Performance is sensitive to the degree of data movement across the block. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -450,9 +475,14 @@ T cub::BlockReduce::Sum( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The return value is undefined in threads other than thread0. + + + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -547,4 +577,4 @@ struct cub::BlockReduce::TempStorage The operations exposed by [BlockReduce](/library/api/cub::BlockReduce) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. -**Inherits from:** `Uninitialized< _TempStorage >` (public) \ No newline at end of file +**Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/pages/cub/block_reduce.mdx b/fern/pages/cub/block_reduce_v6.mdx similarity index 64% rename from fern/pages/cub/block_reduce.mdx rename to fern/pages/cub/block_reduce_v6.mdx index 013402c..7f82146 100644 --- a/fern/pages/cub/block_reduce.mdx +++ b/fern/pages/cub/block_reduce_v6.mdx @@ -3,39 +3,24 @@ title: cub::BlockReduce description: "Collective methods for computing parallel reductions across a CUDA thread block." --- -# BlockReduce +The BlockReduce class provides collective methods for computing a parallel reduction of items partitioned across a CUDA thread block. -The `BlockReduce` class provides collective methods for computing a parallel reduction of items partitioned across a CUDA thread block. - -A [reduction](http://en.wikipedia.org/wiki/Reduce_(higher-order_function)) (or *fold*) uses a binary combining operator to compute a single aggregate from a list of input elements. Threads are assumed to be in row-major order. - -`BlockReduce` can be optionally specialized by algorithm to accommodate different latency/throughput workload profiles: - -1. [`cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY`](/library/api/cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY): - An efficient "raking" reduction algorithm that only supports commutative reduction operators. -2. [`cub::BLOCK_REDUCE_RAKING`](/library/api/cub::BLOCK_REDUCE_RAKING): - An efficient "raking" reduction algorithm that supports commutative and non-commutative reduction operators. -3. [`cub::BLOCK_REDUCE_WARP_REDUCTIONS`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS): - A quick "tiled warp-reductions" reduction algorithm that supports commutative and non-commutative reduction operators. -4. [`cub::BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC): - A quick "tiled warp-reductions" reduction algorithm that supports commutative and non-commutative reduction operators. This variant uses atomic operations to reduce the warp-wide reduction results, making it non-deterministic, i.e. the order of reduction operations is not guaranteed to be the same across different invocations of the same kernel. - -### Performance considerations +## Performance considerations - Performance is sensitive to the degree of data movement across the block. - Very efficient (only one synchronization barrier). -- Incurs zero bank conflicts for most types. +- Incurs zero bank conflicts for most types - Computation is slightly more efficient (i.e., having lower instruction overhead) for: - Summation (vs. generic reduction) - `BLOCK_THREADS` is a multiple of the architecture's warp size - Every thread has a valid input (i.e., full vs. partial-tiles) -- See `cub::BlockReduceAlgorithm` for performance details regarding algorithmic alternatives. +- See cub::BlockReduceAlgorithm for performance details regarding algorithmic alternatives -### Example +## Example The code snippet below illustrates a sum reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. -```cpp +```cpp showLineNumbers={false} #include // or equivalently __global__ void ExampleKernel(...) @@ -55,6 +40,7 @@ __global__ void ExampleKernel(...) } ``` + @@ -78,6 +64,7 @@ The thread block length in threads along the X dimension + --- @@ -86,13 +73,15 @@ The thread block length in threads along the X dimension ### BlockReduce inline - + Collective constructor using a private static allocation of shared memory as temporary storage. -```cpp -cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::BlockReduce() + +```cpp showLineNumbers={false} +cub::BlockReduce::BlockReduce() ``` + *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* @@ -101,13 +90,17 @@ cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::BlockReduce() Collective constructor using the specified memory allocation as temporary storage. -```cpp -cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::BlockReduce(TempStorage &temp_storage) + +```cpp showLineNumbers={false} +cub::BlockReduce::BlockReduce( + TempStorage &temp_storage +) ``` + *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -#### Parameters +**Parameters** Reference to memory allocation having layout type [TempStorage](/library/api/cub::BlockReduce::TempStorage) @@ -127,26 +120,34 @@ Reference to memory allocation having layout type [TempStorage](/library/api/cub Computes a block-wide reduction for thread0 using the specified binary reduction functor. Each thread contributes one input element. -```cpp + +```cpp showLineNumbers={false} template -T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Reduce( +T cub::BlockReduce::Reduce( T input, - ReductionOp reduction_op) + ReductionOp reduction_op +) ``` + *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + -#### Template parameters + +The return value is undefined in threads other than thread0. + + +**Template parameters** **[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` -#### Parameters +**Parameters** Calling thread's input @@ -156,11 +157,11 @@ Calling thread's input Binary reduction functor -#### Example +**Example** The code snippet below illustrates a max reduction of 128 integer items that are partitioned across 128 threads. -```cpp +```cpp showLineNumbers={false} #include // or equivalently __global__ void ExampleKernel(...) @@ -185,20 +186,28 @@ __global__ void ExampleKernel(...) Computes a block-wide reduction for thread0 using the specified binary reduction functor. Each thread contributes an array of consecutive input elements. -```cpp + +```cpp showLineNumbers={false} template -T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Reduce( - T(&inputs)[ITEMS_PER_THREAD], - ReductionOp reduction_op) +T cub::BlockReduce::Reduce( + T (&inputs)[ITEMS_PER_THREAD], + ReductionOp reduction_op +) ``` + *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Performance is sensitive to the degree of data movement across the block. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Performance is sensitive to the degree of data movement across the block. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + + +The return value is undefined in threads other than thread0. + -#### Template parameters +**Template parameters** **[inferred]** The number of consecutive items partitioned onto each thread. @@ -208,7 +217,7 @@ T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Reduce( **[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` -#### Parameters +**Parameters** Calling thread's input segment @@ -218,11 +227,11 @@ Calling thread's input segment Binary reduction functor -#### Example +**Example** The code snippet below illustrates a max reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. -```cpp +```cpp showLineNumbers={false} #include // or equivalently __global__ void ExampleKernel(...) @@ -247,27 +256,35 @@ __global__ void ExampleKernel(...) Computes a block-wide reduction for thread0 using the specified binary reduction functor. The first `num_valid` threads each contribute one input element. -```cpp + +```cpp showLineNumbers={false} template -T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Reduce( +T cub::BlockReduce::Reduce( T input, ReductionOp reduction_op, - int num_valid) + int num_valid +) ``` + *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + -#### Template parameters + +The return value is undefined in threads other than thread0. + + +**Template parameters** **[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` -#### Parameters +**Parameters** Calling thread's input @@ -281,11 +298,11 @@ Binary reduction functor Number of threads containing valid elements (may be less than BLOCK_THREADS) -#### Example +**Example** The code snippet below illustrates a max reduction of a partially-full tile of integer items that are partitioned across 128 threads. -```cpp +```cpp showLineNumbers={false} #include // or equivalently __global__ void ExampleKernel(int num_valid, ...) @@ -319,27 +336,36 @@ __global__ void ExampleKernel(int num_valid, ...) Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. Each thread contributes one input element. -```cpp -T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Sum(T input) + +```cpp showLineNumbers={false} +T cub::BlockReduce::Sum( + T input +) ``` + *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + + +The return value is undefined in threads other than thread0. + -#### Parameters +**Parameters** Calling thread's input -#### Example +**Example** The code snippet below illustrates a sum reduction of 128 integer items that are partitioned across 128 threads. -```cpp +```cpp showLineNumbers={false} #include // or equivalently __global__ void ExampleKernel(...) @@ -364,35 +390,43 @@ __global__ void ExampleKernel(...) Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. Each thread contributes an array of consecutive input elements. -```cpp + +```cpp showLineNumbers={false} template -T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Sum( - T(&inputs)[ITEMS_PER_THREAD]) +T cub::BlockReduce::Sum( + T (&inputs)[ITEMS_PER_THREAD] +) ``` + *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Performance is sensitive to the degree of data movement across the block. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Performance is sensitive to the degree of data movement across the block. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + -#### Template parameters + +The return value is undefined in threads other than thread0. + + +**Template parameters** **[inferred]** The number of consecutive items partitioned onto each thread. -#### Parameters +**Parameters** Calling thread's input segment -#### Example +**Example** The code snippet below illustrates a sum reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. -```cpp +```cpp showLineNumbers={false} #include // or equivalently __global__ void ExampleKernel(...) @@ -417,19 +451,27 @@ __global__ void ExampleKernel(...) Computes a block-wide reduction for thread0 using addition (+) as the reduction operator. The first `num_valid` threads each contribute one input element. -```cpp -T cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::Sum( + +```cpp showLineNumbers={false} +T cub::BlockReduce::Sum( T input, - int num_valid) + int num_valid +) ``` + *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The return value is undefined in threads other than thread0. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Assumes threads are in row-major order. +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + + +The return value is undefined in threads other than thread0. + -#### Parameters +**Parameters** Calling thread's input @@ -439,11 +481,11 @@ Calling thread's input Number of threads containing valid elements (may be less than BLOCK_THREADS) -#### Example +**Example** The code snippet below illustrates a sum reduction of a partially-full tile of integer items that are partitioned across 128 threads. -```cpp +```cpp showLineNumbers={false} #include // or equivalently __global__ void ExampleKernel(int num_valid, ...) @@ -475,11 +517,11 @@ __global__ void ExampleKernel(int num_valid, ...) Internal storage allocator. -```cpp -_TempStorage & cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ >::PrivateStorage() + +```cpp showLineNumbers={false} +_TempStorage & cub::BlockReduce::PrivateStorage() ``` - -**Returns:** Reference to [_TempStorage](/library/api/cub::BlockReduce::_TempStorage) + --- @@ -489,12 +531,12 @@ _TempStorage & cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ > | Name | Definition | Description | |---|---|---| -| `InternalBlockReduce` | `::cuda::std::_If< Algorithm==`[`BLOCK_REDUCE_WARP_REDUCTIONS`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS)`, `[`WarpReductions`](/library/api/cub::BlockReduce::WarpReductions)`, ::cuda::std::_If< Algorithm==`[`BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC`](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC)`, `[`WarpReductionsNondeterministic`](/library/api/cub::BlockReduce::WarpReductionsNondeterministic)`, ::cuda::std::_If< Algorithm==`[`BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY`](/library/api/cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY)`, `[`RakingCommutativeOnly`](/library/api/cub::BlockReduce::RakingCommutativeOnly)`, `[`Raking`](/library/api/cub::BlockReduce::Raking)` > > >` | Internal specialization type. | -| `_TempStorage` | `typename InternalBlockReduce::TempStorage` | Shared memory storage layout type for [BlockReduce](/library/api/cub::BlockReduce). | | `WarpReductions` | `detail::BlockReduceWarpReductions< T, BlockDimX, BlockDimY, BlockDimZ >` | | | `WarpReductionsNondeterministic` | `detail::BlockReduceWarpReductions< T, BlockDimX, BlockDimY, BlockDimZ, false >` | | | `RakingCommutativeOnly` | `detail::BlockReduceRakingCommutativeOnly< T, BlockDimX, BlockDimY, BlockDimZ >` | | | `Raking` | `detail::BlockReduceRaking< T, BlockDimX, BlockDimY, BlockDimZ >` | | +| `InternalBlockReduce` | `::cuda::std::_If< Algorithm==BLOCK_REDUCE_WARP_REDUCTIONS, WarpReductions, ::cuda::std::_If< Algorithm==BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC, WarpReductionsNondeterministic, ::cuda::std::_If< Algorithm==BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY, RakingCommutativeOnly, Raking > > >` | Internal specialization type. | +| `_TempStorage` | `typename InternalBlockReduce::TempStorage` | Shared memory storage layout type for `BlockReduce`. | --- @@ -502,8 +544,8 @@ _TempStorage & cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ > | Name | Type | Description | |---|---|---| -| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | -| `temp_storage` | [`_TempStorage`](/library/api/cub::BlockReduce::_TempStorage) `&` | Shared storage reference. | +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | | `linear_tid` | `unsigned int` | Linear thread-id. | --- @@ -512,10 +554,12 @@ _TempStorage & cub::BlockReduce< T, BlockDimX, Algorithm, BlockDimY, BlockDimZ > ### TempStorage -```cpp + +```cpp showLineNumbers={false} struct cub::BlockReduce::TempStorage ``` + -The operations exposed by [BlockReduce](/library/api/cub::BlockReduce) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. +The operations exposed by `BlockReduce` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. **Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/pages/cub/block_scan_v4.mdx b/fern/pages/cub/block_scan_v5.mdx similarity index 82% rename from fern/pages/cub/block_scan_v4.mdx rename to fern/pages/cub/block_scan_v5.mdx index 0b94a1f..8956281 100644 --- a/fern/pages/cub/block_scan_v4.mdx +++ b/fern/pages/cub/block_scan_v5.mdx @@ -84,7 +84,7 @@ The thread block length in threads along the X dimension ### BlockScan inline - + Collective constructor using a private static allocation of shared memory as temporary storage. @@ -142,9 +142,14 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -197,9 +202,14 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -258,10 +268,15 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -300,9 +315,14 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -362,9 +382,14 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -404,10 +429,15 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -461,9 +491,14 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -533,14 +568,16 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. - +Supports non-commutative scan operators. +Assumes threads are in row-major order. `initial_value` is not applied to the block-wide aggregate. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + **Template parameters** @@ -588,10 +625,15 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -640,9 +682,14 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -692,14 +739,16 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. - +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. `initial_value` is not applied to the block-wide aggregate. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + **Template parameters** @@ -751,10 +800,15 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -813,8 +867,13 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -867,8 +926,13 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -902,9 +966,14 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -943,8 +1012,13 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -980,8 +1054,13 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1021,9 +1100,14 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1076,9 +1160,14 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1143,9 +1232,14 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1215,10 +1309,15 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1266,9 +1365,14 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1313,9 +1417,14 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1364,9 +1473,14 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1441,14 +1555,16 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. - +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. `initial_value` is not applied to the block-wide aggregate. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + **Template parameters** @@ -1500,10 +1616,15 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** diff --git a/fern/pages/cub/block_scan.mdx b/fern/pages/cub/block_scan_v6.mdx similarity index 62% rename from fern/pages/cub/block_scan.mdx rename to fern/pages/cub/block_scan_v6.mdx index fd3b17c..b778b5f 100644 --- a/fern/pages/cub/block_scan.mdx +++ b/fern/pages/cub/block_scan_v6.mdx @@ -3,53 +3,22 @@ title: cub::BlockScan description: "Collective methods for computing parallel prefix sums/scans across a CUDA thread block." --- -The `BlockScan` class provides collective methods for computing a parallel prefix sum/scan of items partitioned across a CUDA thread block. - -Given a list of input elements and a binary reduction operator, a [prefix scan](http://en.wikipedia.org/wiki/Prefix_sum) produces an output list where each element is computed to be the reduction of the elements occurring earlier in the input list. *Prefix sum* connotes a prefix scan with the addition operator. The term *inclusive* indicates that the *i*th output reduction incorporates the *i*th input. The term *exclusive* indicates the *i*th input is not incorporated into the *i*th output reduction. Threads are assumed to be in row-major order. - -`BlockScan` can be optionally specialized by algorithm to accommodate different workload profiles: - -1. [`cub::BLOCK_SCAN_RAKING`](/library/api/cub::BLOCK_SCAN_RAKING): - An efficient (high throughput) "raking reduce-then-scan" prefix scan algorithm. -2. [`cub::BLOCK_SCAN_RAKING_MEMOIZE`](/library/api/cub::BLOCK_SCAN_RAKING_MEMOIZE): - Similar to `cub::BLOCK_SCAN_RAKING`, but having higher throughput at the expense of additional register pressure for intermediate storage. -3. [`cub::BLOCK_SCAN_WARP_SCANS`](/library/api/cub::BLOCK_SCAN_WARP_SCANS): - A quick (low latency) "tiled warpscans" prefix scan algorithm. +The BlockScan class provides collective methods for computing a parallel prefix sum/scan of items partitioned across a CUDA thread block. ## Performance considerations -- Uses special instructions when applicable (e.g., warp `SHFL` instructions) +- Performance is sensitive to the degree of data movement across the block. +- Uses special instructions when applicable (e.g., warp `SHFL`) - Uses synchronization-free communication between warp lanes when applicable -- Invokes a minimal number of minimal block-wide synchronization barriers (only one or two depending on algorithm selection) +- Invokes a minimal number of minimal block-wide synchronization barriers (only + one or two depending on algorithm selection) - Incurs zero bank conflicts for most types - Computation is slightly more efficient (i.e., having lower instruction overhead) for: + - Prefix sum variants (vs. generic scan) - `BLOCK_THREADS` is a multiple of the architecture's warp size -- See `cub::BlockScanAlgorithm` for performance details regarding algorithmic alternatives - -## Example - -The code snippet below illustrates an exclusive prefix sum of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. - -```cpp showLineNumbers={false} -#include // or equivalently - -__global__ void ExampleKernel(...) -{ - // Specialize BlockScan for a 1D block of 128 threads of type int - using BlockScan = cub::BlockScan; - - // Allocate shared memory for BlockScan - __shared__ typename BlockScan::TempStorage temp_storage; - // Obtain a segment of consecutive items that are blocked across threads - int thread_data[4]; - ... - - // Collectively compute the block-wide exclusive prefix sum - BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); -} -``` +- See cub::BlockScanAlgorithm for performance details regarding algorithmic alternatives @@ -84,7 +53,7 @@ The thread block length in threads along the X dimension ### BlockScan inline - + Collective constructor using a private static allocation of shared memory as temporary storage. @@ -142,9 +111,14 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -156,30 +130,6 @@ Calling thread's input item Calling thread's output item (may be aliased to `input`) -**Example** - -The code snippet below illustrates an exclusive prefix sum of 128 integer items that are partitioned across 128 threads. - -```cpp showLineNumbers={false} -#include // or equivalently - -__global__ void ExampleKernel(...) -{ - // Specialize BlockScan for a 1D block of 128 threads of type int - using BlockScan = cub::BlockScan; - - // Allocate shared memory for BlockScan - __shared__ typename BlockScan::TempStorage temp_storage; - - // Obtain input item for each thread - int thread_data; - ... - - // Collectively compute the block-wide exclusive prefix sum - BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); -} -``` - @@ -197,9 +147,14 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -212,34 +167,9 @@ Calling thread's output item (may be aliased to `input`) -block-wide aggregate reduction of input items +Block-wide aggregate reduction of input items -**Example** - -The code snippet below illustrates an exclusive prefix sum of 128 integer items that are partitioned across 128 threads. - -```cpp showLineNumbers={false} -#include // or equivalently - -__global__ void ExampleKernel(...) -{ - // Specialize BlockScan for a 1D block of 128 threads of type int - using BlockScan = cub::BlockScan; - - // Allocate shared memory for BlockScan - __shared__ typename BlockScan::TempStorage temp_storage; - - // Obtain input item for each thread - int thread_data; - ... - - // Collectively compute the block-wide exclusive prefix sum - int block_aggregate; - BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data, block_aggregate); -} -``` - @@ -258,10 +188,15 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -280,7 +215,10 @@ Calling thread's output item (may be aliased to `input`) -*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! @@ -300,9 +238,15 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -320,30 +264,6 @@ Calling thread's input items Calling thread's output items (may be aliased to `input`) -**Example** - -The code snippet below illustrates an exclusive prefix sum of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. - -```cpp showLineNumbers={false} -#include // or equivalently - -__global__ void ExampleKernel(...) -{ - // Specialize BlockScan for a 1D block of 128 threads of type int - using BlockScan = cub::BlockScan; - - // Allocate shared memory for BlockScan - __shared__ typename BlockScan::TempStorage temp_storage; - - // Obtain a segment of consecutive items that are blocked across threads - int thread_data[4]; - ... - - // Collectively compute the block-wide exclusive prefix sum - BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); -} -``` - @@ -362,9 +282,15 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -383,7 +309,7 @@ Calling thread's output items (may be aliased to `input`) -block-wide aggregate reduction of input items +Block-wide aggregate reduction of input items @@ -404,10 +330,16 @@ void cub::BlockScan::ExclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Uses the identity element (zero) as the initial value. -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Uses the identity element (zero) as the initial value. +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -430,7 +362,10 @@ Calling thread's output items (may be aliased to `input`) -*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! @@ -461,9 +396,14 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -482,37 +422,15 @@ Calling thread's output item (may be aliased to `input`) -Initial value to seed the exclusive scan (and is assigned to `output` in *thread*0) +Embed:rst:leading-asterisk +//! Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*\ :sub:`0`) +//! Binary scan functor -**Example** - -The code snippet below illustrates an exclusive prefix max scan of 128 integer items that are partitioned across 128 threads. - -```cpp showLineNumbers={false} -#include // or equivalently - -__global__ void ExampleKernel(...) -{ - // Specialize BlockScan for a 1D block of 128 threads of type int - using BlockScan = cub::BlockScan; - - // Allocate shared memory for BlockScan - __shared__ typename BlockScan::TempStorage temp_storage; - - // Obtain input item for each thread - int thread_data; - ... - - // Collectively compute the block-wide exclusive prefix max scan - BlockScan(temp_storage).ExclusiveScan(thread_data, thread_data, INT_MIN, cuda::maximum<>{}); -} -``` - @@ -533,13 +451,14 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Assumes threads are in row-major order. + - -`initial_value` is not applied to the block-wide aggregate. - + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -558,7 +477,11 @@ Calling thread's output items (may be aliased to `input`) -Initial value to seed the exclusive scan (and is assigned to `output` in *thread*0) +Embed:rst:leading-asterisk +//! Initial value to seed the exclusive scan (and is assigned to ``output[0]`` in *thread*\ :sub:`0`). It is not +//! taken into account for ``block_aggregate``. +//! +//! @@ -566,7 +489,7 @@ Binary scan functor -block-wide aggregate reduction of input items +Block-wide aggregate reduction of input items @@ -588,10 +511,15 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -618,11 +546,71 @@ Binary scan functor -*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! +**Example** + +The code snippet below illustrates a single thread block that progressively computes an exclusive prefix max scan over multiple "tiles" of input using a prefix functor to maintain a running total between block-wide scans. Each tile consists of 128 integer items that are partitioned across 128 threads. + +The corresponding output for the first segment will be `INT_MIN, 0, 0, 2, ..., 124, 126`. The output for the second segment will be `126, 128, 128, 130, ..., 252, 254`. + +```cpp showLineNumbers={false} +#include // or equivalently + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +struct BlockPrefixCallbackOp +{ + // Running prefix + int running_total; + + // Constructor + __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ int operator()(int block_aggregate) + { + int old_prefix = running_total; + running_total = (block_aggregate > old_prefix) ? block_aggregate : old_prefix; + return old_prefix; + } +}; + +__global__ void ExampleKernel(int *d_data, int num_items, ...) +{ + // Specialize BlockScan for a 1D block of 128 threads + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Initialize running total + BlockPrefixCallbackOp prefix_op(INT_MIN); + + // Have the block iterate over segments of items + for (int block_offset = 0; block_offset < num_items; block_offset += 128) + { + // Load a segment of consecutive items that are blocked across threads + int thread_data = d_data[block_offset + threadIdx.x]; + + // Collectively compute the block-wide exclusive prefix max scan + BlockScan(temp_storage).ExclusiveScan( + thread_data, thread_data, INT_MIN, cuda::maximum<>{}, prefix_op); + __syncthreads(); + + // Store scanned items to output segment + d_data[block_offset + threadIdx.x] = thread_data; + } +} +``` + - + Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. @@ -640,9 +628,15 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -665,7 +659,9 @@ Calling thread's output items (may be aliased to `input`) -Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*0) +Embed:rst:leading-asterisk +//! Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*\ :sub:`0`) +//! @@ -673,7 +669,7 @@ Binary scan functor - + Computes an exclusive block-wide prefix scan using the specified binary `scan_op` functor. Each thread contributes an array of consecutive input elements. Also provides every thread with the block-wide `block_aggregate` of all inputs. @@ -692,13 +688,15 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + - -`initial_value` is not applied to the block-wide aggregate. - + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -721,7 +719,10 @@ Calling thread's output items (may be aliased to `input`) -Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*0) +Embed:rst:leading-asterisk +//! Initial value to seed the exclusive scan (and is assigned to `output[0]` in *thread*\ :sub:`0`). It is not taken +//! into account for ``block_aggregate``. +//! @@ -729,9 +730,40 @@ Binary scan functor -block-wide aggregate reduction of input items +Block-wide aggregate reduction of input items +**Example** + +The code snippet below illustrates an exclusive prefix max scan of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +`{ [0,-1,2,-3], [4,-5,6,-7], ..., [508,-509,510,-511] }`. The corresponding output `thread_data` in those threads will be `{ [INT_MIN,0,0,2], [2,4,4,6], ..., [506,508,508,510] }`. Furthermore the value `510` will be stored in `block_aggregate` for all threads. + +.. note:: + +`initial_value` is not applied to the block-wide aggregate. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + // Specialize BlockScan for a 1D block of 128 threads of type int + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Obtain a segment of consecutive items that are blocked across threads + int thread_data[4]; + ... + + // Collectively compute the block-wide exclusive prefix max scan + int block_aggregate; + BlockScan(temp_storage).ExclusiveScan( + thread_data, thread_data, INT_MIN, cuda::maximum<>{}, block_aggregate); +``` + @@ -751,10 +783,16 @@ void cub::BlockScan::ExclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -785,7 +823,10 @@ Binary scan functor -*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! @@ -813,8 +854,13 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -826,30 +872,6 @@ Calling thread's input item Calling thread's output item (may be aliased to `input`) -**Example** - -The code snippet below illustrates an inclusive prefix sum of 128 integer items that are partitioned across 128 threads. - -```cpp showLineNumbers={false} -#include // or equivalently - -__global__ void ExampleKernel(...) -{ - // Specialize BlockScan for a 1D block of 128 threads of type int - using BlockScan = cub::BlockScan; - - // Allocate shared memory for BlockScan - __shared__ typename BlockScan::TempStorage temp_storage; - - // Obtain input item for each thread - int thread_data; - ... - - // Collectively compute the block-wide inclusive prefix sum - BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); -} -``` - @@ -867,8 +889,13 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -881,13 +908,13 @@ Calling thread's output item (may be aliased to `input`) -block-wide aggregate reduction of input items +Block-wide aggregate reduction of input items -Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. The call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. +Computes an inclusive block-wide prefix scan using addition (+) as the scan operator. Each thread contributes one input element. Instead of using 0 as the block-wide prefix, the call-back functor `block_prefix_callback_op` is invoked by the first warp in the block, and the value returned by *lane*0 in that warp is used as the "seed" value that logically prefixes the thread block's scan inputs. ```cpp showLineNumbers={false} @@ -902,9 +929,14 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -923,9 +955,68 @@ Calling thread's output item (may be aliased to `input`) -*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied +//! to the logical input sequence. +//! +**Example** + +The code snippet below illustrates a single thread block that progressively computes an inclusive prefix sum over multiple "tiles" of input using a prefix functor to maintain a running total between block-wide scans. Each tile consists of 128 integer items that are partitioned across 128 threads. + +The corresponding output for the first segment will be `1, 2, ..., 128`. The output for the second segment will be `129, 130, ..., 256`. + +```cpp showLineNumbers={false} +#include // or equivalently + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +struct BlockPrefixCallbackOp +{ + // Running prefix + int running_total; + + // Constructor + __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ int operator()(int block_aggregate) + { + int old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +__global__ void ExampleKernel(int *d_data, int num_items, ...) +{ + // Specialize BlockScan for a 1D block of 128 threads + using BlockScan = cub::BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + + // Initialize running total + BlockPrefixCallbackOp prefix_op(0); + + // Have the block iterate over segments of items + for (int block_offset = 0; block_offset < num_items; block_offset += 128) + { + // Load a segment of consecutive items that are blocked across threads + int thread_data = d_data[block_offset + threadIdx.x]; + + // Collectively compute the block-wide inclusive prefix sum + BlockScan(temp_storage).InclusiveSum( + thread_data, thread_data, prefix_op); + __syncthreads(); + + // Store scanned items to output segment + d_data[block_offset + threadIdx.x] = thread_data; + } +``` + @@ -943,8 +1034,14 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -980,8 +1077,14 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1000,7 +1103,7 @@ Calling thread's output items (may be aliased to `input`) -block-wide aggregate reduction of input items +Block-wide aggregate reduction of input items @@ -1021,9 +1124,15 @@ void cub::BlockScan::InclusiveSum *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1046,7 +1155,10 @@ Calling thread's output items (may be aliased to `input`) -*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to the +//! logical input sequence. +//! @@ -1076,9 +1188,14 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1100,30 +1217,6 @@ Calling thread's output item (may be aliased to `input`) Binary scan functor -**Example** - -The code snippet below illustrates an inclusive prefix max scan of 128 integer items that are partitioned across 128 threads. - -```cpp showLineNumbers={false} -#include // or equivalently - -__global__ void ExampleKernel(...) -{ - // Specialize BlockScan for a 1D block of 128 threads of type int - using BlockScan = cub::BlockScan; - - // Allocate shared memory for BlockScan - __shared__ typename BlockScan::TempStorage temp_storage; - - // Obtain input item for each thread - int thread_data; - ... - - // Collectively compute the block-wide inclusive prefix max scan - BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}); -} -``` - @@ -1143,9 +1236,14 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1168,13 +1266,15 @@ Binary scan functor -block-wide aggregate reduction of input items +Block-wide aggregate reduction of input items **Example** The code snippet below illustrates an inclusive prefix max scan of 128 integer items that are partitioned across 128 threads. +`0, -1, 2, -3, ..., 126, -127`. The corresponding output `thread_data` in those threads will be `0, 0, 2, 2, ..., 126, 126`. Furthermore the value `126` will be stored in `block_aggregate` for all threads. + ```cpp showLineNumbers={false} #include // or equivalently @@ -1193,7 +1293,6 @@ __global__ void ExampleKernel(...) // Collectively compute the block-wide inclusive prefix max scan int block_aggregate; BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}, block_aggregate); -} ``` @@ -1215,10 +1314,15 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Supports non-commutative scan operators. -- Assumes threads are in row-major order. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor's input parameter The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Assumes threads are in row-major order. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1245,7 +1349,10 @@ Binary scan functor -*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! @@ -1266,9 +1373,15 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1313,9 +1426,15 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1338,7 +1457,7 @@ Calling thread's output items (may be aliased to `input`) -Initial value to seed the inclusive scan +Initial value to seed the inclusive scan (uniform across block) @@ -1364,9 +1483,15 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1393,12 +1518,14 @@ Binary scan functor -block-wide aggregate reduction of input items +Block-wide aggregate reduction of input items **Example** -The code snippet below illustrates an inclusive prefix max scan of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. +The code snippet below illustrates an inclusive prefix max scan of 512 integer items that are partitioned in a [blocked arrangement](../index.html#sec5sec3) across 128 threads where each thread owns 4 consecutive items. + +`{ [0,-1,2,-3], [4,-5,6,-7], ..., [508,-509,510,-511] }`. The corresponding output `thread_data` in those threads will be `{ [0,0,2,2], [4,4,6,6], ..., [508,508,510,510] }`. Furthermore the value `510` will be stored in `block_aggregate` for all threads. ```cpp showLineNumbers={false} #include // or equivalently @@ -1418,7 +1545,6 @@ __global__ void ExampleKernel(...) // Collectively compute the block-wide inclusive prefix max scan int block_aggregate; BlockScan(temp_storage).InclusiveScan(thread_data, thread_data, cuda::maximum<>{}, block_aggregate); -} ``` @@ -1441,13 +1567,15 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + - -`initial_value` is not applied to the block-wide aggregate. - + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1470,7 +1598,7 @@ Calling thread's output items (may be aliased to `input`) -Initial value to seed the inclusive scan +Initial value to seed the inclusive scan (uniform across block). It is not taken into account for `block_aggregate`. @@ -1478,7 +1606,7 @@ Binary scan functor -block-wide aggregate reduction of input items +Block-wide aggregate reduction of input items @@ -1500,10 +1628,16 @@ void cub::BlockScan::InclusiveSca *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. -- Supports non-commutative scan operators. -- Data is in a blocked arrangement across threads. -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The `block_prefix_callback_op` functor must implement a member function `T operator()(T block_aggregate)`. The functor will be invoked by the first warp of threads in the block, however only the return value from *lane*0 is applied as the block-wide prefix. Can be stateful. +Supports non-commutative scan operators. +Data is in a blocked arrangement across threads. +Performance is sensitive to the degree of data movement across the block. + + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -1534,9 +1668,77 @@ Binary scan functor -*warp*0 only call-back functor for specifying a block-wide prefix to be applied to the logical input sequence +Embed:rst:leading-asterisk +//! *warp*\ :sub:`0` only call-back functor for specifying a block-wide prefix to be applied to +//! the logical input sequence. +//! +**Example** + +The code snippet below illustrates a single thread block that progressively computes an inclusive prefix max scan over multiple "tiles" of input using a prefix functor to maintain a running total between block-wide scans. Each tile consists of 128 integer items that are partitioned across 128 threads. + +The corresponding output for the first segment will be `0, 0, 2, 2, 4, 4, ..., 510, 510`. The output for the second segment will be `512, 512, 514, 514, 516, 516, ..., 1022, 1022`. + +```cpp showLineNumbers={false} +#include // or equivalently + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +struct BlockPrefixCallbackOp +{ + // Running prefix + int running_total; + + // Constructor + __device__ BlockPrefixCallbackOp(int running_total) : running_total(running_total) {} + + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ int operator()(int block_aggregate) + { + int old_prefix = running_total; + running_total = (block_aggregate > old_prefix) ? block_aggregate : old_prefix; + return old_prefix; + } +}; + +__global__ void ExampleKernel(int *d_data, int num_items, ...) +{ + // Specialize BlockLoad, BlockStore, and BlockScan for a 1D block of 128 threads, 4 ints per thread + using BlockLoad = cub::BlockLoad ; + using BlockStore = cub::BlockStore ; + using BlockScan = cub::BlockScan ; + + // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan + __shared__ union { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + + // Initialize running total + BlockPrefixCallbackOp prefix_op(0); + + // Have the block iterate over segments of items + for (int block_offset = 0; block_offset < num_items; block_offset += 128 * 4) + { + // Load a segment of consecutive items that are blocked across threads + int thread_data[4]; + BlockLoad(temp_storage.load).Load(d_data + block_offset, thread_data); + __syncthreads(); + + // Collectively compute the block-wide inclusive prefix max scan + BlockScan(temp_storage.scan).InclusiveScan( + thread_data, thread_data, cuda::maximum<>{}, prefix_op); + __syncthreads(); + + // Store scanned items to output segment + BlockStore(temp_storage.store).Store(d_data + block_offset, thread_data); + __syncthreads(); + } +``` + @@ -1550,12 +1752,10 @@ Internal storage allocator. ```cpp showLineNumbers={false} -_TempStorage& cub::BlockScan::PrivateStorage() +_TempStorage & cub::BlockScan::PrivateStorage() ``` -**Returns:** Reference to [_TempStorage](/library/api/cub::BlockScan::_TempStorage) - --- ## Types @@ -1564,10 +1764,21 @@ _TempStorage& cub::BlockScan::Pri | Name | Definition | Description | |---|---|---| -| `InternalBlockScan` | `::cuda::std::_If< SAFE_ALGORITHM==`[`BLOCK_SCAN_WARP_SCANS`](/library/api/cub::BLOCK_SCAN_WARP_SCANS)`, `[`WarpScans`](/library/api/cub::BlockScan::WarpScans)`, `[`Raking`](/library/api/cub::BlockScan::Raking)` >` | Define the delegate type for the desired algorithm. | -| `_TempStorage` | `typename InternalBlockScan::TempStorage` | Shared memory storage layout type for [BlockScan](/library/api/cub::BlockScan). | | `WarpScans` | `detail::BlockScanWarpScans< T, BlockDimX, BlockDimY, BlockDimZ >` | | -| `Raking` | `detail::BlockScanRaking< T, BlockDimX, BlockDimY, BlockDimZ, (SAFE_ALGORITHM==`[`BLOCK_SCAN_RAKING_MEMOIZE`](/library/api/cub::BLOCK_SCAN_RAKING_MEMOIZE)`) >` | | +| `Raking` | `detail::BlockScanRaking< T, BlockDimX, BlockDimY, BlockDimZ,(SAFE_ALGORITHM==BLOCK_SCAN_RAKING_MEMOIZE)>` | | +| `InternalBlockScan` | `::cuda::std::_If< SAFE_ALGORITHM==BLOCK_SCAN_WARP_SCANS, WarpScans, Raking >` | Define the delegate type for the desired algorithm. | +| `_TempStorage` | `typename InternalBlockScan::TempStorage` | Shared memory storage layout type for `BlockScan`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `SAFE_ALGORITHM` static constexpr | `BlockScanAlgorithm` | Ensure the template parameterization meets the requirements of the specified algorithm. | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | --- @@ -1581,6 +1792,6 @@ struct cub::BlockScan::TempStorage ``` -The operations exposed by [BlockScan](/library/api/cub::BlockScan) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. +The operations exposed by `BlockScan` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. **Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/pages/cub/simple_struct_v4.mdx b/fern/pages/cub/simple_struct_v4.mdx deleted file mode 100644 index a956942..0000000 --- a/fern/pages/cub/simple_struct_v4.mdx +++ /dev/null @@ -1,48 +0,0 @@ ---- -title: cub::ArgMax -description: "Arg max functor that keeps the value and offset of the first occurrence of the larger item." ---- - -Arg max functor (keeps the value and offset of the first occurrence of the larger item). - -`ArgMax` is a binary functor that operates on [`KeyValuePair`](/library/api/cub::KeyValuePair) instances, returning the pair with the larger value. In case of ties, the pair with the smaller offset is preferred. - ---- - -## Methods - -### operator() inline const - -Boolean max operator, preferring the item having the smaller offset in case of ties. - - -```cpp showLineNumbers={false} -template -KeyValuePair cub::ArgMax::operator()( - const KeyValuePair &a, - const KeyValuePair &b -) const -``` - - -**Template parameters** - - -**[inferred]** Value type - - - -**[inferred]** Offset type - - -**Parameters** - - -First input key-value pair - - - -Second input key-value pair - - -**Returns:** The [`KeyValuePair`](/library/api/cub::KeyValuePair) with the larger value (ties broken by smaller offset) diff --git a/fern/pages/cub/simple_struct.mdx b/fern/pages/cub/simple_struct_v5.mdx similarity index 100% rename from fern/pages/cub/simple_struct.mdx rename to fern/pages/cub/simple_struct_v5.mdx diff --git a/fern/pages/cub/simple_struct_v6.mdx b/fern/pages/cub/simple_struct_v6.mdx new file mode 100644 index 0000000..88572c9 --- /dev/null +++ b/fern/pages/cub/simple_struct_v6.mdx @@ -0,0 +1,24 @@ +--- +title: cub::ArgMax +description: "Arg max functor that keeps the value and offset of the first occurrence of the larger item." +--- + +Arg max functor (keeps the value and offset of the first occurrence of the larger item). + +--- + +## Methods + +### operator() inline const + +Boolean max operator, preferring the item having the smaller offset in case of ties. + + +```cpp showLineNumbers={false} +template +KeyValuePair cub::ArgMax::operator()( + const KeyValuePair &a, + const KeyValuePair &b +) const +``` + diff --git a/fern/pages/cub/warp_reduce_v4.mdx b/fern/pages/cub/warp_reduce_v5.mdx similarity index 89% rename from fern/pages/cub/warp_reduce_v4.mdx rename to fern/pages/cub/warp_reduce_v5.mdx index 2676b1d..61e792b 100644 --- a/fern/pages/cub/warp_reduce_v4.mdx +++ b/fern/pages/cub/warp_reduce_v5.mdx @@ -7,8 +7,10 @@ The `WarpReduce` class provides collective methods for computing a parallel redu A [reduction](http://en.wikipedia.org/wiki/Reduce_(higher-order_function)) (or *fold*) uses a binary combining operator to compute a single aggregate from a list of input elements. -- Supports "logical" warps smaller than the physical warp size (e.g., logical warps of 8 threads) -- The number of entrant threads must be a multiple of `LogicalWarpThreads` + +Supports "logical" warps smaller than the physical warp size (e.g., logical warps of 8 threads). +The number of entrant threads must be a multiple of `LogicalWarpThreads`. + ## Performance considerations @@ -103,7 +105,9 @@ T cub::WarpReduce::Reduce( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -203,7 +207,9 @@ T cub::WarpReduce::Reduce( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -275,7 +281,9 @@ T cub::WarpReduce::Sum( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -350,7 +358,9 @@ T cub::WarpReduce::Sum( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -409,7 +419,9 @@ T cub::WarpReduce::Max( ``` -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -461,7 +473,9 @@ T cub::WarpReduce::Max( ``` -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -495,7 +509,9 @@ T cub::WarpReduce::Min( ``` -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -547,7 +563,9 @@ T cub::WarpReduce::Min( ``` -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -582,7 +600,9 @@ T cub::WarpReduce::HeadSegmentedSum( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -637,7 +657,9 @@ T cub::WarpReduce::TailSegmentedSum( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -695,7 +717,9 @@ T cub::WarpReduce::HeadSegmentedReduce( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -763,7 +787,9 @@ T cub::WarpReduce::TailSegmentedReduce( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** diff --git a/fern/pages/cub/warp_reduce.mdx b/fern/pages/cub/warp_reduce_v6.mdx similarity index 68% rename from fern/pages/cub/warp_reduce.mdx rename to fern/pages/cub/warp_reduce_v6.mdx index 0ccffc2..2e03010 100644 --- a/fern/pages/cub/warp_reduce.mdx +++ b/fern/pages/cub/warp_reduce_v6.mdx @@ -5,10 +5,7 @@ description: "Collective methods for computing parallel reductions across a CUDA The `WarpReduce` class provides collective methods for computing a parallel reduction of items partitioned across a CUDA thread warp. -A [reduction](http://en.wikipedia.org/wiki/Reduce_(higher-order_function)) (or *fold*) uses a binary combining operator to compute a single aggregate from a list of input elements. - -- Supports "logical" warps smaller than the physical warp size (e.g., logical warps of 8 threads) -- The number of entrant threads must be a multiple of `LogicalWarpThreads` +![](../../img/warp_reduce_logo.png) ## Performance considerations @@ -16,13 +13,20 @@ A [reduction](http://en.wikipedia.org/wiki/Reduce_(higher-order_function)) (or * - Uses synchronization-free communication between warp lanes when applicable - Incurs zero bank conflicts for most types - Computation is slightly more efficient (i.e., having lower instruction overhead) for: - - Summation (vs. generic reduction) + + - Summation (**vs.** generic reduction) - The architecture's warp size is a whole multiple of `LogicalWarpThreads` ## Example The code snippet below illustrates four concurrent warp sum reductions within a block of 128 threads (one per each of the 32-thread warps). +The corresponding output `aggregate` in threads 0, 32, 64, and 96 will be `496`, `1520`, `2544`, and `3568`, respectively (and is undefined in other threads). + +The code snippet below illustrates a single warp sum reduction within a block of 128 threads. + +The corresponding output `aggregate` in thread0 will be `496` (and is undefined in other threads). + ```cpp showLineNumbers={false} #include @@ -38,19 +42,36 @@ __global__ void ExampleKernel(...) int warp_id = threadIdx.x / 32; int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data); } -``` -Suppose the set of input `thread_data` across the block of threads is `{0, 1, 2, 3, ..., 127}`. The corresponding output `aggregate` in threads 0, 32, 64, and 96 will be `496`, `1520`, `2544`, and `3568`, respectively (and is undefined in other threads). +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + ... + // Only the first warp performs a reduction + if (threadIdx.x < 32) + { + // Obtain one input item per thread + int thread_data = ... + // Return the warp-wide sum to lane0 + int aggregate = WarpReduce(temp_storage).Sum(thread_data); + } +} +``` -Data type being reduced +The reduction input/output element type -**[optional]** The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute architecture (e.g., 32 threads for SM3x). +**[optional]** The number of threads per "logical" warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute-capability (e.g., 32 threads for SM20). @@ -82,7 +103,7 @@ Reference to memory allocation having layout type [TempStorage](/library/api/cub ## Summation reductions -### Sum inline +### Sum inline nodiscard @@ -99,18 +120,16 @@ T cub::WarpReduce::Sum( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. - -**Parameters** - - -Calling thread's input - + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Example** The code snippet below illustrates four concurrent warp sum reductions within a block of 128 threads (one per each of the 32-thread warps). +The corresponding output `aggregate` in threads 0, 32, 64, and 96 will `496`, `1520`, `2544`, and `3568`, respectively (and is undefined in other threads). + ```cpp showLineNumbers={false} #include @@ -131,31 +150,15 @@ __global__ void ExampleKernel(...) -Computes a warp-wide sum in the calling warp. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. - ```cpp showLineNumbers={false} -template , int> = 0> +template T cub::WarpReduce::Sum( const InputType &input ) ``` -**Template parameters** - - -**[inferred]** Input type, must be a fixed-size random access range - - -**Parameters** - - -Calling thread's input - - @@ -174,7 +177,9 @@ T cub::WarpReduce::Sum( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -190,6 +195,8 @@ Total number of valid items in the calling thread's logical warp (may be less th The code snippet below illustrates a sum reduction within a single, partially-full block of 32 threads (one warp). +The corresponding output `aggregate` in *lane*0 is `6` (and is undefined in other threads). + ```cpp showLineNumbers={false} #include @@ -214,17 +221,11 @@ __global__ void ExampleKernel(int *d_data, int valid_items) ---- - -## Max reductions - -### Max inline +### Max inline nodiscard -Computes a warp-wide maximum in the calling warp. The output is valid in warp *lane*0. - ```cpp showLineNumbers={false} T cub::WarpReduce::Max( @@ -233,49 +234,21 @@ T cub::WarpReduce::Max( ``` -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. - -**Parameters** - - -Calling thread's input - - -Computes a warp-wide maximum in the calling warp. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. - ```cpp showLineNumbers={false} -template , int> = 0> +template T cub::WarpReduce::Max( const InputType &input ) ``` -**Template parameters** - - -**[inferred]** Input type, must be a fixed-size random access range - - -**Parameters** - - -Calling thread's input - - -Computes a partially-full warp-wide maximum in the calling warp. The output is valid in warp *lane*0. - -All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. - ```cpp showLineNumbers={false} T cub::WarpReduce::Max( @@ -285,32 +258,14 @@ T cub::WarpReduce::Max( ``` -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. - -**Parameters** - - -Calling thread's input - - - -Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) - - ---- - -## Min reductions - -### Min inline +### Min inline nodiscard -Computes a warp-wide minimum in the calling warp. The output is valid in warp *lane*0. - ```cpp showLineNumbers={false} T cub::WarpReduce::Min( @@ -319,59 +274,109 @@ T cub::WarpReduce::Min( ``` -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + + -**Parameters** + +```cpp showLineNumbers={false} +template +T cub::WarpReduce::Min( + const InputType &input +) +``` + - -Calling thread's input - + + + + +```cpp showLineNumbers={false} +T cub::WarpReduce::Min( + T input, + int valid_items +) +``` + - + + +### HeadSegmentedSum inline nodiscard -Computes a warp-wide minimum in the calling warp. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. +Computes a segmented sum in the calling warp where segments are defined by head-flags. The sum of each segment is returned to the first lane in that segment (which always includes *lane*0). ```cpp showLineNumbers={false} -template , int> = 0> -T cub::WarpReduce::Min( - const InputType &input +template +T cub::WarpReduce::HeadSegmentedSum( + T input, + FlagT head_flag ) ``` -**Template parameters** +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* - -**[inferred]** Input type, must be a fixed-size random access range - + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** - + Calling thread's input - - + +Head flag denoting whether or not `input` is the start of a new segment + -Computes a partially-full warp-wide minimum in the calling warp. The output is valid in warp *lane*0. +**Example** -All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. +The code snippet below illustrates a head-segmented warp sum reduction within a block of 32 threads (one warp). + +is `{0, 1, 2, 3, ..., 31` and is `{1, 0, 0, 0, 1, 0, 0, 0, ..., 1, 0, 0, 0`, respectively. The corresponding output `aggregate` in threads 0, 4, 8, etc. will be `6`, `22`, `38`, etc. (and is undefined in other threads). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int head_flag = ... + + // Return the warp-wide sums to each lane0 + int aggregate = WarpReduce(temp_storage).HeadSegmentedSum( + thread_data, head_flag); +} +``` + +### TailSegmentedSum inline nodiscard + +Computes a segmented sum in the calling warp where segments are defined by tail-flags. The sum of each segment is returned to the first lane in that segment (which always includes *lane*0). ```cpp showLineNumbers={false} -T cub::WarpReduce::Min( +template +T cub::WarpReduce::TailSegmentedSum( T input, - int valid_items + FlagT tail_flag ) ``` -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Parameters** @@ -379,25 +384,48 @@ T cub::WarpReduce::Min( Calling thread's input - -Total number of valid items in the calling thread's logical warp (may be less than `LogicalWarpThreads`) + +Head flag denoting whether or not `input` is the start of a new segment - - +**Example** + +The code snippet below illustrates a tail-segmented warp sum reduction within a block of 32 threads (one warp). + +is `{0, 1, 2, 3, ..., 31}` and is `{0, 0, 0, 1, 0, 0, 0, 1, ..., 0, 0, 0, 1}`, respectively. The corresponding output `aggregate` in threads 0, 4, 8, etc. will be `6`, `22`, `38`, etc. (and is undefined in other threads). + +```cpp showLineNumbers={false} +#include + +__global__ void ExampleKernel(...) +{ + // Specialize WarpReduce for type int + using WarpReduce = cub::WarpReduce; + + // Allocate WarpReduce shared memory for one warp + __shared__ typename WarpReduce::TempStorage temp_storage; + + // Obtain one input item and flag per thread + int thread_data = ... + int tail_flag = ... + + // Return the warp-wide sums to each lane0 + int aggregate = WarpReduce(temp_storage).TailSegmentedSum( + thread_data, tail_flag); +``` --- ## Generic reductions -### Reduce inline +### Reduce inline nodiscard Computes a warp-wide reduction in the calling warp using the specified binary reduction functor. The output is valid in warp *lane*0. -Supports non-commutative reduction operators. +Supports non-commutative reduction operators ```cpp showLineNumbers={false} @@ -411,7 +439,9 @@ T cub::WarpReduce::Reduce( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -433,6 +463,8 @@ Binary reduction operator The code snippet below illustrates four concurrent warp max reductions within a block of 128 threads (one per each of the 32-thread warps). +`{0, 1, 2, 3, ..., 127}`. The corresponding output `aggregate` in threads 0, 32, 64, and 96 will be `31`, `63`, `95`, and `127`, respectively (and is undefined in other threads). + ```cpp showLineNumbers={false} #include @@ -451,14 +483,11 @@ __global__ void ExampleKernel(...) int warp_id = threadIdx.x / 32; int aggregate = WarpReduce(temp_storage[warp_id]).Reduce( thread_data, cuda::maximum<>{}); -} ``` -Computes a warp-wide reduction in the calling warp using the specified binary reduction functor. Each thread contributes a fixed-size array of consecutive input elements. The output is valid in warp *lane*0. - ```cpp showLineNumbers={false} template @@ -469,26 +498,6 @@ T cub::WarpReduce::Reduce( ``` -**Template parameters** - - -**[inferred]** Input type, must be a fixed-size random access range - - - -**[inferred]** Binary reduction operator type having member `T operator()(const T &a, const T &b)` - - -**Parameters** - - -Calling thread's input - - - -Binary reduction operator - - @@ -496,7 +505,7 @@ Computes a partially-full warp-wide reduction in the calling warp using the spec All threads across the calling warp must agree on the same value for `valid_items`. Otherwise the result is undefined. -Supports non-commutative reduction operators. +Supports non-commutative reduction operators ```cpp showLineNumbers={false} @@ -511,7 +520,9 @@ T cub::WarpReduce::Reduce( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -537,6 +548,8 @@ Total number of valid items in the calling thread's logical warp (may be less th The code snippet below illustrates a max reduction within a single, partially-full block of 32 threads (one warp). +is `4`. The corresponding output `aggregate` in thread0 is `3` (and is undefined in other threads). + ```cpp showLineNumbers={false} #include @@ -556,131 +569,16 @@ __global__ void ExampleKernel(int *d_data, int valid_items) // Return the warp-wide reductions to each lane0 int aggregate = WarpReduce(temp_storage).Reduce( thread_data, cuda::maximum<>{}, valid_items); -} ``` ---- - -## Segmented reductions - -### HeadSegmentedSum inline - -Computes a segmented sum in the calling warp where segments are defined by head-flags. The sum of each segment is returned to the first lane in that segment (which always includes *lane*0). - - -```cpp showLineNumbers={false} -template -T cub::WarpReduce::HeadSegmentedSum( - T input, - FlagT head_flag -) -``` - - -*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* - -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. - -**Parameters** - - -Calling thread's input - - - -Head flag denoting whether or not `input` is the start of a new segment - - -**Example** - -The code snippet below illustrates a head-segmented warp sum reduction within a block of 32 threads (one warp). - -```cpp showLineNumbers={false} -#include - -__global__ void ExampleKernel(...) -{ - // Specialize WarpReduce for type int - using WarpReduce = cub::WarpReduce; - - // Allocate WarpReduce shared memory for one warp - __shared__ typename WarpReduce::TempStorage temp_storage; - - // Obtain one input item and flag per thread - int thread_data = ... - int head_flag = ... - - // Return the warp-wide sums to each lane0 - int aggregate = WarpReduce(temp_storage).HeadSegmentedSum( - thread_data, head_flag); -} -``` - ---- - -### TailSegmentedSum inline - -Computes a segmented sum in the calling warp where segments are defined by tail-flags. The sum of each segment is returned to the first lane in that segment (which always includes *lane*0). - - -```cpp showLineNumbers={false} -template -T cub::WarpReduce::TailSegmentedSum( - T input, - FlagT tail_flag -) -``` - - -*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* - -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. - -**Parameters** - - -Calling thread's input - - - -Tail flag denoting whether or not `input` is the end of the current segment - - -**Example** - -The code snippet below illustrates a tail-segmented warp sum reduction within a block of 32 threads (one warp). - -```cpp showLineNumbers={false} -#include - -__global__ void ExampleKernel(...) -{ - // Specialize WarpReduce for type int - using WarpReduce = cub::WarpReduce; - - // Allocate WarpReduce shared memory for one warp - __shared__ typename WarpReduce::TempStorage temp_storage; - - // Obtain one input item and flag per thread - int thread_data = ... - int tail_flag = ... - - // Return the warp-wide sums to each lane0 - int aggregate = WarpReduce(temp_storage).TailSegmentedSum( - thread_data, tail_flag); -} -``` - ---- - -### HeadSegmentedReduce inline +### HeadSegmentedReduce inline nodiscard Computes a segmented reduction in the calling warp where segments are defined by head-flags. The reduction of each segment is returned to the first lane in that segment (which always includes *lane*0). -Supports non-commutative reduction operators. +Supports non-commutative reduction operators ```cpp showLineNumbers={false} @@ -695,7 +593,9 @@ T cub::WarpReduce::HeadSegmentedReduce( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -714,13 +614,15 @@ Head flag denoting whether or not `input` is the start of a new segment -Binary reduction operator +Reduction operator **Example** The code snippet below illustrates a head-segmented warp max reduction within a block of 32 threads (one warp). +is `{0, 1, 2, 3, ..., 31}` and is `{1, 0, 0, 0, 1, 0, 0, 0, ..., 1, 0, 0, 0}`, respectively. The corresponding output `aggregate` in threads 0, 4, 8, etc. will be `3`, `7`, `11`, etc. (and is undefined in other threads). + ```cpp showLineNumbers={false} #include @@ -739,16 +641,13 @@ __global__ void ExampleKernel(...) // Return the warp-wide reductions to each lane0 int aggregate = WarpReduce(temp_storage).HeadSegmentedReduce( thread_data, head_flag, cuda::maximum<>{}); -} ``` ---- - -### TailSegmentedReduce inline +### TailSegmentedReduce inline nodiscard Computes a segmented reduction in the calling warp where segments are defined by tail-flags. The reduction of each segment is returned to the first lane in that segment (which always includes *lane*0). -Supports non-commutative reduction operators. +Supports non-commutative reduction operators ```cpp showLineNumbers={false} @@ -763,7 +662,9 @@ T cub::WarpReduce::TailSegmentedReduce( *Added in v2.2.0. First appears in CUDA Toolkit 12.3.* -- The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + +The block-wide aggregate of `temp_storage` is undefined after calling this method and should not be used. To preserve the aggregate, use a separate `TempStorage` for each method call. + **Template parameters** @@ -782,13 +683,15 @@ Tail flag denoting whether or not `input` is the end of the current segment -Binary reduction operator +Reduction operator **Example** The code snippet below illustrates a tail-segmented warp max reduction within a block of 32 threads (one warp). +is `{0, 1, 2, 3, ..., 31}` and is `{0, 0, 0, 1, 0, 0, 0, 1, ..., 0, 0, 0, 1}`, respectively. The corresponding output `aggregate` in threads 0, 4, 8, etc. will be `3`, `7`, `11`, etc. (and is undefined in other threads). + ```cpp showLineNumbers={false} #include @@ -807,7 +710,6 @@ __global__ void ExampleKernel(...) // Return the warp-wide reductions to each lane0 int aggregate = WarpReduce(temp_storage).TailSegmentedReduce( thread_data, tail_flag, cuda::maximum<>{}); -} ``` --- @@ -818,7 +720,17 @@ __global__ void ExampleKernel(...) | Name | Definition | Description | |---|---|---| -| `_TempStorage` | `typename InternalWarpReduce::TempStorage` | Shared memory storage layout type for [WarpReduce](/library/api/cub::WarpReduce). | +| `_TempStorage` | `typename InternalWarpReduce::TempStorage` | Shared memory storage layout type for `WarpReduce`. | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `is_full_warp` static constexpr | `bool` | | +| `is_power_of_two` static constexpr | `bool` | | +| `temp_storage` | `_TempStorage &` | Shared storage reference. | --- @@ -832,6 +744,6 @@ struct cub::WarpReduce::TempStorage ``` -The operations exposed by [WarpReduce](/library/api/cub::WarpReduce) require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. +The operations exposed by `WarpReduce` require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the `__shared__` keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or `union`'d with other storage allocation types to facilitate memory reuse. **Inherits from:** `Uninitialized< _TempStorage >` (public) diff --git a/fern/pages/libcudacxx/concept_example.mdx b/fern/pages/libcudacxx/concept_example.mdx deleted file mode 100644 index 363fb12..0000000 --- a/fern/pages/libcudacxx/concept_example.mdx +++ /dev/null @@ -1,43 +0,0 @@ ---- -title: "cuda::mr::resource_with" ---- - -# resource_with - -The `resource_with` concept verifies that a type Resource satisfies the [`resource`](/library/api/cuda::mr::resource) concept and also satisfies all the provided Properties. - - - - -The resource type to check against the [`resource`](/library/api/cuda::mr::resource) concept. - - - -A variadic pack of property types that the resource must additionally satisfy. - - - - ---- - -## Description - -`resource_with` is a compound concept that combines two requirements: - -1. The type `_Resource` must satisfy [`cuda::mr::resource`](/library/api/cuda::mr::resource), meaning it supports both synchronous and stream-ordered allocation interfaces. -2. The type `_Resource` must satisfy [`cuda::has_property`](/library/api/cuda::has_property) for each property type in `_Properties`. - -This concept is useful when writing generic code that requires a memory resource with specific properties, such as device accessibility or a particular allocation strategy. - ---- - -## Related concepts - -| Concept | Description | -|---|---| -| [`cuda::mr::resource`](/library/api/cuda::mr::resource) | Verifies that a type satisfies the basic requirements of a memory resource with stream-ordered allocations. | -| [`cuda::mr::synchronous_resource`](/library/api/cuda::mr::synchronous_resource) | Verifies that a type satisfies the basic requirements of a synchronous memory resource. | -| [`cuda::mr::synchronous_resource_with`](/library/api/cuda::mr::synchronous_resource_with) | The synchronous counterpart: verifies a synchronous resource that also satisfies a set of properties. | -| [`cuda::has_property`](/library/api/cuda::has_property) | Verifies that a resource satisfies a given property. | -| [`cuda::has_property_with`](/library/api/cuda::has_property_with) | Verifies that a resource satisfies a given stateful property. | -| [`cuda::property_with_value`](/library/api/cuda::property_with_value) | Verifies that a property is stateful and exposes a `value_type` alias. | diff --git a/fern/pages/libcudacxx/concept_example_v3.mdx b/fern/pages/libcudacxx/concept_example_v5.mdx similarity index 99% rename from fern/pages/libcudacxx/concept_example_v3.mdx rename to fern/pages/libcudacxx/concept_example_v5.mdx index bc6ebb7..63d1965 100644 --- a/fern/pages/libcudacxx/concept_example_v3.mdx +++ b/fern/pages/libcudacxx/concept_example_v5.mdx @@ -50,4 +50,4 @@ This concept is useful when writing generic code that requires a memory resource | [`cuda::mr::synchronous_resource_with`](/library/api/cuda::mr::synchronous_resource_with) | The synchronous counterpart: verifies a synchronous resource that also satisfies a set of properties. | | [`cuda::has_property`](/library/api/cuda::has_property) | Verifies that a resource satisfies a given property. | | [`cuda::has_property_with`](/library/api/cuda::has_property_with) | Verifies that a resource satisfies a given stateful property. | -| [`cuda::property_with_value`](/library/api/cuda::property_with_value) | Verifies that a property is stateful and exposes a `value_type` alias. | \ No newline at end of file +| [`cuda::property_with_value`](/library/api/cuda::property_with_value) | Verifies that a property is stateful and exposes a `value_type` alias. | diff --git a/fern/pages/libcudacxx/concept_example_v6.mdx b/fern/pages/libcudacxx/concept_example_v6.mdx new file mode 100644 index 0000000..a694f5f --- /dev/null +++ b/fern/pages/libcudacxx/concept_example_v6.mdx @@ -0,0 +1,27 @@ +--- +title: "cuda::mr::resource_with" +description: "A concept that verifies a memory resource satisfies both the resource concept and a set of property requirements." +--- + +C++20 concept + +The `resource_with` concept verifies that a type Resource satisfies the [`resource`](/library/api/cuda::mr::resource) concept and also satisfies all the provided Properties. + + +```cpp showLineNumbers={false} +template +concept resource_with = /* see description */; +``` + + + + + + + + + + + + + diff --git a/fern/pages/libcudacxx/deep_template_class_v4.mdx b/fern/pages/libcudacxx/deep_template_class_v5.mdx similarity index 89% rename from fern/pages/libcudacxx/deep_template_class_v4.mdx rename to fern/pages/libcudacxx/deep_template_class_v5.mdx index 9126960..0443ff6 100644 --- a/fern/pages/libcudacxx/deep_template_class_v4.mdx +++ b/fern/pages/libcudacxx/deep_template_class_v5.mdx @@ -104,7 +104,7 @@ _Start cuda::counting_iterator<_Start,,>::operator*() const ``` -**Returns:** `_Start` +**Returns:** The current value stored in this iterator. ### operator[] inline constexpr const noexcept nodiscard @@ -118,7 +118,7 @@ _Start2 cuda::counting_iterator<_Start,,>::operator[]( ``` -**Returns:** `_Start2` +**Returns:** The value at offset `__n` from the current stored value. **Parameters** @@ -143,7 +143,7 @@ counting_iterator& cuda::counting_iterator<_Start,,>::operator++() ``` -**Returns:** `counting_iterator &` +**Returns:** A reference to this [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator) after incrementing. @@ -154,7 +154,7 @@ auto cuda::counting_iterator<_Start,,>::operator++(int) ``` -**Returns:** `auto` +**Returns:** A copy of this [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator) before incrementing. @@ -172,7 +172,7 @@ counting_iterator& cuda::counting_iterator<_Start,,>::operator--() ``` -**Returns:** `counting_iterator &` +**Returns:** A reference to this [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator) after decrementing. @@ -183,7 +183,7 @@ counting_iterator cuda::counting_iterator<_Start,,>::operator--(int) ``` -**Returns:** `counting_iterator` +**Returns:** A copy of this [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator) before decrementing. @@ -204,7 +204,7 @@ counting_iterator& cuda::counting_iterator<_Start,,>::operator+=( ``` -**Returns:** `counting_iterator &` +**Returns:** A reference to this [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator) after advancing. **Parameters** @@ -224,7 +224,7 @@ counting_iterator& cuda::counting_iterator<_Start,,>::operator-=( ``` -**Returns:** `counting_iterator &` +**Returns:** A reference to this [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator) after retreating. **Parameters** diff --git a/fern/pages/libcudacxx/deep_template_class.mdx b/fern/pages/libcudacxx/deep_template_class_v6.mdx similarity index 63% rename from fern/pages/libcudacxx/deep_template_class.mdx rename to fern/pages/libcudacxx/deep_template_class_v6.mdx index 58c2622..df321e5 100644 --- a/fern/pages/libcudacxx/deep_template_class.mdx +++ b/fern/pages/libcudacxx/deep_template_class_v6.mdx @@ -7,30 +7,14 @@ A `counting_iterator` represents an iterator into a range of sequentially increa This iterator is useful for creating a range filled with a sequence without explicitly storing it in memory. Using `counting_iterator` saves memory capacity and bandwidth. +The following code snippet demonstrates how to create a `counting_iterator` whose [`value_type`](/libcudacxx/api/cuda::counting_iterator::value_type) is `int` + ```cpp showLineNumbers={false} #include ``` - - - - -The value type of the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator). - - -The remaining template parameters are SFINAE constraints that require `_Start` to model `::cuda::std::weakly_incrementable` and `::cuda::std::copyable`. They are not intended to be specified directly. - - - - -**Inherits from:** `__counting_iterator_category<_Start>` (public) - ---- - ## Example -The code snippet below demonstrates how to create a `counting_iterator` whose `value_type` is `int`. - ```cpp showLineNumbers={false} #include ... @@ -51,6 +35,18 @@ std::vector vec(500); std::copy(iter, iter + vec.size(), vec.begin()); ``` + + + + +The value type of the `counting_iterator`. + + + + + +**Inherits from:** `__counting_iterator_category< _Start >` (public) + --- ## Constructors @@ -60,16 +56,15 @@ std::copy(iter, iter + vec.size(), vec.begin()); -Default-constructs the stored value. - ```cpp showLineNumbers={false} -cuda::counting_iterator<_Start,,>::counting_iterator() +template +cuda::counting_iterator<_Start,,>::counting_iterator() noexcept(::cuda::std::is_nothrow_default_constructible_v< _Start2 >) ``` - + explicit @@ -79,14 +74,14 @@ Creates a `counting_iterator` from an initial value. ```cpp showLineNumbers={false} cuda::counting_iterator<_Start,,>::counting_iterator( _Start __value -) +) noexcept(::cuda::std::is_nothrow_move_constructible_v< _Start >) ``` **Parameters** -The value to store in the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator). +The value to store in the `counting_iterator` @@ -94,72 +89,63 @@ The value to store in the [`counting_iterator`](/libcudacxx/api/cuda::counting_i --- -## Element access - -### operator* inline constexpr const noexcept +## Methods -nodiscard +### operator* inline constexpr const noexcept nodiscard -Returns the value currently stored in the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator). +Returns the value currently stored in the `counting_iterator`. ```cpp showLineNumbers={false} -_Start cuda::counting_iterator<_Start,,>::operator*() const +_Start cuda::counting_iterator<_Start,,>::operator*() const noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Start >) ``` -**Returns:** `_Start` +### operator[] inline constexpr const noexcept nodiscard -### operator[] inline constexpr const noexcept - -nodiscard - -Returns the value currently stored in the [`counting_iterator`](/libcudacxx/api/cuda::counting_iterator) advanced by a number of steps. +Returns the value currently stored in the `counting_iterator` advanced by a number of steps. ```cpp showLineNumbers={false} +template _Start2 cuda::counting_iterator<_Start,,>::operator[]( difference_type __n -) const +) const noexcept(::cuda::std::is_nothrow_copy_constructible_v< _Start2 > &&noexcept(::cuda::std::declval< const _Start2 & >()+__n)) ``` -**Returns:** `_Start2` - **Parameters** -The amount of elements to advance. +The amount of elements to advance ---- - -## Increment operators - ### operator++ inline constexpr noexcept +Increments the stored value. + ```cpp showLineNumbers={false} -counting_iterator& cuda::counting_iterator<_Start,,>::operator++() +counting_iterator & cuda::counting_iterator<_Start,,>::operator++() noexcept(++::cuda::std::declval< _Start & >()) ``` -**Returns:** `counting_iterator &` - +Increments the stored value. + ```cpp showLineNumbers={false} -auto cuda::counting_iterator<_Start,,>::operator++(int) +auto cuda::counting_iterator<_Start,,>::operator++( + int +) noexcept(noexcept(++::cuda::std::declval< _Start & >()) &&::cuda::std::is_nothrow_copy_constructible_v< _Start >) ``` -**Returns:** `auto` - @@ -168,66 +154,67 @@ auto cuda::counting_iterator<_Start,,>::operator++(int) +Decrements the stored value. + ```cpp showLineNumbers={false} -counting_iterator& cuda::counting_iterator<_Start,,>::operator--() +template +counting_iterator & cuda::counting_iterator<_Start,,>::operator--() noexcept(--::cuda::std::declval< _Start2 & >()) ``` -**Returns:** `counting_iterator &` - +Decrements the stored value. + ```cpp showLineNumbers={false} -counting_iterator cuda::counting_iterator<_Start,,>::operator--(int) +template +counting_iterator cuda::counting_iterator<_Start,,>::operator--( + int +) noexcept(noexcept(--::cuda::std::declval< _Start2 & >()) &&::cuda::std::is_nothrow_copy_constructible_v< _Start >) ``` -**Returns:** `counting_iterator` - ---- - -## Compound assignment operators - ### operator+= inline constexpr noexcept +Increments the stored value by a given number of elements. + ```cpp showLineNumbers={false} -counting_iterator& cuda::counting_iterator<_Start,,>::operator+=( +counting_iterator & cuda::counting_iterator<_Start,,>::operator+=( difference_type __n -) +) noexcept(::cuda::std::__integer_like< _Start >) ``` -**Returns:** `counting_iterator &` - **Parameters** -The number of positions to advance. +The number of elements to increment ### operator-= inline constexpr noexcept +Decrements the stored value by a given number of elements. + ```cpp showLineNumbers={false} -counting_iterator& cuda::counting_iterator<_Start,,>::operator-=( +template +counting_iterator & cuda::counting_iterator<_Start,,>::operator-=( difference_type __n -) +) noexcept(::cuda::std::__integer_like< _Start2 >) ``` -**Returns:** `counting_iterator &` - **Parameters** -The number of positions to retreat. +The amount of elements to decrement --- @@ -238,8 +225,8 @@ The number of positions to retreat. | Name | Definition | |---|---| -| `iterator_concept` | `::cuda::std::conditional_t<__advanceable<_Start>, ::cuda::std::random_access_iterator_tag, ::cuda::std::conditional_t<__decrementable<_Start>, ::cuda::std::bidirectional_iterator_tag, ::cuda::std::conditional_t<::cuda::std::incrementable<_Start>, ::cuda::std::forward_iterator_tag, ::cuda::std::input_iterator_tag>>>` | +| `iterator_concept` | `::cuda::std::conditional_t< __advanceable< _Start >, ::cuda::std::random_access_iterator_tag, ::cuda::std::conditional_t< __decrementable< _Start >, ::cuda::std::bidirectional_iterator_tag, ::cuda::std::conditional_t<::cuda::std::incrementable< _Start >, ::cuda::std::forward_iterator_tag, ::cuda::std::input_iterator_tag > > >` | | `value_type` | `_Start` | -| `difference_type` | `_IotaDiffT<_Start>` | +| `difference_type` | `_IotaDiffT< _Start >` | | `reference` | `_Start` | | `pointer` | `void` | diff --git a/fern/pages/libcudacxx/empty_docstring_class_v4.mdx b/fern/pages/libcudacxx/empty_docstring_class_v5.mdx similarity index 93% rename from fern/pages/libcudacxx/empty_docstring_class_v4.mdx rename to fern/pages/libcudacxx/empty_docstring_class_v5.mdx index c169478..bfb8241 100644 --- a/fern/pages/libcudacxx/empty_docstring_class_v4.mdx +++ b/fern/pages/libcudacxx/empty_docstring_class_v5.mdx @@ -25,10 +25,10 @@ The properties the allocated memory satisfies. ## Constructors -### Copy and move constructors +### buffer - + inline explicit @@ -49,7 +49,7 @@ The other buffer. - + inline noexcept @@ -70,7 +70,7 @@ The other buffer. After move construction, the other buffer can only be assigned - + inline explicit @@ -92,7 +92,7 @@ The other buffer. - + inline noexcept @@ -114,11 +114,6 @@ The other buffer. After move construction, the other buffer can only be assigned - - -### Resource constructors - - inline @@ -330,7 +325,7 @@ reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( ``` -**Returns:** `reference` +**Returns:** A mutable reference to the element at position `__n`. **Parameters** @@ -353,7 +348,7 @@ const_reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( ``` -**Returns:** `const_reference` +**Returns:** A const reference to the element at position `__n`. **Parameters** @@ -379,7 +374,7 @@ pointer cuda::buffer<_Tp, _Properties>::data() noexcept ``` -**Returns:** `pointer` +**Returns:** A mutable pointer to the first element. @@ -394,7 +389,7 @@ const_pointer cuda::buffer<_Tp, _Properties>::data() const noexcept ``` -**Returns:** `const_pointer` +**Returns:** A const pointer to the first element. @@ -418,7 +413,7 @@ iterator cuda::buffer<_Tp, _Properties>::begin() noexcept ``` -**Returns:** `iterator` +**Returns:** A mutable iterator to the first element. @@ -433,7 +428,7 @@ const_iterator cuda::buffer<_Tp, _Properties>::begin() const noexcept ``` -**Returns:** `const_iterator` +**Returns:** An immutable iterator to the first element. @@ -448,7 +443,7 @@ const_iterator cuda::buffer<_Tp, _Properties>::cbegin() const noexcept ``` -**Returns:** `const_iterator` +**Returns:** An immutable iterator to the first element. ### end inline noexcept @@ -465,7 +460,7 @@ iterator cuda::buffer<_Tp, _Properties>::end() noexcept ``` -**Returns:** `iterator` +**Returns:** A mutable iterator to one past the last element. @@ -480,7 +475,7 @@ const_iterator cuda::buffer<_Tp, _Properties>::end() const noexcept ``` -**Returns:** `const_iterator` +**Returns:** An immutable iterator to one past the last element. @@ -495,7 +490,7 @@ const_iterator cuda::buffer<_Tp, _Properties>::cend() const noexcept ``` -**Returns:** `const_iterator` +**Returns:** An immutable iterator to one past the last element. ### rbegin inline noexcept @@ -512,7 +507,7 @@ reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() noexcept ``` -**Returns:** `reverse_iterator` +**Returns:** A mutable reverse iterator to the first element of the reversed buffer. @@ -527,7 +522,7 @@ const_reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() const noexcept ``` -**Returns:** `const_reverse_iterator` +**Returns:** An immutable reverse iterator to the first element of the reversed buffer. @@ -542,7 +537,7 @@ const_reverse_iterator cuda::buffer<_Tp, _Properties>::crbegin() const noexcept ``` -**Returns:** `const_reverse_iterator` +**Returns:** An immutable reverse iterator to the first element of the reversed buffer. ### rend inline noexcept @@ -559,7 +554,7 @@ reverse_iterator cuda::buffer<_Tp, _Properties>::rend() noexcept ``` -**Returns:** `reverse_iterator` +**Returns:** A mutable reverse iterator to one past the last element of the reversed buffer. @@ -574,7 +569,7 @@ const_reverse_iterator cuda::buffer<_Tp, _Properties>::rend() const noexcept ``` -**Returns:** `const_reverse_iterator` +**Returns:** An immutable reverse iterator to one past the last element of the reversed buffer. @@ -589,7 +584,7 @@ const_reverse_iterator cuda::buffer<_Tp, _Properties>::crend() const noexcept ``` -**Returns:** `const_reverse_iterator` +**Returns:** An immutable reverse iterator to one past the last element of the reversed buffer. --- @@ -605,7 +600,7 @@ size_type cuda::buffer<_Tp, _Properties>::size() const noexcept ``` -**Returns:** `size_type` +**Returns:** `size_type` -- the number of elements in the buffer. ### empty inline const noexcept nodiscard @@ -617,7 +612,7 @@ bool cuda::buffer<_Tp, _Properties>::empty() const noexcept ``` -**Returns:** `bool` +**Returns:** `true` if the buffer contains no elements, `false` otherwise. --- @@ -631,7 +626,7 @@ const __resource_t& cuda::buffer<_Tp, _Properties>::memory_resource() const noex ``` -**Returns:** `const __resource_t &` +**Returns:** A const reference to the memory resource used by this buffer. ### stream inline const constexpr noexcept nodiscard diff --git a/fern/pages/libcudacxx/empty_docstring_class_v6.mdx b/fern/pages/libcudacxx/empty_docstring_class_v6.mdx new file mode 100644 index 0000000..a967d86 --- /dev/null +++ b/fern/pages/libcudacxx/empty_docstring_class_v6.mdx @@ -0,0 +1,723 @@ +--- +title: "cuda::buffer" +description: "" +--- + +`buffer` is a container that provides resizable typed storage allocated from a given memory resource. It handles alignment, release and growth of the allocation. The elements are initialized during construction, which may require a kernel launch. + +In addition to being type-safe, `buffer` also takes a set of properties to ensure that e.g. execution space constraints are checked at compile time. However, only stateless properties can be forwarded. To use a stateful property, implement get_property(const buffer&, Property). + +```cpp showLineNumbers={false} +#include +``` + + + + + +The type to be stored in the buffer + + + +The properties the allocated memory satisfies + + + + + +--- + +## Constructors + +### buffer inline + + + + +explicit + +Copy-constructs from a buffer. + + +```cpp showLineNumbers={false} +cuda::buffer<_Tp, _Properties>::buffer( + const buffer &__other +) +``` + + +**Parameters** + + +The other buffer. + + + + + +noexcept + +Move-constructs from a buffer. + + +```cpp showLineNumbers={false} +cuda::buffer<_Tp, _Properties>::buffer( + buffer &&__other +) noexcept +``` + + +**Parameters** + + +The other buffer. After move construction, the other buffer can only be assigned to or destroyed. + + + + + +explicit + +Copy-constructs from a buffer with matching properties. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + const buffer<_Tp, _OtherProperties...> &__other +) +``` + + +**Parameters** + + +The other buffer. + + + + + +noexcept + +Move-constructs from a buffer with matching properties. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + buffer<_Tp, _OtherProperties...> &&__other +) noexcept +``` + + +**Parameters** + + +The other buffer. After move construction, the other buffer can only be assigned to or destroyed. + + + + + +Constructs an empty buffer using an environment. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + const _Env &__env = {} +) +``` + + + +No memory is allocated. + + +**Parameters** + + +The environment providing the needed information + + + + + +explicit + +Constructs a buffer of size `__size` using a memory and leaves all elements uninitialized. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + const size_type __size, + ::cuda::no_init_t, + const _Env &__env = {} +) +``` + + + +This constructor does *NOT* initialize any elements. It is the user's responsibility to ensure that the elements within `[vec.begin(), vec.end())` are properly initialized, e.g with `cuda::std::uninitialized_copy`. At the destruction of the `buffer` all elements in the range `[vec.begin(), vec.end())` will be destroyed. + + +**Parameters** + + +The size of the buffer. + + + +The environment used to query the memory resource. + + + + + +Constructs a buffer using a memory resource and copy-constructs all elements from the forward range `[__first, __last)`. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + _Iter __first, + _Iter __last, + const _Env &__env = {} +) +``` + + + +If `__first == __last` then no memory is allocated + + +**Parameters** + + +The start of the input sequence. + + + +The end of the input sequence. + + + +The environment used to query the memory resource. + + + + + +Constructs a buffer using a memory resource and copy-constructs all elements from `__ilist`. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + ::cuda::std::initializer_list<_Tp> __ilist, + const _Env &__env = {} +) +``` + + + +If `__ilist.size() == 0` then no memory is allocated + + +**Parameters** + + +The initializer_list being copied into the buffer. + + + +The environment used to query the memory resource. + + + + + +Constructs a buffer using a memory resource and an input range. + + +```cpp showLineNumbers={false} +template +cuda::buffer<_Tp, _Properties>::buffer( + ::cuda::stream_ref __stream, + _Resource &&__resource, + _Range &&__range, + const _Env &__env = {} +) +``` + + + +If `__range.size() == 0` then no memory is allocated. + + +**Parameters** + + +The input range to be moved into the buffer. + + + +The environment used to query the memory resource. + + + + + +--- + +## Assignment operators + +### operator= inline + +Move assignment operator. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::operator=( + buffer &&__other +) +``` + + +**Parameters** + + +The other buffer. After move assignment, the other buffer can only be assigned to or destroyed. + + +--- + +## Methods + +### begin inline noexcept nodiscard + + + + +Returns an iterator to the first element of the buffer. + +If the buffer is empty, the returned iterator will be equal to [end()](/libcudacxx/api/cuda::buffer::end()). + + +```cpp showLineNumbers={false} +iterator cuda::buffer<_Tp, _Properties>::begin() noexcept +``` + + + + + +const + +Returns an immutable iterator to the first element of the buffer. + +If the buffer is empty, the returned iterator will be equal to [end()](/libcudacxx/api/cuda::buffer::end()). + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::begin() const noexcept +``` + + + + + +### cbegin inline const noexcept nodiscard + +Returns an immutable iterator to the first element of the buffer. + +If the buffer is empty, the returned iterator will be equal to [end()](/libcudacxx/api/cuda::buffer::end()). + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::cbegin() const noexcept +``` + + +### end inline noexcept nodiscard + + + + +Returns an iterator to the element following the last element of the buffer. + +This element acts as a placeholder; attempting to access it results in undefined behavior. + + +```cpp showLineNumbers={false} +iterator cuda::buffer<_Tp, _Properties>::end() noexcept +``` + + + + + +const + +Returns an immutable iterator to the element following the last element of the buffer. + +This element acts as a placeholder; attempting to access it results in undefined behavior. + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::end() const noexcept +``` + + + + + +### cend inline const noexcept nodiscard + +Returns an immutable iterator to the element following the last element of the buffer. + +This element acts as a placeholder; attempting to access it results in undefined behavior. + + +```cpp showLineNumbers={false} +const_iterator cuda::buffer<_Tp, _Properties>::cend() const noexcept +``` + + +### rbegin inline noexcept nodiscard + + + + +Returns a reverse iterator to the first element of the reversed buffer. + +It corresponds to the last element of the non-reversed buffer. If the buffer is empty, the returned iterator is equal to [rend()](/libcudacxx/api/cuda::buffer::rend()). + + +```cpp showLineNumbers={false} +reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() noexcept +``` + + + + + +const + +Returns an immutable reverse iterator to the first element of the reversed buffer. + +It corresponds to the last element of the non-reversed buffer. If the buffer is empty, the returned iterator is equal to [rend()](/libcudacxx/api/cuda::buffer::rend()). + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::rbegin() const noexcept +``` + + + + + +### crbegin inline const noexcept nodiscard + +Returns an immutable reverse iterator to the first element of the reversed buffer. + +It corresponds to the last element of the non-reversed buffer. If the buffer is empty, the returned iterator is equal to [rend()](/libcudacxx/api/cuda::buffer::rend()). + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::crbegin() const noexcept +``` + + +### rend inline noexcept nodiscard + + + + +Returns a reverse iterator to the element following the last element of the reversed buffer. + +It corresponds to the element preceding the first element of the non-reversed buffer. This element acts as a placeholder, attempting to access it results in undefined behavior. + + +```cpp showLineNumbers={false} +reverse_iterator cuda::buffer<_Tp, _Properties>::rend() noexcept +``` + + + + + +const + +Returns an immutable reverse iterator to the element following the last element of the reversed buffer. + +It corresponds to the element preceding the first element of the non-reversed buffer. This element acts as a placeholder, attempting to access it results in undefined behavior. + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::rend() const noexcept +``` + + + + + +### crend inline const noexcept nodiscard + +Returns an immutable reverse iterator to the element following the last element of the reversed buffer. + +It corresponds to the element preceding the first element of the non-reversed buffer. This element acts as a placeholder, attempting to access it results in undefined behavior. + + +```cpp showLineNumbers={false} +const_reverse_iterator cuda::buffer<_Tp, _Properties>::crend() const noexcept +``` + + +### data inline noexcept nodiscard + + + + +Returns a pointer to the first element of the buffer. + +If the buffer has not allocated memory the pointer will be null. + + +```cpp showLineNumbers={false} +pointer cuda::buffer<_Tp, _Properties>::data() noexcept +``` + + + + + +const + +Returns a pointer to the first element of the buffer. + +If the buffer has not allocated memory the pointer will be null. + + +```cpp showLineNumbers={false} +const_pointer cuda::buffer<_Tp, _Properties>::data() const noexcept +``` + + + + + +### get_unsynchronized inline noexcept nodiscard + + + + +Returns a reference to the `__n` 'th element of the async_vector. + + +```cpp showLineNumbers={false} +reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( + const size_type __n +) noexcept +``` + + + +Does not synchronize with the stored stream + + +**Parameters** + + +The index of the element we want to access + + + + + +const + +Returns a reference to the `__n` 'th element of the async_vector. + + +```cpp showLineNumbers={false} +const_reference cuda::buffer<_Tp, _Properties>::get_unsynchronized( + const size_type __n +) const noexcept +``` + + + +Does not synchronize with the stored stream + + +**Parameters** + + +The index of the element we want to access + + + + + +### size inline const noexcept nodiscard + +Returns the current number of elements stored in the buffer. + + +```cpp showLineNumbers={false} +size_type cuda::buffer<_Tp, _Properties>::size() const noexcept +``` + + +### empty inline const noexcept nodiscard + +Returns true if the buffer is empty. + + +```cpp showLineNumbers={false} +bool cuda::buffer<_Tp, _Properties>::empty() const noexcept +``` + + +### memory_resource inline const noexcept nodiscard + +Returns a \c const reference to the any_resource that holds the memory resource used to allocate the buffer + + +```cpp showLineNumbers={false} +const __resource_t & cuda::buffer<_Tp, _Properties>::memory_resource() const noexcept +``` + + +### stream inline constexpr const noexcept nodiscard + +Returns the stored stream. + + +```cpp showLineNumbers={false} +stream_ref cuda::buffer<_Tp, _Properties>::stream() const noexcept +``` + + + +Stream used to allocate the buffer is initially stored in the buffer, but can be changed with [`set_stream`](/libcudacxx/api/cuda::buffer::set_stream) + + +### set_stream inline constexpr + +Replaces the stored stream. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::set_stream( + stream_ref __new_stream +) +``` + + + +Always synchronizes with the old stream + + +**Parameters** + + +The new stream + + +### swap inline noexcept + +Swaps the contents of a buffer with those of `__other`. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::swap( + buffer &__other +) noexcept +``` + + +**Parameters** + + +The other buffer. + + +### destroy inline + + + + +Destroys the buffer, deallocates the buffer and destroys the memory resource. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::destroy( + ::cuda::stream_ref __stream +) +``` + + + +After this explicit destroy call, the buffer can only be assigned to or destroyed. + + +**Parameters** + + +The stream to deallocate the buffer on. + + + + + +Destroys the buffer, deallocates the buffer and destroys the memory resource. + + +```cpp showLineNumbers={false} +void cuda::buffer<_Tp, _Properties>::destroy() +``` + + + +Uses the stored stream to deallocate the buffer, equivalent to calling [buffer.destroy](/libcudacxx/api/cuda::buffer::buffer.destroy)([buffer.stream()](/libcudacxx/api/cuda::buffer::buffer.stream())) + + + +After this explicit destroy call, the buffer can only be assigned to or destroyed. + + + + + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `value_type` | `_Tp` | +| `reference` | `_Tp &` | +| `const_reference` | `const _Tp &` | +| `pointer` | `_Tp *` | +| `const_pointer` | `const _Tp *` | +| `iterator` | `::cuda::heterogeneous_iterator< _Tp, _Properties... >` | +| `const_iterator` | `::cuda::heterogeneous_iterator< const _Tp, _Properties... >` | +| `reverse_iterator` | `::cuda::std::reverse_iterator< iterator >` | +| `const_reverse_iterator` | `::cuda::std::reverse_iterator< const_iterator >` | +| `size_type` | `::cuda::std::size_t` | +| `difference_type` | `::cuda::std::ptrdiff_t` | +| `properties_list` | `::cuda::mr::properties_list< _Properties... >` | diff --git a/fern/pages/libcudacxx/raises_example_v4.mdx b/fern/pages/libcudacxx/raises_example_v5.mdx similarity index 98% rename from fern/pages/libcudacxx/raises_example_v4.mdx rename to fern/pages/libcudacxx/raises_example_v5.mdx index ceda85a..9afd88f 100644 --- a/fern/pages/libcudacxx/raises_example_v4.mdx +++ b/fern/pages/libcudacxx/raises_example_v5.mdx @@ -118,6 +118,8 @@ cuda::stream::stream( +### Destructor + ### ~stream inline Destroy the [`stream`](/libcudacxx/api/cuda::stream) object. @@ -153,7 +155,7 @@ stream& cuda::stream::operator=( ``` -**Returns:** `stream &` +**Returns:** A reference to this [`stream`](/libcudacxx/api/cuda::stream) after assignment. `__other` is in a moved-from state. @@ -452,7 +454,7 @@ static stream cuda::stream::from_native_handle( ``` -**Returns:** `stream` +**Returns:** A new [`stream`](/libcudacxx/api/cuda::stream) object that owns the given handle. **Parameters** diff --git a/fern/pages/libcudacxx/raises_example.mdx b/fern/pages/libcudacxx/raises_example_v6.mdx similarity index 52% rename from fern/pages/libcudacxx/raises_example.mdx rename to fern/pages/libcudacxx/raises_example_v6.mdx index 5d77439..66542eb 100644 --- a/fern/pages/libcudacxx/raises_example.mdx +++ b/fern/pages/libcudacxx/raises_example_v6.mdx @@ -3,13 +3,13 @@ title: "cuda::stream" description: "An owning wrapper for cudaStream_t providing RAII-based stream lifecycle management." --- -An owning wrapper for `cudaStream_t`. +An owning wrapper for cudaStream_t. ```cpp showLineNumbers={false} #include ``` -**Inherits from:** [`cuda::stream_ref`](/libcudacxx/api/cuda::stream_ref) (public) +**Inherits from:** `cuda::stream_ref` (public) --- @@ -18,13 +18,13 @@ An owning wrapper for `cudaStream_t`. ### stream inline - + explicit Constructs a stream on a specified device and with specified priority. -Priority is defaulted to [`stream::default_priority`](/libcudacxx/api/cuda::stream::default_priority). +Priority is defaulted to [stream::default_priority](/libcudacxx/api/cuda::stream::default_priority) ```cpp showLineNumbers={false} @@ -35,24 +35,14 @@ cuda::stream::stream( ``` -**Throws:** `cuda_error` if stream creation fails. - -**Parameters** - - -The device on which to create the stream. - - - -The priority of the stream. - +**Throws:** `cuda_error` if stream creation fails explicit noexcept -Construct a new [`stream`](/libcudacxx/api/cuda::stream) object into the moved-from state. +Construct a new `stream` object into the moved-from state. ```cpp showLineNumbers={false} @@ -62,16 +52,16 @@ cuda::stream::stream( ``` - -[`stream()`](/libcudacxx/api/cuda::stream::stream) returns an invalid stream handle. - + +[`stream()`](/libcudacxx/api/cuda::stream::stream()) returns an invalid stream handle + noexcept -Move-construct a new [`stream`](/libcudacxx/api/cuda::stream) object. +Move-construct a new `stream` object. ```cpp showLineNumbers={false} @@ -81,15 +71,9 @@ cuda::stream::stream( ``` - + `__other` is in moved-from state. - - -**Parameters** - - -The stream to move from. - + @@ -118,9 +102,11 @@ cuda::stream::stream( +### Destructor + ### ~stream inline -Destroy the [`stream`](/libcudacxx/api/cuda::stream) object. +Destroy the `stream` object. ```cpp showLineNumbers={false} @@ -128,49 +114,39 @@ cuda::stream::~stream() ``` - + If the stream fails to be destroyed, the error is silently ignored. - + --- ## Assignment operators -### operator= inline +### operator= inline noexcept -noexcept - -Move-assign a [`stream`](/libcudacxx/api/cuda::stream) object. +Move-assign a `stream` object. ```cpp showLineNumbers={false} -stream& cuda::stream::operator=( +stream & cuda::stream::operator=( stream &&__other ) noexcept ``` -**Returns:** `stream &` - - + `__other` is in a moved-from state. - - -**Parameters** - - -The stream to move from. - + ```cpp showLineNumbers={false} -stream& cuda::stream::operator=( +stream & cuda::stream::operator=( const stream & ) = delete ``` @@ -181,11 +157,9 @@ stream& cuda::stream::operator=( --- -## Ownership - -### release inline +## Methods -nodiscard +### release inline nodiscard Retrieve the native `cudaStream_t` handle and give up ownership. @@ -195,41 +169,29 @@ Retrieve the native `cudaStream_t` handle and give up ownership. ``` -**Returns:** `cudaStream_t` -- the native handle being held by the [`stream`](/libcudacxx/api/cuda::stream) object. - - + The stream object is in a moved-from state. - - ---- + -## Accessors +**Returns:** cudaStream_t The native handle being held by the `stream` object. -### get inline constexpr const noexcept - -nodiscard +### get inline constexpr const noexcept nodiscard Returns the wrapped `cudaStream_t` handle. - + ```cpp showLineNumbers={false} -value_type cuda::stream_ref::get() const noexcept +value_type cuda::stream::get() const noexcept ``` -**Returns:** [`value_type`](/libcudacxx/api/cuda::stream_ref::value_type) - ---- - -## Synchronization - ### sync inline const Synchronizes the wrapped stream. - + ```cpp showLineNumbers={false} -void cuda::stream_ref::sync() const +void cuda::stream::sync() const ``` @@ -238,78 +200,74 @@ void cuda::stream_ref::sync() const ### wait inline const - + Deprecated. - +Use [sync()](/libcudacxx/api/cuda::stream_ref::sync()) instead. + + ```cpp showLineNumbers={false} -void cuda::stream_ref::wait() const +void cuda::stream::wait() const ``` - -Use [`sync()`](/libcudacxx/api/cuda::stream_ref::sync) instead. - + +Use [sync()](/libcudacxx/api/cuda::stream_ref::sync()) instead. + - + Make all future work submitted into this stream depend on completion of the specified event. - + ```cpp showLineNumbers={false} -void cuda::stream_ref::wait( +void cuda::stream::wait( event_ref __ev ) const ``` -**Throws:** `cuda_error` if inserting the dependency fails. +**Throws:** `cuda_error` if inserting the dependency fails **Parameters** -Event that this stream should wait for. +Event that this stream should wait for - + Make all future work submitted into this stream depend on completion of all work from the specified stream. - + ```cpp showLineNumbers={false} -void cuda::stream_ref::wait( +void cuda::stream::wait( stream_ref __other ) const ``` -**Throws:** `cuda_error` if inserting the dependency fails. +**Throws:** `cuda_error` if inserting the dependency fails **Parameters** -Stream that this stream should wait for. +Stream that this stream should wait for ---- - -## Query methods - -### is_done inline const - -nodiscard +### is_done inline const nodiscard Queries if all operations on the stream have completed. - + ```cpp showLineNumbers={false} -bool cuda::stream_ref::is_done() const +bool cuda::stream::is_done() const ``` @@ -317,15 +275,13 @@ bool cuda::stream_ref::is_done() const **Throws:** `cuda::cuda_error` if the query fails. -### ready inline const - -nodiscard +### ready inline const nodiscard Queries if all operations on the wrapped stream have completed. - + ```cpp showLineNumbers={false} -bool cuda::stream_ref::ready() const +bool cuda::stream::ready() const ``` @@ -333,135 +289,106 @@ bool cuda::stream_ref::ready() const **Throws:** `cuda::cuda_error` if the query fails. -### priority inline const - -nodiscard +### priority inline const nodiscard Queries the priority of the wrapped stream. - + ```cpp showLineNumbers={false} -int cuda::stream_ref::priority() const +int cuda::stream::priority() const ``` -**Returns:** Value representing the priority of the wrapped stream. +**Returns:** value representing the priority of the wrapped stream. **Throws:** `cuda::cuda_error` if the query fails. -### id inline const - -nodiscard +### id inline const nodiscard Get the unique ID of the stream. - -```cpp showLineNumbers={false} -stream_id cuda::stream_ref::id() const -``` - - -**Returns:** The unique ID of the stream. +Stream handles are sometimes reused, but ID is guaranteed to be unique. -**Throws:** `cuda_error` if the ID query fails. - -### query inline constexpr const noexcept - -nodiscard - -Queries the `stream_ref` for itself. - - + ```cpp showLineNumbers={false} -stream_ref cuda::stream_ref::query( - const ::cuda::get_stream_t & -) const noexcept +stream_id cuda::stream::id() const ``` -**Returns:** [`stream_ref`](/libcudacxx/api/cuda::stream_ref) - ---- - -## Event recording +**Returns:** The unique ID of the stream -### record_event inline const +**Throws:** `cuda_error` if the ID query fails -nodiscard +### record_event inline const nodiscard Create a new event and record it into this stream. - + ```cpp showLineNumbers={false} -event cuda::stream_ref::record_event( +event cuda::stream::record_event( event_flags __flags = event_flags::none ) const ``` -**Returns:** A new event that was recorded into this stream. +**Returns:** A new event that was recorded into this stream -**Throws:** `cuda_error` if event creation or record failed. +**Throws:** `cuda_error` if event creation or record failed -**Parameters** - - -Flags for event creation. - - -### record_timed_event inline const - -nodiscard +### record_timed_event inline const nodiscard Create a new timed event and record it into this stream. - + ```cpp showLineNumbers={false} -timed_event cuda::stream_ref::record_timed_event( +timed_event cuda::stream::record_timed_event( event_flags __flags = event_flags::none ) const ``` -**Returns:** A new timed event that was recorded into this stream. +**Returns:** A new timed event that was recorded into this stream -**Throws:** `cuda_error` if event creation or record failed. +**Throws:** `cuda_error` if event creation or record failed -**Parameters** +### device inline const nodiscard - -Flags for event creation. - +Get device under which this stream was created. ---- +Note: In case of a stream created under a `green_context` the device on which that `green_context` was created is returned -## Device information + +```cpp showLineNumbers={false} +device_ref cuda::stream::device() const +``` + -### device inline const +**Throws:** `cuda_error` if device check fails -nodiscard +### query inline constexpr const noexcept nodiscard -Get device under which this stream was created. +Queries the [`stream_ref`](/libcudacxx/api/cuda::stream_ref) for itself. - +This makes [`stream_ref`](/libcudacxx/api/cuda::stream_ref) usable in places where we expect an environment with a [`get_stream_t`](/libcudacxx/api/cuda::get_stream_t) query + + ```cpp showLineNumbers={false} -device_ref cuda::stream_ref::device() const +stream_ref cuda::stream::query( + const ::cuda::get_stream_t & +) const noexcept ``` -**Returns:** [`device_ref`](/libcudacxx/api/cuda::device_ref) - -**Throws:** `cuda_error` if device check fails. - --- ## Static methods -### from_native_handle inline static +### from_native_handle inline static nodiscard -nodiscard + + -Construct an [`stream`](/libcudacxx/api/cuda::stream) object from a native `cudaStream_t` handle and take ownership. +Construct an `stream` object from a native `cudaStream_t` handle. ```cpp showLineNumbers={false} @@ -471,29 +398,33 @@ static stream cuda::stream::from_native_handle( ``` -**Returns:** `stream` + +The constructed `stream` object takes ownership of the native handle. + + +**Returns:** stream The constructed `stream` object **Parameters** -The native handle. +The native handle + + + The following overloads are deleted to prevent misuse: + ```cpp showLineNumbers={false} static stream cuda::stream::from_native_handle(int) = delete; static stream cuda::stream::from_native_handle(::cuda::std::nullptr_t) = delete; static stream cuda::stream::from_native_handle(invalid_stream_t) = delete; ``` + ---- - -## Member variables - -| Name | Type | Description | -|---|---|---| -| `default_priority` static constexpr | `int` | The default stream priority. | + + --- @@ -504,3 +435,11 @@ static stream cuda::stream::from_native_handle(invalid_stream_t) = delete; | Name | Definition | |---|---| | `value_type` | `::cudaStream_t` | + +--- + +## Member variables + +| Name | Type | Description | +|---|---|---| +| `default_priority` static constexpr | `int` | | diff --git a/fern/pages/rendering_rules.md b/fern/pages/rendering_rules.md new file mode 100644 index 0000000..6b37693 --- /dev/null +++ b/fern/pages/rendering_rules.md @@ -0,0 +1,1113 @@ +# C++ Library Docs Rendering Rules + +This document is the formal rendering specification for generating MDX pages from the C++ Library Docs IR. Every rule is extracted from the 12 golden pages and the resolved inconsistency decisions. + +--- + +## Frontmatter + +Every page begins with YAML frontmatter containing exactly two fields: `title` and `description`. + +**Rule:** +- `title`: The fully-qualified C++ name (e.g., `cub::BlockReduce`, `thrust::device_vector`, `cuda::counting_iterator`). + - For names containing angle brackets or special characters, wrap in double quotes. +- `description`: A one-sentence summary of the entity, wrapped in double quotes. + +**Example** (from `block_reduce_v3.mdx`): +```yaml +--- +title: cub::BlockReduce +description: "Collective methods for computing parallel reductions across a CUDA thread block." +--- +``` + +**Example with special chars** (from `concept_example_v3.mdx`): +```yaml +--- +title: "cuda::mr::resource_with" +description: "A concept that verifies a memory resource satisfies both the resource concept and a set of property requirements." +--- +``` + +**Edge cases:** +- If the title contains `::` only (no angle brackets), quoting is optional but recommended for consistency. +- Always quote the description. + +--- + +## Page Layout -- Class/Struct Pages + +The canonical section order for class/struct pages is: + +### Preamble (before the first `---` separator) + +1. **Summary paragraph(s)** -- plain text describing the entity +2. **Include header** -- a bare `cpp` code block (no CodeBlock component) +3. **Callouts** -- ``, ``, ``, `` for deprecated/behavioral notes +4. **See also** -- bold label `**See also:**` followed by comma-separated links +5. **Example** -- class-level example with optional introductory sentence + code block +6. **Performance considerations** -- `## Performance considerations` section (only for CUB classes) +7. **Template parameters** -- wrapped in `` +8. **Inherits from** -- bold text `**Inherits from:**` followed by type and access specifier +9. Additional notes (e.g., "This class is marked `final`.") + +**Important:** The `---` horizontal rule separates the preamble from the body sections. + +### Body Sections (after first `---`) + +Sections appear as `## SectionLabel` headings, separated by `---` horizontal rules. The order follows the IR's `sections` array. Common section orderings observed: + +1. `## Constructors` (or `## Collective constructors`) +2. `## Assignment operators` +3. `## Element access` +4. `## Iterators` +5. `## Capacity` +6. `## Modifiers` +7. `## Static methods` +8. `## Friend functions` +9. `## Types` (typedefs) +10. `## Member variables` +11. `## Inner classes` + +**Example section order** (from `device_vector_v3.mdx`): +``` +## Constructors +--- +## Assignment operators +--- +## Element access +--- +## Iterators +--- +## Capacity +--- +## Modifiers +--- +## Allocator +--- +## Types +``` + +**Edge cases:** +- If a class has no members in a category, that section is omitted entirely. +- The `Performance considerations` and `Example` sections in the preamble are H2 headings (part of the preamble, before the first `---`). + +--- + +## Page Layout -- Concept Pages + +Concept pages have a different structure since concepts have no constructors, methods, or member variables. + +### Preamble + +1. **Entity-kind Badge** -- `C++20 concept` on its own line +2. **Summary paragraph(s)** +3. **Concept definition signature** -- in a `` with links +4. **Template parameters** -- in `` + +### Body Sections + +1. `## Description` -- detailed explanation +2. `## Related concepts` -- table of related concepts + +**Example** (from `concept_example_v3.mdx`): +```mdx +C++20 concept + +The `resource_with` concept verifies that ... + + +```cpp showLineNumbers={false} +template +concept resource_with = /* see description */; +``` + + + + +... + + + +--- + +## Description + +... + +--- + +## Related concepts + +| Concept | Description | +|---|---| +| [`cuda::mr::resource`](/library/api/cuda::mr::resource) | ... | +``` + +--- + +## Page Layout -- Simple Struct Pages + +Simple structs (like `cub::ArgMax` from `simple_struct_v4.mdx`) omit sections that don't apply. They follow the same canonical order but only include what's present. + +**Example** (from `simple_struct_v4.mdx`): +``` +Preamble: summary only (no include header, no template params) +--- +## Methods +``` + +--- + +## Preamble Sections + +### Summary + +**Rule:** One or more paragraphs of plain text. The first paragraph should be a concise description. Subsequent paragraphs provide additional detail. Inline code uses backticks. Cross-references use markdown links. + +**Example** (from `device_vector_v3.mdx`): +```mdx +A `device_vector` is a container that supports random access to elements, constant time removal of elements at the end, and linear time insertion and removal of elements at the beginning or in the middle. + +The number of elements in a `device_vector` may vary dynamically; memory management is automatic. The memory associated with a `device_vector` resides in the memory accessible to devices. +``` + +**Example with numbered list** (from `block_reduce_v3.mdx`): +```mdx +`BlockReduce` can be optionally specialized by algorithm to accommodate different latency/throughput workload profiles: + +1. [`cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY`](/library/api/cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY): + An efficient "raking" reduction algorithm that only supports commutative reduction operators. +2. [`cub::BLOCK_REDUCE_RAKING`](/library/api/cub::BLOCK_REDUCE_RAKING): + ... +``` + +### Include Header + +**Rule:** Render as a bare fenced code block with `cpp showLineNumbers={false}`. No `` component. Appears after the summary paragraphs. + +**Example** (from `device_vector_v3.mdx`): +```mdx +```cpp showLineNumbers={false} +#include +``` +``` + +**Edge cases:** +- Not all entities have include headers. If absent, omit entirely. +- Simple structs (like `cub::ArgMax`) may not have include headers. + +### Callouts (Deprecated, Warnings, Notes) + +**Rule:** Always use callout components. Never use plain bullet points for behavioral notes, warnings, or deprecation notices. + +| Callout type | Usage | +|---|---| +| `` | Deprecation notices | +| `` | Dangerous behavior, user responsibility warnings | +| `` | Important behavioral notes, clarifications | +| `` | Postcondition descriptions | +| `` | Other informational callouts with custom titles | + +**Deprecated example** (from `deprecated_example_v4.mdx`): +```mdx + +Use `cuda::strided_iterator` instead. + +``` + +**Note example** (from `pointer_v4.mdx`): +```mdx + +`pointer` is not a smart pointer; it is the client's responsibility to deallocate memory pointer to by `pointer`. + +``` + +**Warning example** (from `empty_docstring_class_v4.mdx` / `raises_example_v4.mdx`): +```mdx + +This constructor does **NOT** initialize any elements. It is the user's responsibility to ensure that the elements within `[vec.begin(), vec.end())` are properly initialized. + +``` + +**Info/Postconditions example** (from `raises_example_v4.mdx`): +```mdx + +`__other` is in moved-from state. + +``` + +**Edge cases:** +- Callouts appear in the preamble AND inside method/overload tabs. +- Within tabs, callouts appear after the CodeBlock and before **Parameters**. + +### See Also + +**Rule:** Render as `**See also:**` (bold) followed by a line of comma-separated links. External URLs use full markdown links. Internal refs use relative link paths. + +**Example** (from `device_vector_v3.mdx`): +```mdx +**See also:** +[https://en.cppreference.com/w/cpp/container/vector](https://en.cppreference.com/w/cpp/container/vector), +[device_allocator](/library/api/thrust::device_allocator), +[host_vector](/library/api/thrust::host_vector), +[universal_vector](/library/api/thrust::universal_vector) +``` + +**Example** (from `pointer_v4.mdx`): +```mdx +**See also:** +[device_ptr](/library/api/thrust::device_ptr), +reference, +[raw_pointer_cast](/library/api/thrust::raw_pointer_cast) +``` + +**Edge cases:** +- If a see-also item has no link target, render as plain text (e.g., `reference` in the pointer example). +- Each link on its own line, comma-separated. + +### Example (class-level) + +**Rule:** Render as an `## Example` H2 heading followed by an optional introductory sentence and a fenced code block (`cpp showLineNumbers={false}`). Class-level examples use a bare code block (no `` component). + +**Example** (from `block_reduce_v3.mdx`): +```mdx +## Example + +The code snippet below illustrates a sum reduction of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + ... +} +``` +``` + +**Edge cases:** +- Some pages place the example before template parameters, some after. The canonical order is: example comes before template parameters when it's a class-level example in the preamble. +- `## Performance considerations` if present comes before `## Example`. + +### Template Parameters + +**Rule:** Wrapped in ``. Each parameter uses a `` component. + +**Example** (from `block_reduce_v3.mdx`): +```mdx + + + + +Data type being reduced + + + +The thread block length in threads along the X dimension + + + +**[optional]** [cub::BlockReduceAlgorithm](/library/api/cub::BlockReduceAlgorithm) enumerator specifying the underlying algorithm to use (default: [cub::BLOCK_REDUCE_WARP_REDUCTIONS](/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS)) + + + + +``` + +**Edge cases:** +- Optional parameters start their description with `**[optional]**`. +- Default values use the `default` prop on ``. + +### Base Classes / Inherits From + +**Rule:** Render as `**Inherits from:**` (bold) followed by the base class type and access specifier in parentheses. + +**Example** (from `device_vector_v3.mdx`): +```mdx +**Inherits from:** `detail::vector_base< T, thrust::device_allocator< T > >` (public) +``` + +**Example with link** (from `pointer_v4.mdx`): +```mdx +**Inherits from:** [`thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >`](/library/api/thrust::iterator_adaptor) (public) +``` + +**Example with multiple bases** (from `group_member_example_v4.mdx`): +```mdx +**Inherits from:** [`thrust::mr::memory_resource< Upstream::pointer >`](/library/api/thrust::mr::memory_resource) (public), [`thrust::mr::validator2< Upstream, Bookkeeper >`](/library/api/thrust::mr::validator2) (private) +``` + +**Edge cases:** +- If a base class has a known API page, link it. +- Multiple base classes are comma-separated. +- If the class is `final`, add a separate line: `This class is marked final.` + +--- + +## Method Rendering + +### Method Heading (H3 + Badges) + +**Rule:** Methods are rendered as H3 headings. Qualifier badges that apply to ALL overloads of the method go on the H3 heading line. Qualifiers specific to individual overloads go inside the corresponding Tab. + +**Format:** `### MethodName qualifier` + +**Example -- all-overload qualifiers on H3** (from `block_reduce_v3.mdx`): +```mdx +### Reduce inline +``` + +**Example -- multiple qualifiers** (from `pointer_v4.mdx`): +```mdx +### get inline const +``` + +**Example -- overload-specific qualifier inside Tab** (from `device_vector_v3.mdx`): +```mdx + + +const + +Subscript read access to the data contained in this vector. +``` + +**Example -- overload-specific qualifier inside Tab** (from `deprecated_example_v4.mdx`): +```mdx + + +inline + +Creates a [strided_iterator](/library/api/thrust::strided_iterator) from an existing iterator and a stride. +``` + +**Badge placement decision rule:** +- If a qualifier (e.g., `inline`, `const`, `static`, `virtual`, `noexcept`, `nodiscard`, `explicit`, `constexpr`) applies to every overload of the method, place it on the H3. +- If it applies only to some overloads, place it as a standalone Badge line at the top of the relevant Tab (before the description text). + +### Signature (CodeBlock with links) + +**Rule:** Method signatures are always wrapped in a `` component. The links prop maps identifier names to their API page paths. + +**Format:** +```mdx + +```cpp showLineNumbers={false} +ReturnType namespace::ClassName::MethodName( + ParamType1 param1, + ParamType2 param2 +) +``` + +``` + +**Example** (from `block_reduce_v3.mdx`): +```mdx + +```cpp showLineNumbers={false} +template +T cub::BlockReduce::Reduce( + T input, + ReductionOp reduction_op +) +``` + +``` + +**Rules for the links prop:** +- Include the owning class short name (e.g., `"BlockReduce"`, `"device_vector"`). +- Include any parameter types or return types that have known API pages. +- If a method is inherited, include both the current class and the base class in links. + +**Example -- inherited method with two link targets** (from `pointer_v4.mdx`): +```mdx + +```cpp showLineNumbers={false} +Base const & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base() const +``` + +``` + +### Parameters + +**Rule:** Always use `**Parameters**` (bold text), never `#### Parameters` (H4). + +**Format:** +```mdx +**Parameters** + + +Description of the parameter. + +``` + +**Example** (from `block_reduce_v3.mdx`): +```mdx +**Parameters** + + +Calling thread's input + + + +Binary reduction functor + +``` + +**Edge cases:** +- If a parameter has a default value, use the `default` prop: ``. +- Empty ParamFields (no description text) should be omitted entirely. However, if the parameter must be listed but genuinely has no description, render with empty content: `\n`. + +### Template Parameters (method-level) + +**Rule:** Same format as Parameters but with `**Template parameters**` heading. + +**Example** (from `block_reduce_v3.mdx`): +```mdx +**Template parameters** + + +**[inferred]** Binary reduction functor type having member `T operator()(const T &a, const T &b)` + +``` + +**Edge cases:** +- Inferred template parameters start with `**[inferred]**`. + +### Returns + +**Rule:** Always use prose description format: `**Returns:** Description text`. Include type in backticks or as a link within the prose. Never use bare type-only returns. + +**Examples:** + +Prose with type (from `block_reduce_v3.mdx`): +```mdx +**Returns:** Reference to [_TempStorage](/library/api/cub::BlockReduce::_TempStorage) +``` + +Prose with inline type (from `device_vector_v3.mdx`): +```mdx +**Returns:** Read/write reference to data. +``` + +Type with dash separator (from `device_vector_v3.mdx`): +```mdx +**Returns:** `size_type` -- the number of elements. +``` + +Short type returns (from `raises_example_v4.mdx`): +```mdx +**Returns:** `stream &` +``` + +Link-based returns (from `raises_example_v4.mdx`): +```mdx +**Returns:** [`stream_ref`](/libcudacxx/api/cuda::stream_ref) +``` + +Boolean returns (from `device_vector_v3.mdx`): +```mdx +**Returns:** `true` if [size()](/library/api/thrust::device_vector::size) == 0; `false`, otherwise. +``` + +### Throws + +**Rule:** Render as `**Throws:**` followed by the exception type and condition. + +**Example** (from `device_vector_v3.mdx`): +```mdx +**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). +``` + +**Example** (from `raises_example_v4.mdx`): +```mdx +**Throws:** `cuda_error` if stream creation fails. +``` + +### Callouts (preconditions, postconditions, notes, warnings) + +**Rule:** Use callout components within method bodies/tabs. These appear AFTER the CodeBlock and version annotation, BEFORE **Parameters**. + +**Behavioral notes as bullet list with callout components:** + +Actually, looking at the golden pages, behavioral notes within method tabs are rendered as plain bullet lists (not callouts). The resolved decision says to use callout components. However, in many golden pages, behavioral notes within overload tabs use plain bullet lists: + +**Current golden page pattern** (from `block_reduce_v3.mdx`): +```mdx +- The return value is undefined in threads other than thread0. +- Assumes threads are in row-major order. +- The block-wide aggregate of `temp_storage` is undefined after calling this method ... +``` + +**Resolved decision:** Standardize on callout components everywhere. However, for method-level behavioral notes that are short lists, the golden pages use plain bullet lists. The resolved decision (#1) says to use callout components (``, ``, etc.) and never plain bullet points. Apply this decision: + +- Method-level behavioral notes --> wrap in `` callout +- Warnings --> `` +- Postconditions --> `` +- Deprecation --> `` + +**Postconditions example** (from `raises_example_v4.mdx`): +```mdx + +`__other` is in moved-from state. + +``` + +**Note within Tab** (from `empty_docstring_class_v4.mdx`): +```mdx + +No memory is allocated. + +``` + +### Examples (method-level) + +**Rule:** Render as `**Example**` (bold text) followed by an optional intro sentence and a bare fenced code block (`cpp showLineNumbers={false}`). Method-level examples do NOT use the `` component. + +**Example** (from `block_reduce_v3.mdx`): +```mdx +**Example** + +The code snippet below illustrates a max reduction of 128 integer items that are partitioned across 128 threads. + +```cpp showLineNumbers={false} +#include // or equivalently + +__global__ void ExampleKernel(...) +{ + ... +} +``` +``` + +--- + +## Overloads + +### Tab Structure + +**Rule:** When a method has multiple overloads, wrap them in `...`. Each Tab contains the full rendering of one overload: description, CodeBlock, version annotation, callouts, template parameters, parameters, returns, throws, example. + +**Format:** +```mdx + + + +Description of this overload. + + +```cpp showLineNumbers={false} +... +``` + + +*Version annotation* + +**Parameters** + + +... + + + + +... + + +``` + +### Tab Titles Convention + +**Rule:** Tab titles should be short, descriptive labels. Conventions: + +| Pattern | Title format | Examples | +|---|---|---| +| Default constructor | `"Default"` or `"Default constructor"` | `"Default"`, `"Default constructor"` | +| Copy constructors | `"From X"` where X describes the source | `"From device_vector"`, `"From std::vector"`, `"From other device_vector type"`, `"From initializer_list"` | +| Move constructors | `"Move"` or `"Move with allocator"` | `"Move"`, `"Move with allocator"` | +| Constructors with config | `"With X"` where X is the config | `"With TempStorage"`, `"With allocator"`, `"With upstream and bookkeeper"`, `"With default resources"` | +| Const/Mutable pairs | `"Mutable"` / `"Const"` | `"Mutable"`, `"Const"` | +| Size variants | `"Single item"` / `"Multiple items per thread"` / `"Partial tile"` | as shown | +| Functional description | Short phrase | `"Fill"`, `"Range"`, `"Value-initialized"`, `"Default-initialized"`, `"No-init"` | +| Pre/Post increment | `"Pre-increment"` / `"Post-increment"` | as shown | +| Assignment variants | `"Copy assign"` / `"Move assign"` | as shown | +| Deleted overloads | `"Copy (deleted)"` / `"Copy assign (deleted)"` / `"Deleted overloads"` | as shown | + +**Key rules:** +- Be concise but clear. +- Use noun phrases, not sentences. +- When there are Mutable/Const pairs, always use exactly `"Mutable"` and `"Const"`. + +--- + +## Constructors + +### Grouping + +**Rule:** Always use a single H3 with one `` group for all constructor overloads in a logical group. Do NOT subdivide into named H3 groups per overload. + +However, when a class has MANY constructors with clearly distinct categories (like `device_vector` with 20+ constructors), they MAY be split into multiple named H3 groups, each with its own ``: + +**Example of multiple constructor groups** (from `device_vector_v3.mdx`): +```mdx +### Default and allocator constructors +... + +### Size constructors +... + +### Fill constructors +... + +### Copy constructors +... + +### Move constructors +... + +### Initializer list constructors +... + +### Range constructors +... +``` + +**Resolution clarification:** The resolved decision (#9) says "Always use a single H3 with one Tabs group." However, the `device_vector` golden page uses multiple named H3 groups. The rule is: **prefer a single H3 with one Tabs group.** Only use multiple H3 groups when there are so many constructors (>6) that a single Tabs group would be unwieldy, AND the constructors fall into clearly distinct semantic categories. When using multiple H3 groups, the H3 title describes the category (e.g., "Copy constructors"), not individual overloads. + +**Example of single H3 with Tabs** (from `block_reduce_v3.mdx`): +```mdx +### BlockReduce inline + + + +... + + +... + + +``` + +### Destructor + +**Rule:** Always include `### Destructor` as a label heading before the `### ~ClassName` heading. + +**Example** (from `device_vector_v3.mdx`): +```mdx +### Destructor + +### ~device_vector inline + +The destructor erases the elements. + + +```cpp showLineNumbers={false} +thrust::device_vector::~device_vector() +``` + +``` + +**Example** (from `group_member_example_v4.mdx`): +```mdx +### Destructor + +### ~disjoint_unsynchronized_pool_resource inline + +Destructor. Releases all held memory to upstream. + + +```cpp showLineNumbers={false} +thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::~disjoint_unsynchronized_pool_resource() +``` + +``` + +**Edge cases:** +- The destructor appears at the end of the Constructors section, after all constructor tabs. +- If the destructor has no description, still include the `### Destructor` label and the `### ~ClassName` heading with the CodeBlock. +- Destructors without overloads do NOT use Tabs. +- The `### ~ClassName` heading on the destructor line (from `raises_example_v4.mdx`): +```mdx +### ~stream inline + +Destroy the [`stream`](/libcudacxx/api/cuda::stream) object. +``` +Note: In the `raises_example_v4.mdx` golden page, the destructor is rendered WITHOUT the `### Destructor` label heading. The resolved decision (#5) says to ALWAYS include it. So the canonical form is to always add `### Destructor` before `### ~ClassName`. + +--- + +## Components + +### Callout Types and When to Use Each + +| Component | When to use | +|---|---| +| `` | General informational notes, behavioral clarifications, non-critical operational details | +| `` | Dangerous behavior, user responsibility, potential misuse consequences | +| `` | Deprecation notices (always with `title="Deprecated"`) | +| `` | Postconditions, pre-conditions, general informational callouts | +| `` | Specifically for method postconditions | + +### Badge -- Qualifier Badges vs Entity-Kind Badges + +**Qualifier badges** use `intent="note" minimal`: +```mdx +inline +const +static +virtual +explicit +constexpr +noexcept +nodiscard +final +``` + +**Entity-kind badges** use `intent="info"` (NOT minimal): +```mdx +C++20 concept +``` + +**Placement:** +- Qualifier badges on H3 headings: separated by spaces, on the same line +- Qualifier badges inside Tabs: standalone line at the top of the Tab, before description +- Entity-kind badges: in the preamble, on their own line before the summary + +**Static/constexpr in member variable tables** (from `block_reduce_v3.mdx`): +```mdx +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +``` + +### AccordionGroup / Accordion + +**Rule:** Used exclusively for class-level template parameters. + +```mdx + + + + +Description + + + + +``` + +### CodeBlock with Links + +**Rule:** Used for all method/constructor/destructor signatures and concept definition signatures. + +**Format:** +```mdx + +```cpp showLineNumbers={false} +signature here +``` + +``` + +**Rules for the `links` prop:** +- Keys are short identifiers that appear in the signature +- Values are the API page paths +- Always include the owning class as a link +- Include parameter types and return types that have known pages +- For inherited methods, include both the declaring class and the owning class + +### ParamField + +**Rule:** Used for parameters and template parameters. + +**Props:** +- `path` (required): Parameter name +- `type` (required): Parameter type as string +- `default` (optional): Default value + +**Example:** +```mdx + +Pool options to use. + +``` + +**Empty ParamField rule (resolved decision #10):** Omit ParamField components entirely when there is no description text. Do not render empty ParamFields. + +However, some golden pages do include ParamFields with no description (from `pointer_v4.mdx`): +```mdx + + +``` +The resolved decision says to omit these. The canonical rule is: **omit ParamField components when there is no description text.** + +### Tabs + +**Rule:** Used for method overloads, const/mutable pairs, and any other multi-variant rendering. + +```mdx + + +Content for variant 1 + + +Content for variant 2 + + +``` + +--- + +## Tables + +### Typedef Tables + +**Rule:** Use either 2-column or 3-column format depending on whether descriptions are available. + +**3-column format** (Name | Definition | Description) -- when at least some typedefs have descriptions: + +**Example** (from `block_reduce_v3.mdx`): +```mdx +| Name | Definition | Description | +|---|---|---| +| `InternalBlockReduce` | `::cuda::std::_If< ... >` | Internal specialization type. | +| `_TempStorage` | `typename InternalBlockReduce::TempStorage` | Shared memory storage layout type for [BlockReduce](/library/api/cub::BlockReduce). | +| `WarpReductions` | `detail::BlockReduceWarpReductions< T, BlockDimX, BlockDimY, BlockDimZ >` | | +``` + +**2-column format** (Name | Definition) -- when NO typedefs have descriptions: + +**Example** (from `device_vector_v3.mdx`): +```mdx +| Name | Definition | +|---|---| +| `Parent` | `detail::vector_base< T, Alloc >` | +``` + +**Decision rule:** If ANY typedef in the table has a description, use 3-column. If NONE have descriptions, use 2-column. + +### Member Variable Tables + +**Rule:** Always 3-column: Name | Type | Description. + +**Example** (from `block_reduce_v3.mdx`): +```mdx +| Name | Type | Description | +|---|---|---| +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +| `temp_storage` | [`_TempStorage`](/library/api/cub::BlockReduce::_TempStorage) `&` | Shared storage reference. | +| `linear_tid` | `unsigned int` | Linear thread-id. | +``` + +**Edge cases:** +- Static/constexpr badges go in the Name column. +- Types that have API pages are linked. +- If a member has no description, leave the Description cell empty. + +### Inner Class Member Tables + +**Rule:** Inner classes that have only member variables (no methods) render a member variable table directly after the CodeBlock. + +**Example** (from `group_member_example_v4.mdx`): +```mdx +### chunk_descriptor + + +```cpp showLineNumbers={false} +struct thrust::mr::disjoint_unsynchronized_pool_resource::chunk_descriptor +``` + + +| Name | Type | Description | +|---|---|---| +| `size` | `std::size_t` | | +| `pointer` | `void_ptr` | | +| `pool_idx` | `std::size_t` | | +``` + +### Related Concepts Table (Concept Pages) + +**Rule:** 2-column: Concept | Description. Each concept name is a link. + +**Example** (from `concept_example_v3.mdx`): +```mdx +| Concept | Description | +|---|---| +| [`cuda::mr::resource`](/library/api/cuda::mr::resource) | Verifies that a type satisfies the basic requirements of a memory resource with stream-ordered allocations. | +``` + +--- + +## Cross-References + +### Link Path Convention per Namespace + +**Rule:** Cross-reference paths follow the pattern `/library/api/fully::qualified::name` for CUB and Thrust entities, and `/libcudacxx/api/fully::qualified::name` for libcudacxx (cuda::) entities. + +| Library/Namespace | Path prefix | Example | +|---|---|---| +| `cub::` | `/library/api/cub::` | `/library/api/cub::BlockReduce` | +| `thrust::` | `/library/api/thrust::` | `/library/api/thrust::device_vector` | +| `thrust::mr::` | `/library/api/thrust::mr::` | `/library/api/thrust::mr::memory_resource` | +| `cuda::` | `/libcudacxx/api/cuda::` | `/libcudacxx/api/cuda::counting_iterator` | +| `cuda::mr::` | `/library/api/cuda::mr::` | `/library/api/cuda::mr::resource` | +| `cuda::std::` | N/A (no API pages) | (not linked) | + +**Member references** use `::` separators: +- `/library/api/cub::BlockReduce::TempStorage` +- `/library/api/thrust::device_vector::size` + +**Method references:** +- `/library/api/thrust::device_vector::size` (links to the method on the class page) +- `/libcudacxx/api/cuda::stream::sync` +- `/libcudacxx/api/cuda::stream_ref::sync` +- `/libcudacxx/api/cuda::buffer::set_stream` + +### Same-Page Anchors vs Cross-Page Links + +**Rule:** Cross-page links always use the path format above. Same-page anchors are not used in the golden pages -- all references link to the full path even for items on the same page. + +--- + +## Empty State Handling + +### Missing Summary + +**Rule:** If the IR provides no summary/description for the class, omit the summary paragraph entirely. The page starts with the include header or template parameters. + +**Example** (from `empty_docstring_class_v4.mdx` -- `cuda::buffer`): +```mdx +--- +title: "cuda::buffer" +description: "A memory-safe buffer for managing typed, property-annotated device memory with stream-ordered allocation." +--- + +```cpp showLineNumbers={false} +#include +``` + + +... +``` +(Note: the frontmatter description is always present even if the body has no summary paragraph.) + +### Missing Parameters + +**Rule:** If a method has no parameters, omit the `**Parameters**` heading and all ParamField components entirely. + +### Missing Returns + +**Rule:** If a method returns `void`, omit the `**Returns:**` line entirely. Never write `**Returns:** void`. + +**Exception:** If the return type is meaningful (e.g., `void_ptr`), include it. + +### Missing Examples + +**Rule:** If a method has no example, simply omit the `**Example**` heading and code block. + +### Missing Member Variables + +**Rule:** If a class has no member variables, omit the `## Member variables` section entirely. + +### Missing Description on ParamField + +**Rule (resolved decision #10):** Omit ParamField components entirely when there is no description text. Do not render empty ParamFields. + +### Methods with No Description + +**Rule:** If a method has no description text, render just the signature CodeBlock (and any badges, parameters, returns). Omit the description paragraph. + +**Example** (from `pointer_v4.mdx`): +```mdx +### operator-> inline const + + +```cpp showLineNumbers={false} +Element * thrust::pointer< Element, Tag, Reference, Derived >::operator->() const +``` + +``` + +--- + +## Version Annotations + +**Rule:** Version annotations appear as italic text on their own line, immediately after the CodeBlock. + +**Format:** `*Added in vX.Y.Z. First appears in CUDA Toolkit X.Y.*` + +**Example** (from `block_reduce_v3.mdx`): +```mdx +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* +``` + +**Edge cases:** +- Not all methods have version annotations. If absent, omit. +- Always italic (wrapped in `*...*`). +- Appears between the CodeBlock and the callouts/parameters. + +--- + +## Separator Usage + +**Rule:** Horizontal rules (`---`) separate: +1. The preamble from the first body section +2. Each top-level H2 section from the next + +A `---` appears BEFORE each `## SectionHeading` (except the very first one after the preamble, where the `---` serves double duty). + +**Example structure:** +```mdx +(preamble content) + +--- + +## Constructors + +### ClassName +... + +--- + +## Methods + +### methodName +... + +--- + +## Types + +### Typedefs +... +``` + +**Edge cases:** +- Within a section, methods are NOT separated by `---`. Only H2 sections get separators. +- Exception: segmented methods in `warp_reduce_v4.mdx` -- individual non-overloaded methods within the same H2 section ARE separated by `---` when they are distinct methods (not overloads). + +**Example of method-level separators** (from `warp_reduce_v4.mdx`): +```mdx +### HeadSegmentedSum inline +... + +--- + +### TailSegmentedSum inline +... + +--- + +### HeadSegmentedReduce inline +``` + +The pattern: `---` between distinct H3 methods within the same H2 section when the methods are semantically separate (not overloads of the same function name). + +--- + +## Section Heading Rules + +| Level | Usage | +|---|---| +| H2 (`##`) | Top-level body sections: Constructors, Methods, Types, Member variables, etc. Also: Performance considerations, Example, Description (in preamble/concept pages) | +| H3 (`###`) | Individual methods, constructors, destructors, inner classes, "Typedefs" label, "Destructor" label | +| Never H1 | H1 is never used (the page title from frontmatter serves as H1) | +| Never H4 | H4 is never used. Use `**Bold text**` instead (e.g., `**Parameters**`) | diff --git a/fern/pages/style_reference.md b/fern/pages/style_reference.md new file mode 100644 index 0000000..b5e91e2 --- /dev/null +++ b/fern/pages/style_reference.md @@ -0,0 +1,874 @@ +# C++ Library Docs Style Reference + +This document captures micro-level formatting decisions extracted from the 12 golden pages. It complements the Rendering Rules document with exact spacing, prop patterns, and stylistic conventions. + +--- + +## Blank Line Conventions + +### Between top-level elements in the preamble + +One blank line between each preamble element (summary, include header, callouts, see also, template params, inherits from). + +**Example** (from `pointer_v4.mdx`): +```mdx +`pointer` stores a pointer to an object allocated in memory. + +Like [`device_ptr`](/library/api/thrust::device_ptr), this type ensures type safety ... + +```cpp showLineNumbers={false} +#include +``` + + +`pointer` is not a smart pointer; ... + + +**See also:** +[device_ptr](/library/api/thrust::device_ptr), +reference, +[raw_pointer_cast](/library/api/thrust::raw_pointer_cast) + + +``` + +### Inside Tabs + +One blank line after `` before content. One blank line before ``. + +**Example** (from `block_reduce_v3.mdx`): +```mdx + + +Computes a block-wide reduction for thread0 ... + + +```cpp showLineNumbers={false} +... +``` + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined ... + +**Template parameters** + + +... + + +**Parameters** + + +... + + +**Example** + +The code snippet below ... + +```cpp showLineNumbers={false} +... +``` + + +``` + +### Between ParamFields + +One blank line between consecutive `` components. + +**Example** (from `block_reduce_v3.mdx`): +```mdx + +Calling thread's input + + + +Binary reduction functor + +``` + +### Between heading and content + +One blank line between an H3 heading and the following content (description or Tabs). + +**Example** (from `device_vector_v3.mdx`): +```mdx +### size const + +Returns the number of elements in this vector. +``` + +### Between method sections (same H2) + +One blank line after the last element of one method, then the next H3 heading. When methods within the same H2 section are separated by `---`, the pattern is: + +```mdx +(end of method content) + +--- + +### NextMethodName +``` + +### Between CodeBlock and version annotation + +One blank line between `` and the italic version annotation. + +```mdx + + +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* +``` + +### Between version annotation and callouts/bullets + +One blank line between the version annotation and callout components or bullet lists. + +```mdx +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* + +- The return value is undefined ... +``` + +### Between callouts/bullets and Parameters heading + +One blank line between callout content and `**Parameters**`. + +### Between Returns and Parameters + +`**Returns:**` comes BEFORE `**Parameters**` when both are present (only seen in assignment operators). More commonly, `**Returns:**` comes AFTER `**Parameters**`. Actually, looking at the golden pages: + +Standard order within a method/tab: +1. Description text +2. CodeBlock (signature) +3. Version annotation (italic) +4. Callouts / behavioral notes +5. `**Template parameters**` + ParamFields +6. `**Parameters**` + ParamFields +7. `**Returns:**` line +8. `**Throws:**` line +9. `**Example**` + code block + +**Exception for assignment operators and short methods** (from `device_vector_v3.mdx`): +```mdx +**Returns:** `device_vector &` + +**Parameters** + + +``` +Here, Returns comes BEFORE Parameters. This pattern is used when the return value is short and the parameters are the main content. + +**The canonical order depends on the golden pages.** Analyzing all instances: +- In `device_vector_v3.mdx` (operator=, operator[], front, back, data, iterators, etc.): **Returns** comes BEFORE **Parameters** +- In `block_reduce_v3.mdx` (Reduce, Sum): **Returns** is absent (returns the value directly, described in prose), Parameters come after Template parameters +- In `simple_struct_v4.mdx` (operator()): **Returns** comes AFTER **Parameters** +- In `pointer_v4.mdx` (get, base, base_reference): **Returns** comes after CodeBlock, no Parameters +- In `group_member_example_v4.mdx` (do_allocate): **Returns** comes BEFORE **Parameters** + +**Rule:** `**Returns:**` comes immediately after the behavioral callouts/notes, BEFORE `**Parameters**`. If there are template parameters, the order is: Returns, then Template parameters, then Parameters. However, this is not consistent -- some pages put Returns after Parameters. The most common pattern across golden pages is: + +**Returns before Parameters** when the method is primarily characterized by its return value. +**Returns after Parameters** when the method is primarily characterized by what it does with its parameters. + +For consistency, adopt the most common golden page pattern: **Returns BEFORE Parameters**. + +--- + +## Component Prop Patterns + +### CodeBlock + +```mdx + +```cpp showLineNumbers={false} +signature +``` + +``` + +- Always `showLineNumbers={false}` on the inner code fence +- `links` prop is a JSON object mapping identifier strings to path strings +- Multiple links separated by commas: `{{"A": "/path/a", "B": "/path/b"}}` + +### ParamField + +```mdx + +Description text. + +``` + +With default: +```mdx + +Description text. + +``` + +- `path`: The parameter name exactly as it appears in the signature +- `type`: The full type string (may include `const`, `&`, `*`, template parameters) +- `default` (optional): The default value as a string + +### Badge + +Qualifier badge: +```mdx +qualifier_text +``` + +Entity-kind badge: +```mdx +C++20 concept +``` + +- Qualifier badges always use `intent="note"` and `minimal` +- Entity-kind badges use `intent="info"` without `minimal` +- Multiple badges separated by a single space on the same line + +### Accordion + +```mdx + + + +(ParamField components) + + + +``` + +- Always wrapped in `` +- Title is always `"Template parameters"` +- Blank line after `` and before `` + +### Callouts + +```mdx + +Content text. + + + +Content text with **bold** or `code`. + + + +Use `alternative` instead. + + + +State description. + +``` + +- `` always has `title="Deprecated"` for deprecation notices +- `` uses `title="Postconditions"` for postconditions, or other custom titles +- `` and `` typically have no `title` prop +- Blank line after opening tag, content, blank line before closing tag (or content directly between tags with no extra blank lines -- both patterns appear) + +--- + +## Tab Title Conventions + +### Constructor Tab Titles + +| Scenario | Title | +|---|---| +| No-argument default | `"Default"` or `"Default constructor"` | +| From nullptr | `"From nullptr"` | +| From raw pointer | `"From raw pointer"` | +| From other pointer/type | `"From other pointer"`, `"From other device_vector type"` | +| From specific type | `"From std::vector"`, `"From vector_base"`, `"From initializer_list"` | +| With an allocator | `"With allocator"` | +| With a config object | `"With TempStorage"`, `"With upstream and bookkeeper"`, `"With default resources"` | +| Move | `"Move"`, `"Move with allocator"` | +| Copy | `"Copy constructor"`, `"From device_vector"` | +| Copy with allocator | `"From device_vector with allocator"` | +| Size variants | `"Value-initialized"`, `"Default-initialized"`, `"No-init"` | +| Fill variants | `"n copies of value"`, `"n copies with allocator"` | +| Range | `"From iterator range"`, `"From iterator range with allocator"` | +| Deleted | `"Copy (deleted)"` | + +### Method Overload Tab Titles + +| Scenario | Title | +|---|---| +| Single item vs array | `"Single item"`, `"Multiple items per thread"` | +| Partial operation | `"Partial tile"`, `"Partial warp"` | +| Full vs partial | `"Full warp"`, `"Partial warp"` | +| With aggregate output | `"With aggregate"`, `"With prefix callback"` | +| Const/Mutable | `"Mutable"`, `"Const"` | +| Copy/Move assign | `"Copy assign"`, `"Move assign"` | +| Functional | `"Fill"`, `"Range"`, `"Single element"` | +| Pre/Post | `"Pre-increment"`, `"Post-increment"`, `"Pre-decrement"`, `"Post-decrement"` | +| With/without stream | `"With stream"`, `"Default stream"` | +| Deprecated variant | `"Deprecated (no args)"` | +| Deleted | `"Copy assign (deleted)"`, `"Deleted overloads"` | + +--- + +## Heading Hierarchy + +| Level | Usage | Examples | +|---|---|---| +| **H1** | Never used | -- | +| **H2** (`##`) | Top-level body sections | `## Constructors`, `## Methods`, `## Types`, `## Member variables`, `## Inner classes`, `## Performance considerations`, `## Example`, `## Description`, `## Related concepts`, `## Collective constructors`, `## Generic reductions`, `## Summation reductions`, `## Exclusive prefix sum operations`, `## Assignment operators`, `## Element access`, `## Iterators`, `## Capacity`, `## Modifiers`, `## Allocator`, `## Static methods`, `## Friend functions`, `## Pool management`, `## Allocation`, `## Comparison`, `## Ownership`, `## Accessors`, `## Synchronization`, `## Query methods`, `## Event recording`, `## Device information`, `## Increment operators`, `## Compound assignment operators`, `## Resource and stream management`, `## Segmented reductions`, `## Max reductions`, `## Min reductions`, `## Utility methods` | +| **H3** (`###`) | Individual members | Method names, constructor names, destructor label, inner class names, `### Typedefs`, `### Destructor` | +| **H4** (`####`) | Never used | Use `**Bold text**` instead | + +--- + +## Cross-Ref Path Format + +### Per-Library Path Patterns + +| Pattern | Path format | Example | +|---|---|---| +| CUB classes | `/library/api/cub::ClassName` | `/library/api/cub::BlockReduce` | +| CUB nested types | `/library/api/cub::ClassName::TypeName` | `/library/api/cub::BlockReduce::TempStorage` | +| CUB enums | `/library/api/cub::EnumValue` | `/library/api/cub::BLOCK_REDUCE_WARP_REDUCTIONS` | +| CUB methods | `/library/api/cub::ClassName::MethodName` | `/library/api/cub::BlockReduce::_TempStorage` | +| Thrust classes | `/library/api/thrust::ClassName` | `/library/api/thrust::device_vector` | +| Thrust nested types | `/library/api/thrust::ClassName::TypeName` | `/library/api/thrust::device_vector::size` | +| Thrust free functions | `/library/api/thrust::FunctionName` | `/library/api/thrust::raw_pointer_cast` | +| Thrust nested namespaces | `/library/api/thrust::mr::ClassName` | `/library/api/thrust::mr::memory_resource` | +| libcudacxx classes | `/libcudacxx/api/cuda::ClassName` | `/libcudacxx/api/cuda::counting_iterator` | +| libcudacxx nested | `/libcudacxx/api/cuda::ClassName::Member` | `/libcudacxx/api/cuda::stream::sync` | +| libcudacxx cross-ref from other pages | `/libcudacxx/api/cuda::ClassName` | `/libcudacxx/api/cuda::stream_ref` | +| cuda::mr concepts | `/library/api/cuda::mr::ConceptName` | `/library/api/cuda::mr::resource` | +| cuda:: concepts/traits | `/library/api/cuda::ConceptName` | `/library/api/cuda::has_property` | + +**Key distinction:** `cuda::` entities that are part of libcudacxx use `/libcudacxx/api/` prefix. `cuda::mr::` concepts referenced from the concept_example page use `/library/api/` prefix. This distinction maps to the library the entity belongs to. + +### In-Signature Links + +Links in `` map short identifiers to full paths: + +```mdx + +``` + +```mdx + +``` + +```mdx + +``` + +```mdx + +``` + +### In-Text Links + +Markdown link format: `[display text](/path/to/api)` + +**Example** (from `device_vector_v3.mdx`): +```mdx +[size()](/library/api/thrust::device_vector::size) +``` + +**Example** (from `concept_example_v3.mdx`): +```mdx +[`cuda::mr::resource`](/library/api/cuda::mr::resource) +``` + +--- + +## Badge Conventions + +### Which intent for which purpose + +| Badge text | Intent | Minimal | When used | +|---|---|---|---| +| `inline` | `note` | yes | Method/constructor qualifier | +| `const` | `note` | yes | Method qualifier | +| `static` | `note` | yes | Method qualifier, member variable qualifier | +| `virtual` | `note` | yes | Method qualifier | +| `explicit` | `note` | yes | Constructor/conversion qualifier | +| `constexpr` | `note` | yes | Method/variable qualifier | +| `noexcept` | `note` | yes | Method qualifier | +| `nodiscard` | `note` | yes | Method qualifier | +| `final` | `note` | yes | Class qualifier | +| `C++20 concept` | `info` | no | Entity-kind indicator | + +### Badge ordering on H3 headings + +Badges appear in a consistent order. The observed canonical order is: + +1. `inline` +2. `static` +3. `constexpr` +4. `explicit` +5. `const` +6. `noexcept` +7. `nodiscard` +8. `virtual` +9. `final` + +**Example** (from `deep_template_class_v4.mdx`): +```mdx +### operator* inline constexpr const noexcept nodiscard +``` + +**Example** (from `group_member_example_v4.mdx`): +```mdx +### do_allocate inline nodiscard virtual +``` + +**Example** (from `pointer_v4.mdx`): +```mdx +### operator bool inline explicit const +``` + +### Badge placement in member variable tables + +Badges for `static` and `constexpr` go in the Name cell, after the variable name: + +```mdx +| `BLOCK_THREADS` static constexpr | `int` | The thread block size in threads. | +``` + +--- + +## Signature Formatting + +### General rules + +- Signatures are always inside `` with `showLineNumbers={false}` +- Fully-qualified name: `namespace::ClassName::MethodName` +- Template prefix on its own line when present +- Parameters indented with 4 spaces when multi-line +- Closing `)` on same line as last parameter, or on its own line for `const`/`noexcept` qualifiers + +### Template prefix + +When a method has its own template parameters, the template line comes first: + +```cpp +template +T cub::BlockReduce::Reduce( + T input, + ReductionOp reduction_op +) +``` + +### Multi-parameter line breaking + +Parameters are each on their own line, indented with 4 spaces: + +```cpp +void cub::BlockScan::ExclusiveSum( + T input, + T &output, + T &block_aggregate +) +``` + +### Single-parameter or no-parameter signatures + +May be on one line: + +```cpp +T cub::WarpReduce::Sum( + T input +) +``` + +Or inline: + +```cpp +size_type thrust::device_vector::size() const +``` + +### Const/noexcept qualifiers on signatures + +Appear after the closing `)`: + +```cpp +const_reference thrust::device_vector::front() const +``` + +```cpp +void cuda::buffer<_Tp, _Properties>::swap( + buffer &__other +) noexcept +``` + +### SFINAE / enable_if in signatures + +Rendered as-is in the signature: + +```cpp +template , int> = 0> +T cub::WarpReduce::Sum( + const InputType &input +) +``` + +### Deleted functions + +```cpp +cuda::stream::stream( + const stream & +) = delete +``` + +### Default arguments in signatures + +```cpp +thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::disjoint_unsynchronized_pool_resource( + Upstream *upstream, + Bookkeeper *bookkeeper, + pool_options options = get_default_options() +) +``` + +--- + +## Description Segment Rendering + +The IR contains description segments of various types. Here is how each renders in MDX: + +### Plain text (DocTextSegment) + +Rendered as-is in markdown. + +### Code references (DocCodeSegment) + +Wrapped in backticks: `` `code_here` `` + +### Bold text (DocBoldSegment) + +Wrapped in `**bold**`. + +### Italic text (DocItalicSegment) + +Wrapped in `*italic*`. + +### Links/references (DocRefSegment) + +Rendered as markdown links: `[display text](/path/to/target)` + +### Subscript + +Using HTML: `thread0`, `lane0` + +### Superscript + +Using HTML: `*i*th` + +### Inline math / special formatting + +Rendered using HTML tags or markdown as appropriate. + +--- + +## Separator Usage + +### Where `---` goes + +1. After the preamble (before first H2 body section) +2. Between each H2 body section +3. Between distinct H3 methods within the same H2 section (when they are different method names, not overloads) + +### Where `---` does NOT go + +1. Between overloads of the same method (use Tabs instead) +2. Between the `### Destructor` label and `### ~ClassName` +3. After the very last section on the page + +### Example structure (from `warp_reduce_v4.mdx`) + +```mdx +## Segmented reductions + +### HeadSegmentedSum inline + +(content) + +--- + +### TailSegmentedSum inline + +(content) + +--- + +### HeadSegmentedReduce inline + +(content) + +--- + +### TailSegmentedReduce inline + +(content) + +--- + +## Types +``` + +--- + +## Version Annotation Format + +**Format:** Italic text, standalone line. + +**Full form:** +```mdx +*Added in v2.2.0. First appears in CUDA Toolkit 12.3.* +``` + +**Rules:** +- Always italic (wrapped in `*...*`) +- Period after each sentence +- Appears on its own line immediately after `` +- One blank line before and after +- If no version info is available, omit entirely + +--- + +## Table Formatting + +### Column headers and alignment + +Always use `|---|` (no alignment colons) between header and body: + +```mdx +| Name | Definition | Description | +|---|---|---| +| `TypeName` | `Definition` | Description text. | +``` + +### When to use 2 vs 3 columns for typedefs + +**3-column** (Name | Definition | Description): When ANY typedef in the table has a description. + +**Example** (from `block_reduce_v3.mdx`): +```mdx +| Name | Definition | Description | +|---|---|---| +| `InternalBlockReduce` | `...` | Internal specialization type. | +| `_TempStorage` | `...` | Shared memory storage layout type for ... | +| `WarpReductions` | `...` | | +``` + +**2-column** (Name | Definition): When NO typedefs have descriptions. + +**Example** (from `device_vector_v3.mdx`): +```mdx +| Name | Definition | +|---|---| +| `Parent` | `detail::vector_base< T, Alloc >` | +``` + +**Example** (from `deep_template_class_v4.mdx`): +```mdx +| Name | Definition | +|---|---| +| `value_type` | `_Start` | +| `difference_type` | `_IotaDiffT<_Start>` | +``` + +### Member variable table format + +Always 3-column: Name | Type | Description. + +```mdx +| Name | Type | Description | +|---|---|---| +| `variable_name` | `type` | Description. | +``` + +Badges in Name column: +```mdx +| `default_priority` static constexpr | `int` | The default stream priority. | +``` + +Types with links: +```mdx +| `m_options` | [`pool_options`](/library/api/thrust::mr::pool_options) | | +``` + +Types with reference qualifiers: +```mdx +| `temp_storage` | [`_TempStorage`](/library/api/cub::BlockReduce::_TempStorage) `&` | Shared storage reference. | +``` + +### Related concepts table format (concept pages only) + +2-column: Concept | Description. Concept names are linked. + +```mdx +| Concept | Description | +|---|---| +| [`cuda::mr::resource`](/library/api/cuda::mr::resource) | Verifies that a type satisfies ... | +``` + +### Inner class member tables + +3-column (same as member variable tables). Appear directly after the inner class CodeBlock. + +```mdx +### chunk_descriptor + + +```cpp showLineNumbers={false} +struct namespace::ClassName::chunk_descriptor +``` + + +| Name | Type | Description | +|---|---|---| +| `size` | `std::size_t` | | +| `pointer` | `void_ptr` | | +``` + +--- + +## Callout Usage + +### Which callout type for which IR field + +| IR field / concept | Callout component | +|---|---| +| Deprecated flag | `` | +| Behavioral notes / preconditions | `` (or inline bullet list per golden pages, but resolved to use ``) | +| Warnings about user responsibility | `` | +| Postconditions | `` | +| Memory allocation notes | `` | +| Thread safety notes | `` | +| Destroyed/moved-from state warnings | `` | + +### Callout placement in method tabs + +Callouts appear AFTER the CodeBlock and version annotation, BEFORE template parameters and regular parameters. + +**Example order within a Tab** (from `raises_example_v4.mdx`): +```mdx + + +explicit noexcept + +Construct a new [`stream`](/libcudacxx/api/cuda::stream) object into the moved-from state. + + +```cpp showLineNumbers={false} +... +``` + + + +[`stream()`](/libcudacxx/api/cuda::stream::stream) returns an invalid stream handle. + + + +``` + +### Callout placement in preamble + +Callouts in the preamble appear after the include header, before See Also: + +```mdx +```cpp showLineNumbers={false} +#include +``` + + +`pointer` is not a smart pointer; ... + + +**See also:** +... +``` + +For deprecated classes, the `` callout appears after the include header: + +```mdx +```cpp showLineNumbers={false} +#include +``` + + +Use `cuda::strided_iterator` instead. + +``` + +--- + +## Empty State Rules + +### What to omit + +| Missing data | Action | +|---|---| +| No summary text | Omit summary paragraph(s); frontmatter description is still present | +| No include header | Omit the include code block | +| No deprecation | Omit `` callout | +| No see-also refs | Omit `**See also:**` line | +| No class-level example | Omit `## Example` section | +| No performance considerations | Omit `## Performance considerations` section | +| No template parameters | Omit `` entirely | +| No base classes | Omit `**Inherits from:**` line | +| No methods in a category | Omit that H2 section entirely | +| No member variables | Omit `## Member variables` section | +| No inner classes | Omit `## Inner classes` section | +| No typedefs | Omit `### Typedefs` and its table | +| No version annotation | Omit the italic version line | +| No return value (void) | Omit `**Returns:**` line | +| No throws | Omit `**Throws:**` line | +| No method-level example | Omit `**Example**` and code block | +| No parameters | Omit `**Parameters**` heading and ParamFields | +| No template params (method) | Omit `**Template parameters**` heading and ParamFields | +| No description on ParamField | Omit the ParamField component entirely (resolved decision #10) | +| No description on method | Omit description paragraph; keep CodeBlock and other fields | + +### What to keep even when seemingly empty + +| Element | Always keep | +|---|---| +| Frontmatter `description` | Always present, even if body has no summary | +| `---` separators | Always present between H2 sections | +| CodeBlock for signature | Always present for every method/constructor/destructor | +| H3 heading for method | Always present even with no description | +| `### Destructor` label | Always present before `### ~ClassName` | +| `### Typedefs` heading | Present when there are typedefs to show | + +--- + +## Method Content Ordering Within a Tab + +The canonical order of elements inside a `` (or for a non-tabbed method): + +1. **Overload-specific badges** (standalone line, if any) +2. **Description text** (one or more paragraphs) +3. **CodeBlock** (signature) +4. **Version annotation** (italic, if present) +5. **Callouts** (``, ``, ``, ``) +6. **Returns** (`**Returns:**` line, if present) +7. **Template parameters** (`**Template parameters**` + ParamFields, if present) +8. **Parameters** (`**Parameters**` + ParamFields, if present) +9. **Throws** (`**Throws:**` line, if present) +10. **Example** (`**Example**` + code block, if present) + +**Note on Returns placement:** In many golden pages (especially `device_vector_v3.mdx`), **Returns** appears BEFORE **Parameters**. In others (like `simple_struct_v4.mdx`), it appears AFTER. The predominant pattern is Returns before Parameters. + +**Alternate observed order (from block_reduce_v3.mdx within Tabs):** +1. Description +2. CodeBlock +3. Version annotation +4. Behavioral bullet points / callouts +5. Template parameters +6. Parameters +7. Example + +(No explicit Returns line in block_reduce pages because the return is described in the method description.) diff --git a/fern/pages/thrust/deprecated_example_v4.mdx b/fern/pages/thrust/deprecated_example_v5.mdx similarity index 95% rename from fern/pages/thrust/deprecated_example_v4.mdx rename to fern/pages/thrust/deprecated_example_v5.mdx index 7a4302b..026618b 100644 --- a/fern/pages/thrust/deprecated_example_v4.mdx +++ b/fern/pages/thrust/deprecated_example_v5.mdx @@ -60,14 +60,6 @@ thrust::strided_iterator< RandomAccessIterator, StrideHolder >::strided_iterator ``` -**Parameters** - - - - - - - @@ -131,4 +123,4 @@ RandomAccessIterator const & thrust::iterator_adaptor< strided_iterator< RandomA | Name | Type | Description | |---|---|---| -| `has_static_stride` static constexpr | `bool` | | +| `has_static_stride` static constexpr | `bool` | | \ No newline at end of file diff --git a/fern/pages/thrust/deprecated_example.mdx b/fern/pages/thrust/deprecated_example_v6.mdx similarity index 62% rename from fern/pages/thrust/deprecated_example.mdx rename to fern/pages/thrust/deprecated_example_v6.mdx index 354aa6b..afb0670 100644 --- a/fern/pages/thrust/deprecated_example.mdx +++ b/fern/pages/thrust/deprecated_example_v6.mdx @@ -3,31 +3,31 @@ title: thrust::strided_iterator description: "An iterator adaptor that wraps another iterator and moves it by a specified stride each time it is incremented or decremented." --- -A [`strided_iterator`](/library/api/thrust::strided_iterator) wraps another iterator and moves it by a specified stride each time it is incremented or decremented. +A `strided_iterator` wraps another iterator and moves it by a specified stride each time it is incremented or decremented. ```cpp showLineNumbers={false} #include ``` - -Use `cuda::strided_iterator` instead. - + +Use `cuda::strided_iterator` instead + -A random access iterator. +A random access iterator -Either a [runtime_value](/library/api/thrust::runtime_value) or a [compile_time_value](/library/api/thrust::compile_time_value) specifying the stride. +Either a [runtime_value](/library/api/thrust::runtime_value) or a [compile_time_value](/library/api/thrust::compile_time_value) specifying the stride -**Inherits from:** [`thrust::iterator_adaptor< strided_iterator< RandomAccessIterator, StrideHolder >, RandomAccessIterator >`](/library/api/thrust::iterator_adaptor) (public), `StrideHolder` (private) +**Inherits from:** `thrust::iterator_adaptor< strided_iterator< RandomAccessIterator, StrideHolder >, RandomAccessIterator >` (public), `StrideHolder` (private) --- @@ -40,34 +40,26 @@ Either a [runtime_value](/library/api/thrust::runtime_value) or a [compile_time_ ```cpp showLineNumbers={false} -thrust::strided_iterator< RandomAccessIterator, StrideHolder >::strided_iterator() = default +thrust::strided_iterator::strided_iterator() = default ``` - + inline -Creates a [strided_iterator](/library/api/thrust::strided_iterator) from an existing iterator and a stride. +Creates a `strided_iterator` from an existing iterator and a stride. ```cpp showLineNumbers={false} -thrust::strided_iterator< RandomAccessIterator, StrideHolder >::strided_iterator( +thrust::strided_iterator::strided_iterator( RandomAccessIterator it, StrideHolder stride = {} ) ``` -**Parameters** - - - - - - - @@ -81,7 +73,7 @@ Returns either the [runtime_value](/library/api/thrust::runtime_value) or the [c ```cpp showLineNumbers={false} -const auto & thrust::strided_iterator< RandomAccessIterator, StrideHolder >::stride_holder() const +const auto & thrust::strided_iterator::stride_holder() const ``` @@ -91,15 +83,15 @@ Returns the stride's value. ```cpp showLineNumbers={false} -difference_type thrust::strided_iterator< RandomAccessIterator, StrideHolder >::stride() const +difference_type thrust::strided_iterator::stride() const ``` ### base inline const - + ```cpp showLineNumbers={false} -RandomAccessIterator const & thrust::iterator_adaptor< strided_iterator< RandomAccessIterator, StrideHolder >, RandomAccessIterator, use_default, use_default, use_default, use_default, use_default >::base() const +RandomAccessIterator const & thrust::iterator_adaptor, RandomAccessIterator, use_default, use_default, use_default, use_default, use_default>::base() const ``` @@ -107,9 +99,9 @@ RandomAccessIterator const & thrust::iterator_adaptor< strided_iterator< RandomA ### base_reference inline const - + ```cpp showLineNumbers={false} -RandomAccessIterator const & thrust::iterator_adaptor< strided_iterator< RandomAccessIterator, StrideHolder >, RandomAccessIterator, use_default, use_default, use_default, use_default, use_default >::base_reference() const +RandomAccessIterator const & thrust::iterator_adaptor, RandomAccessIterator, use_default, use_default, use_default, use_default, use_default>::base_reference() const ``` @@ -123,7 +115,7 @@ RandomAccessIterator const & thrust::iterator_adaptor< strided_iterator< RandomA | Name | Definition | Description | |---|---|---| -| `base_type` | `RandomAccessIterator` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)'s adapts. | +| `base_type` | `RandomAccessIterator` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. | --- diff --git a/fern/pages/thrust/device_vector.mdx b/fern/pages/thrust/device_vector.mdx deleted file mode 100644 index 8869d0d..0000000 --- a/fern/pages/thrust/device_vector.mdx +++ /dev/null @@ -1,1188 +0,0 @@ ---- -title: thrust::device_vector ---- - -# device_vector - -A `device_vector` is a container that supports random access to elements, constant time removal of elements at the end, and linear time insertion and removal of elements at the beginning or in the middle. - -The number of elements in a `device_vector` may vary dynamically; memory management is automatic. The memory associated with a `device_vector` resides in the memory accessible to devices. - -```cpp -#include -``` - -**See also:** -[https://en.cppreference.com/w/cpp/container/vector](https://en.cppreference.com/w/cpp/container/vector), -[device_allocator](/library/api/thrust::device_allocator), -[host_vector](/library/api/thrust::host_vector), -[universal_vector](/library/api/thrust::universal_vector) - - - - -The element type of the vector. - - - -**[optional]** The allocator type used for memory management (default: [thrust::device_allocator](/library/api/thrust::device_allocator)``). - - - - -**Inherits from:** `detail::vector_base< T, thrust::device_allocator< T > >` (public) - ---- - -## Constructors - -### Default and allocator constructors - - - - -This constructor creates an empty `device_vector`. - -```cpp -thrust::device_vector< T, Alloc >::device_vector() -``` - - - - -This constructor creates an empty `device_vector`. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(const Alloc &alloc) -``` - -#### Parameters - - -The allocator to use by this `device_vector`. - - - - - -### Size constructors - - - - -explicit - -This constructor creates a `device_vector` with the given size. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(size_type n) -``` - -#### Parameters - - -The number of elements to initially create. - - - - - -This constructor creates a `device_vector` with the given size, performing only default-initialization instead of value-initialization. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(size_type n, default_init_t) -``` - -#### Parameters - - -The number of elements to initially create. - - - - - -This constructor creates a `device_vector` with the given size, without initializing elements. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(size_type n, no_init_t) -``` - -#### Parameters - - -The number of elements to initially create. - - - - - -explicit - -This constructor creates a `device_vector` with the given size. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(size_type n, const Alloc &alloc) -``` - -#### Parameters - - -The number of elements to initially create. - - - -The allocator to use by this `device_vector`. - - - - - -### Fill constructors - - - - -explicit - -This constructor creates a `device_vector` with copies of an exemplar element. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(size_type n, const value_type &value) -``` - -#### Parameters - - -The number of elements to initially create. - - - -An element to copy. - - - - - -explicit - -This constructor creates a `device_vector` with copies of an exemplar element. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(size_type n, const value_type &value, const Alloc &alloc) -``` - -#### Parameters - - -The number of elements to initially create. - - - -An element to copy. - - - -The allocator to use by this `device_vector`. - - - - - -### Copy constructors - - - - -Copy constructor copies from an exemplar `device_vector`. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(const device_vector &v) -``` - -#### Parameters - - -The `device_vector` to copy. - - - - - -Copy constructor copies from an exemplar `device_vector`. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(const device_vector &v, const Alloc &alloc) -``` - -#### Parameters - - -The `device_vector` to copy. - - - -The allocator to use by this `device_vector`. - - - - - -explicit - -Copy constructor copies from an exemplar `device_vector` with different type. - -```cpp -template -thrust::device_vector< T, Alloc >::device_vector(const device_vector< OtherT, OtherAlloc > &v) -``` - -#### Parameters - - -The `device_vector` to copy. - - - - - -Copy constructor copies from an exemplar `std::vector`. - -```cpp -template -thrust::device_vector< T, Alloc >::device_vector(const std::vector< OtherT, OtherAlloc > &v) -``` - -#### Parameters - - -The `std::vector` to copy. - - - - - -Copy construct from a `vector_base` whose element type is convertible to `T`. - -```cpp -template -thrust::device_vector< T, Alloc >::device_vector(const detail::vector_base< OtherT, OtherAlloc > &v) -``` - -#### Parameters - - -The `vector_base` to copy. - - - - - -### Move constructors - - - - -Move constructor moves from another `device_vector`. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(device_vector &&v) -``` - -#### Parameters - - -The `device_vector` to move. - - - - - -Move constructor moves from another `device_vector`. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(device_vector &&v, const Alloc &alloc) -``` - -#### Parameters - - -The `device_vector` to move. - - - -The allocator to use by this `device_vector`. - - - - - -### Initializer list constructors - - - - -This constructor builds a `device_vector` from an initializer_list. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(::cuda::std::initializer_list< T > il) -``` - -#### Parameters - - -The initializer_list. - - - - - -This constructor builds a `device_vector` from an initializer_list. - -```cpp -thrust::device_vector< T, Alloc >::device_vector(::cuda::std::initializer_list< T > il, const Alloc &alloc) -``` - -#### Parameters - - -The initializer_list. - - - -The allocator to use by this `device_vector`. - - - - - -### Range constructors - - - - -This constructor builds a `device_vector` from a range. - -```cpp -template -thrust::device_vector< T, Alloc >::device_vector(InputIterator first, InputIterator last) -``` - -#### Parameters - - -The beginning of the range. - - - -The end of the range. - - - - - -This constructor builds a `device_vector` from a range. - -```cpp -template -thrust::device_vector< T, Alloc >::device_vector(InputIterator first, InputIterator last, const Alloc &alloc) -``` - -#### Parameters - - -The beginning of the range. - - - -The end of the range. - - - -The allocator to use by this `device_vector`. - - - - - -### Destructor - -#### ~device_vector inline - -The destructor erases the elements. - -```cpp -thrust::device_vector< T, Alloc >::~device_vector() -``` - ---- - -## Assignment operators - -### operator= inline - - - - -Copy assign operator copies another `device_vector` with the same type. - -```cpp -device_vector & thrust::device_vector< T, Alloc >::operator=(const device_vector &v) -``` - -**Returns:** `device_vector &` - -#### Parameters - - -The `device_vector` to copy. - - - - - -Move assign operator moves from another `device_vector`. - -```cpp -device_vector & thrust::device_vector< T, Alloc >::operator=(device_vector &&v) -``` - -**Returns:** `device_vector &` - -#### Parameters - - -The `device_vector` to move. - - - - - -Assign operator copies from an exemplar `device_vector` with different type. - -```cpp -template -device_vector & thrust::device_vector< T, Alloc >::operator=(const device_vector< OtherT, OtherAlloc > &v) -``` - -**Returns:** `device_vector &` - -#### Parameters - - -The `device_vector` to copy. - - - - - -Assign operator copies from an exemplar `std::vector`. - -```cpp -template -device_vector & thrust::device_vector< T, Alloc >::operator=(const std::vector< OtherT, OtherAlloc > &v) -``` - -**Returns:** `device_vector &` - -#### Parameters - - -The `std::vector` to copy. - - - - - -Assign a `vector_base` whose element type is convertible to `T`. - -```cpp -template -device_vector & thrust::device_vector< T, Alloc >::operator=(const detail::vector_base< OtherT, OtherAlloc > &v) -``` - -**Returns:** `device_vector &` - -#### Parameters - - -The `vector_base` to copy. - - - - - -Assign an `initializer_list` with a matching element type. - -```cpp -device_vector & thrust::device_vector< T, Alloc >::operator=(::cuda::std::initializer_list< T > il) -``` - -**Returns:** `device_vector &` - -#### Parameters - - -The initializer_list. - - - - - ---- - -## Element access - -### operator[] - - - - -Subscript access to the data contained in this vector. - -```cpp -reference thrust::device_vector< T, Alloc >::operator[](size_type n) -``` - -**Returns:** Read/write reference to data. - -#### Parameters - - -The index of the element for which data should be accessed. - - - - - -const - -Subscript read access to the data contained in this vector. - -```cpp -const_reference thrust::device_vector< T, Alloc >::operator[](size_type n) const -``` - -**Returns:** Read reference to data. - -#### Parameters - - -The index of the element for which data should be accessed. - - - - - -### front - - - - -This method returns a reference pointing to the first element of this vector. - -```cpp -reference thrust::device_vector< T, Alloc >::front() -``` - -**Returns:** The first element of this vector. - - - - -const - -This method returns a `const_reference` referring to the first element of this vector. - -```cpp -const_reference thrust::device_vector< T, Alloc >::front() const -``` - -**Returns:** The first element of this vector. - - - - -### back - - - - -This method returns a reference referring to the last element of this vector. - -```cpp -reference thrust::device_vector< T, Alloc >::back() -``` - -**Returns:** The last element of this vector. - - - - -const - -This method returns a const reference pointing to the last element of this vector. - -```cpp -const_reference thrust::device_vector< T, Alloc >::back() const -``` - -**Returns:** The last element of this vector. - - - - -### data - - - - -This method returns a pointer to this vector's first element. - -```cpp -pointer thrust::device_vector< T, Alloc >::data() -``` - -**Returns:** A pointer to the first element of this vector. - - - - -const - -This method returns a `const_pointer` to this vector's first element. - -```cpp -const_pointer thrust::device_vector< T, Alloc >::data() const -``` - -**Returns:** A `const_pointer` to the first element of this vector. - - - - ---- - -## Iterators - -### begin - - - - -This method returns an iterator pointing to the beginning of this vector. - -```cpp -iterator thrust::device_vector< T, Alloc >::begin() -``` - -**Returns:** `iterator` to the beginning. - - - - -const - -This method returns a `const_iterator` pointing to the beginning of this vector. - -```cpp -const_iterator thrust::device_vector< T, Alloc >::begin() const -``` - -**Returns:** `const_iterator` to the beginning. - - - - -### cbegin const - -This method returns a `const_iterator` pointing to the beginning of this vector. - -```cpp -const_iterator thrust::device_vector< T, Alloc >::cbegin() const -``` - -**Returns:** `const_iterator` to the beginning. - -### end - - - - -This method returns an iterator pointing to one element past the last of this vector. - -```cpp -iterator thrust::device_vector< T, Alloc >::end() -``` - -**Returns:** `iterator` past the end. - - - - -const - -This method returns a `const_iterator` pointing to one element past the last of this vector. - -```cpp -const_iterator thrust::device_vector< T, Alloc >::end() const -``` - -**Returns:** `const_iterator` past the end. - - - - -### cend const - -This method returns a `const_iterator` pointing to one element past the last of this vector. - -```cpp -const_iterator thrust::device_vector< T, Alloc >::cend() const -``` - -**Returns:** `const_iterator` past the end. - -### rbegin - - - - -This method returns a `reverse_iterator` pointing to the beginning of this vector's reversed sequence. - -```cpp -reverse_iterator thrust::device_vector< T, Alloc >::rbegin() -``` - -**Returns:** A `reverse_iterator` pointing to the beginning of this vector's reversed sequence. - - - - -const - -This method returns a `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. - -```cpp -const_reverse_iterator thrust::device_vector< T, Alloc >::rbegin() const -``` - -**Returns:** A `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. - - - - -### crbegin const - -This method returns a `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. - -```cpp -const_reverse_iterator thrust::device_vector< T, Alloc >::crbegin() const -``` - -**Returns:** A `const_reverse_iterator` pointing to the beginning of this vector's reversed sequence. - -### rend - - - - -This method returns a `reverse_iterator` pointing to one element past the last of this vector's reversed sequence. - -```cpp -reverse_iterator thrust::device_vector< T, Alloc >::rend() -``` - -**Returns:** A `reverse_iterator` past the end of the reversed sequence. - - - - -const - -This method returns a `const_reverse_iterator` pointing to one element past the last of this vector's reversed sequence. - -```cpp -const_reverse_iterator thrust::device_vector< T, Alloc >::rend() const -``` - -**Returns:** A `const_reverse_iterator` past the end of the reversed sequence. - - - - -### crend const - -This method returns a `const_reverse_iterator` pointing to one element past the last of this vector's reversed sequence. - -```cpp -const_reverse_iterator thrust::device_vector< T, Alloc >::crend() const -``` - -**Returns:** A `const_reverse_iterator` past the end of the reversed sequence. - ---- - -## Capacity - -### size const - -Returns the number of elements in this vector. - -```cpp -size_type thrust::device_vector< T, Alloc >::size() const -``` - -**Returns:** `size_type` -- the number of elements. - -### max_size const - -Returns the [size()](/library/api/thrust::device_vector::size) of the largest possible vector. - -```cpp -size_type thrust::device_vector< T, Alloc >::max_size() const -``` - -**Returns:** The largest possible return value of [size()](/library/api/thrust::device_vector::size). - -### capacity const - -Returns the number of elements which have been reserved in this vector. - -```cpp -size_type thrust::device_vector< T, Alloc >::capacity() const -``` - -**Returns:** `size_type` -- the number of elements reserved. - -### empty const - -This method returns true iff [size()](/library/api/thrust::device_vector::size) == 0. - -```cpp -bool thrust::device_vector< T, Alloc >::empty() const -``` - -**Returns:** `true` if [size()](/library/api/thrust::device_vector::size) == 0; `false`, otherwise. - -### reserve - -If `n` is less than or equal to [capacity()](/library/api/thrust::device_vector::capacity), this call has no effect. Otherwise, this method is a request for allocation of additional memory. If the request is successful, then [capacity()](/library/api/thrust::device_vector::capacity) is greater than or equal to `n`; otherwise, [capacity()](/library/api/thrust::device_vector::capacity) is unchanged. In either case, [size()](/library/api/thrust::device_vector::size) is unchanged. - -```cpp -void thrust::device_vector< T, Alloc >::reserve(size_type n) -``` - -**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). - -### shrink_to_fit - -This method shrinks the capacity of this vector to exactly fit its elements. - -```cpp -void thrust::device_vector< T, Alloc >::shrink_to_fit() -``` - -### resize - - - - -Resizes this vector to the specified number of elements. - -```cpp -void thrust::device_vector< T, Alloc >::resize(size_type new_size, const value_type &x=value_type()) -``` - -**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). - -#### Parameters - - -Number of elements this vector should contain. - - - -Data with which new elements should be populated. - - - - - -Resizes this vector to the specified number of elements, performing default-initialization instead of value-initialization. - -```cpp -void thrust::device_vector< T, Alloc >::resize(size_type new_size, default_init_t) -``` - -**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). - -#### Parameters - - -Number of elements this vector should contain. - - - - - -Resizes this vector to the specified number of elements, without initializing elements. - -```cpp -void thrust::device_vector< T, Alloc >::resize(size_type new_size, no_init_t) -``` - -**Throws:** `std::length_error` if `n` exceeds [max_size()](/library/api/thrust::device_vector::max_size). - -#### Parameters - - -Number of elements this vector should contain. - - - - - ---- - -## Modifiers - -### push_back - -This method appends the given element to the end of this vector. - -```cpp -void thrust::device_vector< T, Alloc >::push_back(const value_type &x) -``` - -#### Parameters - - -The element to append. - - -### pop_back - -This method erases the last element of this vector, invalidating all iterators and references to it. - -```cpp -void thrust::device_vector< T, Alloc >::pop_back() -``` - -### clear - -This method resizes this vector to 0. - -```cpp -void thrust::device_vector< T, Alloc >::clear() -``` - -### swap - -This method swaps the contents of this `device_vector` with another vector. - -```cpp -void thrust::device_vector< T, Alloc >::swap(device_vector &v) -``` - -#### Parameters - - -The vector with which to swap. - - -### insert - - - - -This method inserts a single copy of a given exemplar value at the specified position in this vector. - -```cpp -iterator thrust::device_vector< T, Alloc >::insert(iterator position, const T &x) -``` - -**Returns:** An iterator pointing to the newly inserted element. - -#### Parameters - - -The insertion position. - - - -The exemplar element to copy & insert. - - - - - -This method inserts a copy of an exemplar value to a range at the specified position in this vector. - -```cpp -void thrust::device_vector< T, Alloc >::insert(iterator position, size_type n, const T &x) -``` - -#### Parameters - - -The insertion position. - - - -The number of insertions to perform. - - - -The value to replicate and insert. - - - - - -This method inserts a copy of an input range at the specified position in this vector. - -```cpp -template -void thrust::device_vector< T, Alloc >::insert(iterator position, InputIterator first, InputIterator last) -``` - -#### Template parameters - - -A model of [Input Iterator](https://en.cppreference.com/w/cpp/iterator/input_iterator) whose `value_type` is a model of [Assignable](https://en.cppreference.com/w/cpp/named_req/CopyAssignable). - - -#### Parameters - - -The insertion position. - - - -The beginning of the range to copy. - - - -The end of the range to copy. - - - - - -### erase - - - - -This method removes the element at position pos. - -```cpp -iterator thrust::device_vector< T, Alloc >::erase(iterator pos) -``` - -**Returns:** An iterator pointing to the new location of the element that followed the element at position pos. - -#### Parameters - - -The position of the element of interest. - - - - - -This method removes the range of elements [first,last) from this vector. - -```cpp -iterator thrust::device_vector< T, Alloc >::erase(iterator first, iterator last) -``` - -**Returns:** An iterator pointing to the new location of the element that followed the last element in the sequence [first,last). - -#### Parameters - - -The beginning of the range of elements to remove. - - - -The end of the range of elements to remove. - - - - - -### assign - - - - -This version of `assign` replicates a given exemplar `n` times into this vector. - -```cpp -void thrust::device_vector< T, Alloc >::assign(size_type n, const T &x) -``` - -#### Parameters - - -The number of times to copy `x`. - - - -The exemplar element to replicate. - - - - - -This version of `assign` makes this vector a copy of a given input range. - -```cpp -template -void thrust::device_vector< T, Alloc >::assign(InputIterator first, InputIterator last) -``` - -#### Template parameters - - -A model of [Input Iterator](https://en.cppreference.com/w/cpp/iterator/input_iterator). - - -#### Parameters - - -The beginning of the range to copy. - - - -The end of the range to copy. - - - - - ---- - -## Allocator - -### get_allocator const - -This method returns a copy of this vector's allocator. - -```cpp -allocator_type thrust::device_vector< T, Alloc >::get_allocator() const -``` - -**Returns:** A copy of the allocator used by this vector. - ---- - -## Types - -### Typedefs - -| Name | Definition | -|---|---| -| `Parent` | `detail::vector_base< T, Alloc >` | diff --git a/fern/pages/thrust/device_vector_v3.mdx b/fern/pages/thrust/device_vector_v5.mdx similarity index 97% rename from fern/pages/thrust/device_vector_v3.mdx rename to fern/pages/thrust/device_vector_v5.mdx index cc89a73..04858b0 100644 --- a/fern/pages/thrust/device_vector_v3.mdx +++ b/fern/pages/thrust/device_vector_v5.mdx @@ -37,7 +37,7 @@ The element type of the vector. ## Constructors -### Default and allocator constructors +### device_vector @@ -70,12 +70,7 @@ The allocator to use by this `device_vector`. - - -### Size constructors - - - + explicit @@ -96,7 +91,7 @@ The number of elements to initially create. - + This constructor creates a `device_vector` with the given size, performing only default-initialization instead of value-initialization. @@ -116,7 +111,7 @@ The number of elements to initially create. - + This constructor creates a `device_vector` with the given size, without initializing elements. @@ -136,7 +131,7 @@ The number of elements to initially create. - + explicit @@ -162,12 +157,7 @@ The allocator to use by this `device_vector`. - - -### Fill constructors - - - + explicit @@ -193,7 +183,7 @@ An element to copy. - + explicit @@ -224,12 +214,7 @@ The allocator to use by this `device_vector`. - - -### Copy constructors - - - + Copy constructor copies from an exemplar `device_vector`. @@ -248,7 +233,7 @@ The `device_vector` to copy. - + Copy constructor copies from an exemplar `device_vector`. @@ -334,11 +319,6 @@ The `vector_base` to copy. - - -### Move constructors - - Move constructor moves from another `device_vector`. @@ -382,11 +362,6 @@ The allocator to use by this `device_vector`. - - -### Initializer list constructors - - This constructor builds a `device_vector` from an initializer_list. @@ -430,11 +405,6 @@ The allocator to use by this `device_vector`. - - -### Range constructors - - This constructor builds a `device_vector` from a range. @@ -523,7 +493,7 @@ device_vector& thrust::device_vector::operator=( ``` -**Returns:** `device_vector &` +**Returns:** A reference to this `device_vector`. **Parameters** @@ -544,7 +514,7 @@ device_vector& thrust::device_vector::operator=( ``` -**Returns:** `device_vector &` +**Returns:** A reference to this `device_vector`. **Parameters** @@ -566,7 +536,7 @@ device_vector& thrust::device_vector::operator=( ``` -**Returns:** `device_vector &` +**Returns:** A reference to this `device_vector`. **Parameters** @@ -588,7 +558,7 @@ device_vector& thrust::device_vector::operator=( ``` -**Returns:** `device_vector &` +**Returns:** A reference to this `device_vector`. **Parameters** @@ -610,7 +580,7 @@ device_vector& thrust::device_vector::operator=( ``` -**Returns:** `device_vector &` +**Returns:** A reference to this `device_vector`. **Parameters** @@ -631,7 +601,7 @@ device_vector& thrust::device_vector::operator=( ``` -**Returns:** `device_vector &` +**Returns:** A reference to this `device_vector`. **Parameters** @@ -1145,7 +1115,7 @@ void thrust::device_vector::push_back( ``` -#### Parameters +**Parameters** The element to append. @@ -1183,7 +1153,7 @@ void thrust::device_vector::swap( ``` -#### Parameters +**Parameters** The vector with which to swap. diff --git a/fern/pages/thrust/device_vector_v6.mdx b/fern/pages/thrust/device_vector_v6.mdx new file mode 100644 index 0000000..505ebc6 --- /dev/null +++ b/fern/pages/thrust/device_vector_v6.mdx @@ -0,0 +1,1368 @@ +--- +title: thrust::device_vector +description: "A dynamically-sized container for device memory with automatic memory management." +--- + +A `device_vector` is a container that supports random access to elements, constant time removal of elements at the end, and linear time insertion and removal of elements at the beginning or in the middle. + +The number of elements in a `device_vector` may vary dynamically; memory management is automatic. The memory associated with a `device_vector` resides in the memory accessible to devices. + +```cpp showLineNumbers={false} +#include +``` + +**See also:** +[https://en.cppreference.com/w/cpp/container/vector](https://en.cppreference.com/w/cpp/container/vector), +[device_allocator](/library/api/thrust::device_allocator), +[host_vector](/library/api/thrust::host_vector), +universal_vector + + + + + + + + + + + + + +**Inherits from:** `detail::vector_base< T, thrust::device_allocator< T > >` (public) + +--- + +## Constructors + +### device_vector inline + + + + +This constructor creates an empty `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector() +``` + + + + + +This constructor creates an empty `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + const Alloc &alloc +) +``` + + +**Parameters** + + +The allocator to use by this `device_vector`. + + + + + +explicit + +This constructor creates a `device_vector` with the given size. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +This constructor creates a `device_vector` with the given size, performing only default-initialization instead of value-initialization. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + default_init_t +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +This constructor creates a `device_vector` with the given size, without initializing elements. + +It mandates that the element type is trivially default-constructible. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + no_init_t +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + + + +explicit + +This constructor creates a `device_vector` with the given size. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + const Alloc &alloc +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +The allocator to use by this `device_vector`. + + + + + +explicit + +This constructor creates a `device_vector` with copies of an exemplar element. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + const value_type &value +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +An element to copy. + + + + + +explicit + +This constructor creates a `device_vector` with copies of an exemplar element. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + size_type n, + const value_type &value, + const Alloc &alloc +) +``` + + +**Parameters** + + +The number of elements to initially create. + + + +An element to copy. + + + +The allocator to use by this `device_vector`. + + + + + +Copy constructor copies from an exemplar `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + const device_vector &v +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + + + +Copy constructor copies from an exemplar `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + const device_vector &v, + const Alloc &alloc +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + +The allocator to use by this `device_vector`. + + + + + +explicit + +Copy constructor copies from an exemplar `device_vector` with different type. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + const device_vector &v +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + + + +Copy constructor copies from an exemplar `std::vector`. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + const std::vector &v +) +``` + + +**Parameters** + + +The `std::vector` to copy. + + + + + +Copy construct from a `vector_base` whose element type is convertible to `T`. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + const detail::vector_base &v +) +``` + + +**Parameters** + + +The `vector_base` to copy. + + + + + +Move constructor moves from another `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + device_vector &&v +) +``` + + +**Parameters** + + +The `device_vector` to move. + + + + + +Move constructor moves from another `device_vector`. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + device_vector &&v, + const Alloc &alloc +) +``` + + +**Parameters** + + +The `device_vector` to move. + + + +The allocator to use by this `device_vector`. + + + + + +This constructor builds a `device_vector` from an intializer_list. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + ::cuda::std::initializer_list il +) +``` + + +**Parameters** + + +The intializer_list. + + + + + +This constructor builds a `device_vector` from an intializer_list. + + +```cpp showLineNumbers={false} +thrust::device_vector::device_vector( + ::cuda::std::initializer_list il, + const Alloc &alloc +) +``` + + +**Parameters** + + +The intializer_list. + + + +The allocator to use by this `device_vector`. + + + + + +This constructor builds a `device_vector` from a range. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + InputIterator first, + InputIterator last +) +``` + + +**Parameters** + + +The beginning of the range. + + + +The end of the range. + + + + + +This constructor builds a `device_vector` from a range. + + +```cpp showLineNumbers={false} +template +thrust::device_vector::device_vector( + InputIterator first, + InputIterator last, + const Alloc &alloc +) +``` + + +**Parameters** + + +The beginning of the range. + + + +The end of the range. + + + +The allocator to use by this `device_vector`. + + + + + +### Destructor + +### ~device_vector inline + +The destructor erases the elements. + + +```cpp showLineNumbers={false} +thrust::device_vector::~device_vector() +``` + + +--- + +## Assignment operators + +### operator= inline + + + + +Copy assign operator copies another `device_vector` with the same type. + + +```cpp showLineNumbers={false} +device_vector & thrust::device_vector::operator=( + const device_vector &v +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + + + +Move assign operator moves from another `device_vector`. + + +```cpp showLineNumbers={false} +device_vector & thrust::device_vector::operator=( + device_vector &&v +) +``` + + +**Parameters** + + +The `device_vector` to move. + + + + + +Assign operator copies from an exemplar `device_vector` with different type. + + +```cpp showLineNumbers={false} +template +device_vector & thrust::device_vector::operator=( + const device_vector &v +) +``` + + +**Parameters** + + +The `device_vector` to copy. + + + + + +Assign operator copies from an exemplar `std::vector`. + + +```cpp showLineNumbers={false} +template +device_vector & thrust::device_vector::operator=( + const std::vector &v +) +``` + + +**Parameters** + + +The `std::vector` to copy. + + + + + +Assign a `vector_base` whose element type is convertible to `T`. + + +```cpp showLineNumbers={false} +template +device_vector & thrust::device_vector::operator=( + const detail::vector_base &v +) +``` + + +**Parameters** + + +The `vector_base` to copy. + + + + + +Assign an `intializer_list` with a matching element type. + + +```cpp showLineNumbers={false} +device_vector & thrust::device_vector::operator=( + ::cuda::std::initializer_list il +) +``` + + +**Parameters** + + +The intializer_list. + + + + + +--- + +## Methods + +### resize + + + + +Resizes this vector to the specified number of elements. + + +```cpp showLineNumbers={false} +void thrust::device_vector::resize( + size_type new_size, + const value_type &x = value_type() +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::device_vector::max_size()). + +**Parameters** + + +Number of elements this vector should contain. + + + +Data with which new elements should be populated. + + + + + +Resizes this vector to the specified number of elements, performing default-initialization instead of value-initialization. + + +```cpp showLineNumbers={false} +void thrust::device_vector::resize( + size_type new_size, + default_init_t +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::device_vector::max_size()). + +**Parameters** + + +Number of elements this vector should contain. + + + + + +Resizes this vector_base to the specified number of elements, without initializing elements. + +It mandates that the element type is trivially default-constructible. + + +```cpp showLineNumbers={false} +void thrust::device_vector::resize( + size_type new_size, + no_init_t +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::device_vector::max_size()). + +**Parameters** + + +Number of elements this vector_base should contain. + + + + + +### size const + +Returns the number of elements in this vector. + + +```cpp showLineNumbers={false} +size_type thrust::device_vector::size() const +``` + + +### max_size const + +Returns the [size()](/library/api/thrust::device_vector::size()) of the largest possible vector. + + +```cpp showLineNumbers={false} +size_type thrust::device_vector::max_size() const +``` + + +**Returns:** The largest possible return value of [size()](/library/api/thrust::device_vector::size()). + +### reserve + +If n is less than or equal to [capacity()](/library/api/thrust::device_vector::capacity()), this call has no effect. + + +```cpp showLineNumbers={false} +void thrust::device_vector::reserve( + size_type n +) +``` + + +**Throws:** `std::length_error` If n exceeds [max_size()](/library/api/thrust::device_vector::max_size()). + +### capacity const + +Returns the number of elements which have been reserved in this vector. + + +```cpp showLineNumbers={false} +size_type thrust::device_vector::capacity() const +``` + + +### shrink_to_fit + +This method shrinks the capacity of this vector to exactly fit its elements. + + +```cpp showLineNumbers={false} +void thrust::device_vector::shrink_to_fit() +``` + + +### operator[] + + + + +Subscript access to the data contained in this vector_dev. + +This operator allows for easy, array-style, data access. Note that data access with this operator is unchecked and out_of_range lookups are not defined. + + +```cpp showLineNumbers={false} +reference thrust::device_vector::operator[]( + size_type n +) +``` + + +**Returns:** Read/write reference to data. + +**Parameters** + + +The index of the element for which data should be accessed. + + + + + +const + +Subscript read access to the data contained in this vector_dev. + +This operator allows for easy, array-style, data access. Note that data access with this operator is unchecked and out_of_range lookups are not defined. + + +```cpp showLineNumbers={false} +const_reference thrust::device_vector::operator[]( + size_type n +) const +``` + + +**Returns:** Read reference to data. + +**Parameters** + + +The index of the element for which data should be accessed. + + + + + +### begin + + + + +This method returns an iterator pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::begin() +``` + + +**Returns:** mStart + + + + +const + +This method returns a const_iterator pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::begin() const +``` + + +**Returns:** mStart + + + + +### cbegin const + +This method returns a const_iterator pointing to the beginning of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::cbegin() const +``` + + +**Returns:** mStart + +### rbegin + + + + +This method returns a reverse_iterator pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +reverse_iterator thrust::device_vector::rbegin() +``` + + +**Returns:** A reverse_iterator pointing to the beginning of this vector's reversed sequence. + + + + +const + +This method returns a const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::rbegin() const +``` + + +**Returns:** A const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + + + + +### crbegin const + +This method returns a const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::crbegin() const +``` + + +**Returns:** A const_reverse_iterator pointing to the beginning of this vector's reversed sequence. + +### end + + + + +This method returns an iterator pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::end() +``` + + +**Returns:** [begin()](/library/api/thrust::device_vector::begin()) + [size()](/library/api/thrust::device_vector::size()). + + + + +const + +This method returns a const_iterator pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::end() const +``` + + +**Returns:** [begin()](/library/api/thrust::device_vector::begin()) + [size()](/library/api/thrust::device_vector::size()). + + + + +### cend const + +This method returns a const_iterator pointing to one element past the last of this vector. + + +```cpp showLineNumbers={false} +const_iterator thrust::device_vector::cend() const +``` + + +**Returns:** [begin()](/library/api/thrust::device_vector::begin()) + [size()](/library/api/thrust::device_vector::size()). + +### rend + + + + +This method returns a reverse_iterator pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +reverse_iterator thrust::device_vector::rend() +``` + + +**Returns:** [rbegin()](/library/api/thrust::device_vector::rbegin()) + [size()](/library/api/thrust::device_vector::size()). + + + + +const + +This method returns a const_reverse_iterator pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::rend() const +``` + + +**Returns:** [rbegin()](/library/api/thrust::device_vector::rbegin()) + [size()](/library/api/thrust::device_vector::size()). + + + + +### crend const + +This method returns a const_reverse_iterator pointing to one element past the last of this vector's reversed sequence. + + +```cpp showLineNumbers={false} +const_reverse_iterator thrust::device_vector::crend() const +``` + + +**Returns:** [rbegin()](/library/api/thrust::device_vector::rbegin()) + [size()](/library/api/thrust::device_vector::size()). + +### front + + + + +This method returns a reference pointing to the first element of this vector. + + +```cpp showLineNumbers={false} +reference thrust::device_vector::front() +``` + + +**Returns:** The first element of this vector. + + + + +const + +This method returns a const_reference referring to the first element of this vector. + + +```cpp showLineNumbers={false} +const_reference thrust::device_vector::front() const +``` + + +**Returns:** The first element of this vector. + + + + +### back + + + + +This method returns a reference referring to the last element of this vector_dev. + + +```cpp showLineNumbers={false} +reference thrust::device_vector::back() +``` + + +**Returns:** The last element of this vector. + + + + +const + +This method returns a const reference pointing to the last element of this vector. + + +```cpp showLineNumbers={false} +const_reference thrust::device_vector::back() const +``` + + +**Returns:** The last element of this vector. + + + + +### data + + + + +This method returns a pointer to this vector's first element. + + +```cpp showLineNumbers={false} +pointer thrust::device_vector::data() +``` + + +**Returns:** A pointer to the first element of this vector. + + + + +const + +This method returns a const_pointer to this vector's first element. + + +```cpp showLineNumbers={false} +const_pointer thrust::device_vector::data() const +``` + + +**Returns:** a const_pointer to the first element of this vector. + + + + +### clear + +This method resizes this vector to 0. + + +```cpp showLineNumbers={false} +void thrust::device_vector::clear() +``` + + +### empty const + +This method returns true iff [size()](/library/api/thrust::device_vector::size()) == 0. + + +```cpp showLineNumbers={false} +bool thrust::device_vector::empty() const +``` + + +**Returns:** true if [size()](/library/api/thrust::device_vector::size()) == 0; false, otherwise. + +### push_back + +This method appends the given element to the end of this vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::push_back( + const value_type &x +) +``` + + +**Parameters** + + +The element to append. + + +### pop_back + +This method erases the last element of this vector, invalidating all iterators and references to it. + + +```cpp showLineNumbers={false} +void thrust::device_vector::pop_back() +``` + + +### swap + +This method swaps the contents of this `device_vector` with another vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::swap( + device_vector &v +) +``` + + +**Parameters** + + +The vector with which to swap. + + +### erase + + + + +This method removes the element at position pos. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::erase( + iterator pos +) +``` + + +**Returns:** An iterator pointing to the new location of the element that followed the element at position pos. + +**Parameters** + + +The position of the element of interest. + + + + + +This method removes the range of elements [first,last) from this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::erase( + iterator first, + iterator last +) +``` + + +**Returns:** An iterator pointing to the new location of the element that followed the last element in the sequence [first,last). + +**Parameters** + + +The beginning of the range of elements to remove. + + + +The end of the range of elements to remove. + + + + + +### insert + + + + +This method inserts a single copy of a given exemplar value at the specified position in this vector. + + +```cpp showLineNumbers={false} +iterator thrust::device_vector::insert( + iterator position, + const T &x +) +``` + + +**Returns:** An iterator pointing to the newly inserted element. + +**Parameters** + + +The insertion position. + + + +The exemplar element to copy & insert. + + + + + +This method inserts a copy of an exemplar value to a range at the specified position in this vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::insert( + iterator position, + size_type n, + const T &x +) +``` + + +**Parameters** + + +The insertion position + + + +The number of insertions to perform. + + + +The value to replicate and insert. + + + + + +This method inserts a copy of an input range at the specified position in this vector. + + +```cpp showLineNumbers={false} +template +void thrust::device_vector::insert( + iterator position, + InputIterator first, + InputIterator last +) +``` + + +**Template parameters** + + +Is a model of [Input Iterator](https://en.cppreference.com/w/cpp/iterator/input_iterator), and `InputIterator's` `value_type` is a model of [Assignable.](https://en.cppreference.com/w/cpp/named_req/CopyAssignable) + + +**Parameters** + + +The insertion position. + + + +The beginning of the range to copy. + + + +The end of the range to copy. + + + + + +### assign + + + + +This version of `assign` replicates a given exemplar `n` times into this vector. + + +```cpp showLineNumbers={false} +void thrust::device_vector::assign( + size_type n, + const T &x +) +``` + + +**Parameters** + + +The number of times to copy `x`. + + + +The exemplar element to replicate. + + + + + +This version of `assign` makes this vector a copy of a given input range. + + +```cpp showLineNumbers={false} +template +void thrust::device_vector::assign( + InputIterator first, + InputIterator last +) +``` + + +**Template parameters** + + +Is a model of [Input Iterator](https://en.cppreference.com/w/cpp/named_req/InputIterator). + + +**Parameters** + + +The beginning of the range to copy. + + + +The end of the range to copy. + + + + + +### get_allocator const + +This method returns a copy of this vector's allocator. + + +```cpp showLineNumbers={false} +allocator_type thrust::device_vector::get_allocator() const +``` + + +**Returns:** A copy of the allocator used by this vector. + +--- + +## Types + +### Typedefs + +| Name | Definition | +|---|---| +| `Parent` | `detail::vector_base< T, Alloc >` | diff --git a/fern/pages/thrust/group_member_example_v4.mdx b/fern/pages/thrust/group_member_example_v5.mdx similarity index 99% rename from fern/pages/thrust/group_member_example_v4.mdx rename to fern/pages/thrust/group_member_example_v5.mdx index df1192b..e755c63 100644 --- a/fern/pages/thrust/group_member_example_v4.mdx +++ b/fern/pages/thrust/group_member_example_v5.mdx @@ -405,4 +405,4 @@ struct thrust::mr::disjoint_unsynchronized_pool_resource::pool | Name | Type | Description | |---|---|---| | `free_blocks` | `pointer_vector` | | -| `previous_allocated_count` | `std::size_t` | | +| `previous_allocated_count` | `std::size_t` | | \ No newline at end of file diff --git a/fern/pages/thrust/group_member_example.mdx b/fern/pages/thrust/group_member_example_v6.mdx similarity index 75% rename from fern/pages/thrust/group_member_example.mdx rename to fern/pages/thrust/group_member_example_v6.mdx index bb46084..86dd2ad 100644 --- a/fern/pages/thrust/group_member_example.mdx +++ b/fern/pages/thrust/group_member_example_v6.mdx @@ -19,17 +19,17 @@ This is not the only case where it makes sense to use a disjoint pool resource, -The type of memory resources that will be used for allocating memory blocks to be handed off to the user. +The type of memory resources that will be used for allocating memory blocks to be handed off to the user -The type of memory resources that will be used for allocating bookkeeping memory. +The type of memory resources that will be used for allocating bookkeeping memory -**Inherits from:** [`thrust::mr::memory_resource< Upstream::pointer >`](/library/api/thrust::mr::memory_resource) (public), [`thrust::mr::validator2< Upstream, Bookkeeper >`](/library/api/thrust::mr::validator2) (private) +**Inherits from:** `thrust::mr::memory_resource< Upstream::pointer >` (public), `thrust::mr::validator2< Upstream, Bookkeeper >` (private) This class is marked final. @@ -40,13 +40,13 @@ This class is marked final. ### disjoint_unsynchronized_pool_resource inline - + Constructor. ```cpp showLineNumbers={false} -thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::disjoint_unsynchronized_pool_resource( +thrust::mr::disjoint_unsynchronized_pool_resource::disjoint_unsynchronized_pool_resource( Upstream *upstream, Bookkeeper *bookkeeper, pool_options options = get_default_options() @@ -57,25 +57,27 @@ thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::disjo **Parameters** -The upstream memory resource for allocations. +The upstream memory resource for allocations -The upstream memory resource for bookkeeping. +The upstream memory resource for bookkeeping -Pool options to use. +Pool options to use - + -Constructor. Upstream and bookkeeping resources are obtained by calling `get_global_resource` for their types. +Constructor. + +Upstream and bookkeeping resources are obtained by calling `get_global_resource` for their types. ```cpp showLineNumbers={false} -thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::disjoint_unsynchronized_pool_resource( +thrust::mr::disjoint_unsynchronized_pool_resource::disjoint_unsynchronized_pool_resource( pool_options options = get_default_options() ) ``` @@ -84,7 +86,7 @@ thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::disjo **Parameters** -Pool options to use. +Pool options to use @@ -94,17 +96,19 @@ Pool options to use. ### ~disjoint_unsynchronized_pool_resource inline -Destructor. Releases all held memory to upstream. +Destructor. + +Releases all held memory to upstream. ```cpp showLineNumbers={false} -thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::~disjoint_unsynchronized_pool_resource() +thrust::mr::disjoint_unsynchronized_pool_resource::~disjoint_unsynchronized_pool_resource() ``` --- -## Pool management +## Methods ### release inline @@ -112,7 +116,7 @@ Releases all held memory to upstream. ```cpp showLineNumbers={false} -void thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::release() +void thrust::mr::disjoint_unsynchronized_pool_resource::release() ``` @@ -120,21 +124,17 @@ void thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >:: ```cpp showLineNumbers={false} -void thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::squeeze() +void thrust::mr::disjoint_unsynchronized_pool_resource::squeeze() ``` ---- - -## Allocation - ### do_allocate inline nodiscard virtual Allocates memory of size at least `bytes` and alignment at least `alignment`. ```cpp showLineNumbers={false} -virtual void_ptr thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::do_allocate( +virtual void_ptr thrust::mr::disjoint_unsynchronized_pool_resource::do_allocate( std::size_t bytes, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT ) override @@ -143,21 +143,23 @@ virtual void_ptr thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bo **Returns:** A pointer to void to the newly allocated memory. +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + **Parameters** -Size, in bytes, that is requested from this allocation. +Size, in bytes, that is requested from this allocation -Alignment that is requested from this allocation. +Alignment that is requested from this allocation ### do_allocate_impl inline nodiscard ```cpp showLineNumbers={false} -void_ptr thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::do_allocate_impl( +void_ptr thrust::mr::disjoint_unsynchronized_pool_resource::do_allocate_impl( std::size_t bytes, std::size_t alignment ) @@ -170,9 +172,9 @@ Deallocates memory pointed to by `p`. ```cpp showLineNumbers={false} -virtual void thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::do_deallocate( +virtual void thrust::mr::disjoint_unsynchronized_pool_resource::do_deallocate( void_ptr p, - std::size_t bytes, + std::size_t n, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT ) override ``` @@ -181,7 +183,7 @@ virtual void thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookke **Parameters** -Pointer to be deallocated. +Pointer to be deallocated @@ -196,9 +198,9 @@ The size of the allocation. This must be equivalent to the value of `alignment` Allocates memory of size at least `bytes` and alignment at least `alignment`. - + ```cpp showLineNumbers={false} -pointer thrust::mr::memory_resource< Upstream::pointer >::allocate( +pointer thrust::mr::memory_resource::allocate( std::size_t bytes, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT ) @@ -207,23 +209,25 @@ pointer thrust::mr::memory_resource< Upstream::pointer >::allocate( **Returns:** A pointer to void to the newly allocated memory. +**Throws:** `thrust::bad_alloc` when no memory with requested size and alignment can be allocated. + **Parameters** -Size, in bytes, that is requested from this allocation. +Size, in bytes, that is requested from this allocation -Alignment that is requested from this allocation. +Alignment that is requested from this allocation ### deallocate inline noexcept Deallocates memory pointed to by `p`. - + ```cpp showLineNumbers={false} -void thrust::mr::memory_resource< Upstream::pointer >::deallocate( +void thrust::mr::memory_resource::deallocate( pointer p, std::size_t bytes, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT @@ -234,7 +238,7 @@ void thrust::mr::memory_resource< Upstream::pointer >::deallocate( **Parameters** -Pointer to be deallocated. +Pointer to be deallocated @@ -245,30 +249,26 @@ The size of the allocation. This must be equivalent to the value of `bytes` that The alignment of the allocation. This must be equivalent to the value of `alignment` that was passed to the allocation function that returned `p`. ---- - -## Comparison - ### is_equal inline const noexcept Compares this resource to the other one. The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. - + ```cpp showLineNumbers={false} -bool thrust::mr::memory_resource< Upstream::pointer >::is_equal( +bool thrust::mr::memory_resource::is_equal( const memory_resource &other ) const noexcept ``` -**Returns:** Whether the two resources are equivalent. +**Returns:** whether the two resources are equivalent. **Parameters** -The other resource to compare this resource to. +The other resource to compare this resource to ### do_is_equal inline const noexcept virtual @@ -277,20 +277,20 @@ Compares this resource to the other one. The default implementation uses identity comparison, which is often the right thing to do and doesn't require RTTI involvement. - + ```cpp showLineNumbers={false} -virtual bool thrust::mr::memory_resource< Upstream::pointer >::do_is_equal( +virtual bool thrust::mr::memory_resource::do_is_equal( const memory_resource &other ) const noexcept ``` -**Returns:** Whether the two resources are equivalent. +**Returns:** whether the two resources are equivalent. **Parameters** -The other resource to compare this resource to. +The other resource to compare this resource to --- @@ -301,9 +301,11 @@ The other resource to compare this resource to. Get the default options for a disjoint pool. +These are meant to be a sensible set of values for many use cases, and as such, may be tuned in the future. This function is exposed so that creating a set of options that are just a slight departure from the defaults is easy. + ```cpp showLineNumbers={false} -static pool_options thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, Bookkeeper >::get_default_options() +static pool_options thrust::mr::disjoint_unsynchronized_pool_resource::get_default_options() ``` @@ -327,16 +329,16 @@ static pool_options thrust::mr::disjoint_unsynchronized_pool_resource< Upstream, ## Member variables -| Name | Type | -|---|---| -| `m_upstream` | `Upstream *` | -| `m_bookkeeper` | `Bookkeeper *` | -| `m_options` | [`pool_options`](/library/api/thrust::mr::pool_options) | -| `m_smallest_block_log2` | `std::size_t` | -| `m_pools` | `pool_vector` | -| `m_allocated` | `chunk_vector` | -| `m_cached_oversized` | `oversized_block_vector` | -| `m_oversized` | `oversized_block_vector` | +| Name | Type | Description | +|---|---|---| +| `m_upstream` | `Upstream *` | | +| `m_bookkeeper` | `Bookkeeper *` | | +| `m_options` | `pool_options` | | +| `m_smallest_block_log2` | `std::size_t` | | +| `m_pools` | `pool_vector` | | +| `m_allocated` | `chunk_vector` | | +| `m_cached_oversized` | `oversized_block_vector` | | +| `m_oversized` | `oversized_block_vector` | | --- @@ -350,11 +352,11 @@ struct thrust::mr::disjoint_unsynchronized_pool_resource::chunk_descriptor ``` -| Member | Type | -|---|---| -| `size` | `std::size_t` | -| `pointer` | `void_ptr` | -| `pool_idx` | `std::size_t` | +| Name | Type | Description | +|---|---|---| +| `size` | `std::size_t` | | +| `pointer` | `void_ptr` | | +| `pool_idx` | `std::size_t` | | ### oversized_block_descriptor @@ -364,11 +366,11 @@ struct thrust::mr::disjoint_unsynchronized_pool_resource::oversized_block_descri ``` -| Member | Type | -|---|---| -| `size` | `std::size_t` | -| `alignment` | `std::size_t` | -| `pointer` | `void_ptr` | +| Name | Type | Description | +|---|---|---| +| `size` | `std::size_t` | | +| `alignment` | `std::size_t` | | +| `pointer` | `void_ptr` | | ### equal_pointers @@ -378,9 +380,9 @@ struct thrust::mr::disjoint_unsynchronized_pool_resource::equal_pointers ``` -| Member | Type | -|---|---| -| `p` | `void_ptr` | +| Name | Type | Description | +|---|---|---| +| `p` | `void_ptr` | | ### matching_alignment @@ -390,9 +392,9 @@ struct thrust::mr::disjoint_unsynchronized_pool_resource::matching_alignment ``` -| Member | Type | -|---|---| -| `requested` | `std::size_t` | +| Name | Type | Description | +|---|---|---| +| `requested` | `std::size_t` | | ### pool @@ -402,7 +404,7 @@ struct thrust::mr::disjoint_unsynchronized_pool_resource::pool ``` -| Member | Type | -|---|---| -| `free_blocks` | `pointer_vector` | -| `previous_allocated_count` | `std::size_t` | +| Name | Type | Description | +|---|---|---| +| `free_blocks` | `pointer_vector` | | +| `previous_allocated_count` | `std::size_t` | | diff --git a/fern/pages/thrust/pointer.mdx b/fern/pages/thrust/pointer_v5.mdx similarity index 98% rename from fern/pages/thrust/pointer.mdx rename to fern/pages/thrust/pointer_v5.mdx index 362e7e4..585b86b 100644 --- a/fern/pages/thrust/pointer.mdx +++ b/fern/pages/thrust/pointer_v5.mdx @@ -15,9 +15,9 @@ The raw pointer encapsulated by a `pointer` may be obtained through its [`get`]( #include ``` - + `pointer` is not a smart pointer; it is the client's responsibility to deallocate memory pointer to by `pointer`. - + **See also:** [device_ptr](/library/api/thrust::device_ptr), @@ -151,7 +151,7 @@ derived_type& thrust::pointer< Element, Tag, Reference, Derived >::operator=( ``` -**Returns:** `derived_type &` +**Returns:** A reference to the derived pointer type. @@ -168,7 +168,7 @@ thrust::pointer< Element, Tag, Reference, Derived >::operator=( ``` -**Returns:** `*this` +**Returns:** A reference to `*this`. **Template parameters** @@ -239,30 +239,30 @@ Base const & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, ### base_reference - + -inline const +inline ```cpp showLineNumbers={false} -Base const & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base_reference() const +Base & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base_reference() ``` -**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. +**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. - + -inline +inline const ```cpp showLineNumbers={false} -Base & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base_reference() +Base const & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base_reference() const ``` -**Returns:** A mutable reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. +**Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. @@ -292,4 +292,4 @@ static derived_type thrust::pointer< Element, Tag, Reference, Derived >::pointer | `super_t` | `typename detail::pointer_base< Element, Tag, Reference, Derived >::type` | | | `derived_type` | `typename detail::pointer_base< Element, Tag, Reference, Derived >::derived_type` | | | `raw_pointer` | `typename super_t::base_type` | The type of the raw pointer. | -| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)'s adapts. | +| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)'s adapts. | \ No newline at end of file diff --git a/fern/pages/thrust/pointer_v4.mdx b/fern/pages/thrust/pointer_v6.mdx similarity index 73% rename from fern/pages/thrust/pointer_v4.mdx rename to fern/pages/thrust/pointer_v6.mdx index 3c7b616..1020566 100644 --- a/fern/pages/thrust/pointer_v4.mdx +++ b/fern/pages/thrust/pointer_v6.mdx @@ -7,7 +7,7 @@ description: "A tagged pointer type that stores a pointer to an object allocated Like [`device_ptr`](/library/api/thrust::device_ptr), this type ensures type safety when dispatching standard algorithms on ranges resident in memory. -`pointer` generalizes [`device_ptr`](/library/api/thrust::device_ptr) by relaxing the backend system associated with the `pointer`. Instead of the backend system specified by `THRUST_DEVICE_SYSTEM`, `pointer`'s system is given by its second template parameter, `Tag`. For the purpose of Thrust dispatch, [`device_ptr`](/library/api/thrust::device_ptr) and `pointer` are considered equivalent. +`pointer` generalizes [`device_ptr`](/library/api/thrust::device_ptr) by relaxing the backend system associated with the `pointer`. Instead of the backend system specified by `THRUST_DEVICE_SYSTEM`, `pointer's` system is given by its second template parameter, `Tag`. For the purpose of Thrust dispatch, [`device_ptr`](/library/api/thrust::device_ptr) and [`pointer`](/library/api/thrust::pointer::pointer%3CElement,device_system_tag%3E) are considered equivalent. The raw pointer encapsulated by a `pointer` may be obtained through its [`get`](/library/api/thrust::pointer::get) member function or the `raw_pointer_cast` free function. @@ -22,7 +22,7 @@ The raw pointer encapsulated by a `pointer` may be obtained through its [`get`]( **See also:** [device_ptr](/library/api/thrust::device_ptr), reference, -[raw_pointer_cast](/library/api/thrust::raw_pointer_cast) +raw_pointer_cast @@ -36,17 +36,17 @@ Specifies the system with which this `pointer` is associated. This may be any Th -Allows the client to specify the reference type returned upon dereference. By default, this type is `reference`. +Allows the client to specify the reference type returned upon derereference. By default, this type is `reference`. -Allows the client to specify the name of the derived type when `pointer` is used as a base class. This is useful to ensure that arithmetic on values of the derived type return values of the derived type as a result. By default, this type is `pointer`. +Allows the client to specify the name of the derived type when `pointer` is used as a base class. This is useful to ensure that arithmetic on values of the derived type return values of the derived type as a result. By default, this type is [`pointer`](/library/api/thrust::pointer::pointer%3CElement,Tag,Reference%3E). -**Inherits from:** [`thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >`](/library/api/thrust::iterator_adaptor) (public) +**Inherits from:** `thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >` (public) --- @@ -57,11 +57,11 @@ Allows the client to specify the name of the derived type when `pointer` is used -`pointer`'s default constructor initializes its encapsulated pointer to `0`. +`pointer's` default constructor initializes its encapsulated pointer to `0` ```cpp showLineNumbers={false} -thrust::pointer< Element, Tag, Reference, Derived >::pointer() +thrust::pointer::pointer() ``` @@ -70,7 +70,7 @@ thrust::pointer< Element, Tag, Reference, Derived >::pointer() ```cpp showLineNumbers={false} -thrust::pointer< Element, Tag, Reference, Derived >::pointer( +thrust::pointer::pointer( ::cuda::std::nullptr_t ) ``` @@ -81,12 +81,12 @@ thrust::pointer< Element, Tag, Reference, Derived >::pointer( explicit -This constructor allows construction of a `pointer` from a `T *`. +This constructor allows construction of a [`pointer`](/library/api/thrust::pointer::pointer%3Cconst T, ...%3E) from a `T*`. ```cpp showLineNumbers={false} template -thrust::pointer< Element, Tag, Reference, Derived >::pointer( +thrust::pointer::pointer( OtherElement *ptr ) ``` @@ -101,7 +101,7 @@ thrust::pointer< Element, Tag, Reference, Derived >::pointer( **Parameters** -A raw pointer to copy from, presumed to point to a location in `Tag`'s memory. +A raw pointer to copy from, presumed to point to a location in `Tag's` memory. @@ -111,9 +111,8 @@ This constructor allows initialization from another pointer-like object. ```cpp showLineNumbers={false} -template * = nullptr> -thrust::pointer< Element, Tag, Reference, Derived >::pointer( +template +thrust::pointer::pointer( const OtherPointer &other ) ``` @@ -145,14 +144,12 @@ The `OtherPointer` to copy. ```cpp showLineNumbers={false} -derived_type& thrust::pointer< Element, Tag, Reference, Derived >::operator=( +derived_type & thrust::pointer::operator=( ::cuda::std::nullptr_t ) ``` -**Returns:** `derived_type &` - @@ -161,8 +158,7 @@ Assignment operator allows assigning from another pointer-like object whose elem ```cpp showLineNumbers={false} template -detail::enable_if_pointer_is_convertible_t -thrust::pointer< Element, Tag, Reference, Derived >::operator=( +detail::enable_if_pointer_is_convertible_t thrust::pointer::operator=( const OtherPointer &other ) ``` @@ -189,40 +185,40 @@ The other pointer-like object to assign from. ## Methods -### get inline const - -`get` returns this `pointer`'s encapsulated raw pointer. +### dereference inline const ```cpp showLineNumbers={false} -Element * thrust::pointer< Element, Tag, Reference, Derived >::get() const +template +SuperRef thrust::pointer::dereference() const ``` -**Returns:** This `pointer`'s raw pointer. +### get inline const -### operator-> inline const +`get` returns this `pointer's` encapsulated raw pointer. ```cpp showLineNumbers={false} -Element * thrust::pointer< Element, Tag, Reference, Derived >::operator->() const +Element * thrust::pointer::get() const ``` -### operator bool inline explicit const +**Returns:** This `pointer's` raw pointer. + +### operator-> inline const ```cpp showLineNumbers={false} -thrust::pointer< Element, Tag, Reference, Derived >::operator bool() const +Element * thrust::pointer::operator->() const ``` -### dereference inline const +### operator bool inline explicit const ```cpp showLineNumbers={false} -template -SuperRef thrust::pointer< Element, Tag, Reference, Derived >::dereference() const +thrust::pointer::operator bool() const ``` @@ -230,22 +226,20 @@ SuperRef thrust::pointer< Element, Tag, Reference, Derived >::dereference() cons ```cpp showLineNumbers={false} -Base const & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base() const +Base const & thrust::iterator_adaptor::base() const ``` **Returns:** A `const` reference to the `Base` iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor) adapts. -### base_reference +### base_reference inline -inline - ```cpp showLineNumbers={false} -Base & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base_reference() +Base & thrust::iterator_adaptor::base_reference() ``` @@ -254,11 +248,11 @@ Base & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Refere -inline const +const ```cpp showLineNumbers={false} -Base const & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, Reference, Difference >::base_reference() const +Base const & thrust::iterator_adaptor::base_reference() const ``` @@ -275,17 +269,12 @@ Base const & thrust::iterator_adaptor< Derived, Base, Value, System, Traversal, ```cpp showLineNumbers={false} -static derived_type thrust::pointer< Element, Tag, Reference, Derived >::pointer_to( - typename detail::pointer_traits_detail::pointer_to_param< Element >::type r +static derived_type thrust::pointer::pointer_to( + typename detail::pointer_traits_detail::pointer_to_param::type r ) ``` -**Parameters** - - - - --- ## Types @@ -297,4 +286,4 @@ static derived_type thrust::pointer< Element, Tag, Reference, Derived >::pointer | `super_t` | `typename detail::pointer_base< Element, Tag, Reference, Derived >::type` | | | `derived_type` | `typename detail::pointer_base< Element, Tag, Reference, Derived >::derived_type` | | | `raw_pointer` | `typename super_t::base_type` | The type of the raw pointer. | -| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)'s adapts. | +| `base_type` | `Base` | The type of iterator this [`iterator_adaptor`](/library/api/thrust::iterator_adaptor)`'s` `adapts`. |