Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bionemo-recipes/recipes/evo2_megatron/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ RUN uv venv --system-site-packages --seed $VIRTUAL_ENV
# accidentally, though your pyproject.toml overrides handle the critical ones.
RUN pip freeze | grep transformer_engine > pip-constraints.txt

## Trying to workaround CAUSAL_CONV1D build issues
ENV CAUSAL_CONV1D_FORCE_BUILD=1
ENV PIP_NO_BINARY=causal-conv1d

# 5. Install package and dependencies
RUN --mount=type=secret,id=netrc,target=/root/.netrc \
--mount=type=cache,target=/root/.cache/uv \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,16 @@ def forward(self, x, layer_past=None, inference_context=None, _hyena_use_cp=True
Returns:
Tuple of (output tensor, bias)
"""
# CP control
if _hyena_use_cp:
# CP control: disable CP during inference because the inference path
# does not split sequences across CP ranks (the full sequence is on each rank).
# The AllToAll operations in Hyena operators assume sequence-split input which
# only happens during training.
if inference_context is not None:
_proj_use_cp = False
elif _hyena_use_cp:
cp_group = self.pg_collection.cp
cp_size = cp_group.size()
else:
cp_group = None
cp_size = 1
Comment thread
jstjohn marked this conversation as resolved.
if cp_group is not None and cp_size > 1:
_proj_use_cp = True
_proj_use_cp = cp_group is not None and cp_size > 1
else:
_proj_use_cp = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,8 @@ def get_filter_state(filter_name):
key = f"{filter_name}_filter_state_dict"
return getattr(inference_context, key, {}).get(id(self))

# x1, x2, v all of shape torch.Size([1, 4096, 63])
u = torch.cat([x2, x1, v], dim=1) # torch.Size([1, 12288, 63])
# x1, x2, v all of shape [B, width_per_tp_group, L]
u = torch.cat([x2, x1, v], dim=1) # [B, 3 * width_per_tp_group, L]
L = u.shape[-1] # noqa: N806
poles = rearrange(self.filter.p, "d n -> d n 1") # n = 16
poles = self.filter.get_logp()
Expand All @@ -994,13 +994,13 @@ def get_filter_state(filter_name):
iir_state = get_filter_state("iir")
if iir_state is None:
y, iir_state = engine.parallel_iir(
z_pre=u, # [1 d l]
h=h, # must be in [1 d l]
D=bias, # self.short_filter_bias,
z_pre=u, # [B, 3 * width_per_tp_group, L]
h=h, # [width_per_tp_group, L]
D=bias, # [width_per_tp_group]
L=L,
poles=poles,
t=self.filter.get_t(L), # torch.Size([1, 1, L])
hidden_size=self.hidden_size,
t=self.filter.get_t(L), # [1, 1, L]
hidden_size=self.width_per_tp_group,
compute_state=inference_context is not None,
)
# y = rearrange(y, "b d l -> b l d")
Expand Down Expand Up @@ -1093,6 +1093,18 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None):
cp_group = None
cp_size = 1

# When CP is disabled (e.g., during inference), ensure the ImplicitModalFilter
# also uses full (non-CP-sharded) parameters. The filter stores _cp_size at init
# and slices its parameters (R, p, gamma) by CP rank in compute_filter/get_logp.
# During inference the full parameter set is needed since data is not CP-split.
_filter_cp_override = False
if cp_group is None and isinstance(self.filter, ImplicitModalFilter) and self.filter._cp_size > 1:
self._saved_filter_cp_size = self.filter._cp_size
self._saved_filter_cp_rank = self.filter._cp_rank
self.filter._cp_size = 1
self.filter._cp_rank = 0
_filter_cp_override = True

# The kernel length must be adjusted in CP settings
_L_kernel = L if cp_group is None else L * cp_size # noqa: N806
if self.use_medium_hyena:
Expand Down Expand Up @@ -1133,30 +1145,47 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None):

h = h.repeat_interleave(self.group_dim, dim=-2)

if inference_context is not None: # Needs full length x1 x2 v
if self.operator_type == "hyena_medium_conv":
return self.forward_medium(x1=x1, x2=x2, v=v, h=h, bias=conv_bias, inference_context=inference_context)
elif self.operator_type == "hyena":
return self.forward_long(x1=x1, x2=x2, v=v, h=h, bias=conv_bias, inference_context=inference_context)
else: # Needs full length z (post gating)
# with torch.autocast("cuda"):
z = fftconv_func(
u=z.to(torch.float32),
k=h.to(torch.float32),
D=conv_bias.to(torch.float32),
dropout_mask=None,
gelu=False,
bidirectional=self.bidirectional,
use_subquadratic_ops=self.use_subquadratic_ops,
)
z = z.to(v.dtype)

if cp_group is not None and cp_size > 1:
z = AllToAllSingleFunction.apply(z, cp_group, "full_to_split", True)
# [ B, H, L // num_ranks]

z = x1 * z
return z # [B, (G, DG), L]
try:
if inference_context is not None: # Needs full length x1 x2 v
if self.operator_type == "hyena_medium_conv":
z = self.forward_medium(
x1=x1, x2=x2, v=v, h=h, bias=conv_bias, inference_context=inference_context
)
elif self.operator_type == "hyena":
z = self.forward_long(x1=x1, x2=x2, v=v, h=h, bias=conv_bias, inference_context=inference_context)
else:
raise ValueError(f"Unsupported operator_type for inference: {self.operator_type}")

# Reverse AllToAll: convert from channel-split back to sequence-split
# [B, H/cp_size, L_full] -> [B, H, L_local]
if cp_group is not None and cp_size > 1:
z = AllToAllSingleFunction.apply(z, cp_group, "full_to_split", True)

return z
else: # Needs full length z (post gating)
# with torch.autocast("cuda"):
z = fftconv_func(
u=z.to(torch.float32),
k=h.to(torch.float32),
D=conv_bias.to(torch.float32),
dropout_mask=None,
gelu=False,
bidirectional=self.bidirectional,
use_subquadratic_ops=self.use_subquadratic_ops,
)
z = z.to(v.dtype)

if cp_group is not None and cp_size > 1:
z = AllToAllSingleFunction.apply(z, cp_group, "full_to_split", True)
# [ B, H, L // num_ranks]

z = x1 * z
return z # [B, (G, DG), L]
finally:
# Restore filter CP settings if they were overridden
if _filter_cp_override:
self.filter._cp_size = self._saved_filter_cp_size
self.filter._cp_rank = self._saved_filter_cp_rank

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Sharded state dictionary for the ParallelHyenaOperator."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def setup_inference_engine(
model_provider.tensor_model_parallel_size = tensor_parallel_size
model_provider.pipeline_model_parallel_size = pipeline_model_parallel_size
model_provider.context_parallel_size = context_parallel_size
model_provider.sequence_parallel = tensor_parallel_size > 1
# Disable sequence parallelism for inference - Megatron's inference engine
# does not support it for non-MoE models.
model_provider.sequence_parallel = False

# Enable flash decode for inference
model_provider.flash_decode = True
Expand Down Expand Up @@ -489,6 +491,12 @@ def parse_args() -> argparse.Namespace:
default=default_prompt,
help="Prompt text for generation",
)
ap.add_argument(
"--prompt-file",
type=Path,
default=None,
help="Read prompt from a text file (overrides --prompt). Useful for long prompts that exceed shell argument limits.",
)
ap.add_argument("--max-new-tokens", type=int, default=100, help="Maximum tokens to generate")
ap.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
ap.add_argument("--top-k", type=int, default=0, help="Top-k sampling (0 = disabled)")
Expand Down Expand Up @@ -607,8 +615,15 @@ def infer(
def main() -> None:
"""CLI entry point for Evo2 text generation."""
args = parse_args()

# Read prompt from file if specified (overrides --prompt)
prompt = args.prompt
if args.prompt_file is not None:
with open(args.prompt_file) as f:
prompt = f.read().strip()

infer(
prompt=args.prompt,
prompt=prompt,
ckpt_dir=args.ckpt_dir,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
Expand Down
Loading