Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions configs/agents/rl/push_cube/gym_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
"id": "PushCubeRL",
"max_episodes": 5,
"env": {
"num_envs": 128,
"sim_steps_per_control": 4,
"events": {
"randomize_cube": {
"func": "randomize_rigid_object_pose",
Expand Down
3 changes: 1 addition & 2 deletions configs/gym/pour_water/gym_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,7 @@
"task_description": "Pour water",
"data_type": "sim"
},
"use_videos": true,
"export_success_only": false
"use_videos": true
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions configs/gym/pour_water/gym_config_simple.json
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,7 @@
"task_description": "Pour water",
"data_type": "sim"
},
"use_videos": true,
"export_success_only": false
"use_videos": true
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions configs/gym/special/simple_task_ur10.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@
"task_description": "Oscillatory motion",
"data_type": "sim"
},
"use_videos": false,
"export_success_only": false
"use_videos": false
}
}
}
Expand Down
1 change: 0 additions & 1 deletion docs/source/overview/gym/env.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ The manager operates in a single mode ``"save"`` which handles both recording an
* ``robot_meta``: Robot metadata dictionary (required for LeRobot format).
* ``instruction``: Task instruction dictionary.
* ``use_videos``: Whether to save video recordings of episodes.
* ``export_success_only``: Filter to save only successful episodes (based on ``info["success"]``).

The dataset manager is called automatically during {meth}`~envs.Env.step()`, ensuring all observation-action pairs are recorded without additional user code.

Expand Down
82 changes: 49 additions & 33 deletions embodichain/lab/gym/envs/embodied_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
# ----------------------------------------------------------------------------

from math import log
import os
import torch
import numpy as np
Expand Down Expand Up @@ -176,6 +177,9 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs):
self.episode_action_buffer: Dict[int, List[EnvAction]] = {
i: [] for i in range(self.num_envs)
}
self.episode_success_status: Dict[int, bool] = {
i: False for i in range(self.num_envs)
}

def _init_sim_state(self, **kwargs):
"""Initialize the simulation state at the beginning of scene creation."""
Expand Down Expand Up @@ -317,18 +321,11 @@ def _hook_after_sim_step(
self.episode_obs_buffer[env_id].append(single_obs)
self.episode_action_buffer[env_id].append(single_action)

# Call dataset manager with mode="save": it will record and auto-save if dones=True
if self.cfg.dataset:
if "save" in self.dataset_manager.available_modes:
self.dataset_manager.apply(
mode="save",
env_ids=None,
obs=obs,
action=action,
dones=dones,
terminateds=terminateds,
info=info,
)
# Update success status if episode is done
if dones[env_id].item():
if "success" in info:
success_value = info["success"]
self.episode_success_status[env_id] = success_value[env_id].item()

def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs:
if self.observation_manager:
Expand Down Expand Up @@ -368,14 +365,49 @@ def _update_sim_state(self, **kwargs) -> None:
def _initialize_episode(
self, env_ids: Sequence[int] | None = None, **kwargs
) -> None:
# Clear episode buffers for environments that are being reset
save_data = kwargs.get("save_data", True)

# Determine which environments to process
if env_ids is None:
env_ids = range(self.num_envs)
for env_id in env_ids:
# Convert to int if it's a tensor
env_id = int(env_id) if isinstance(env_id, torch.Tensor) else env_id
env_ids_to_process = list(range(self.num_envs))
elif isinstance(env_ids, torch.Tensor):
env_ids_to_process = env_ids.cpu().tolist()
else:
env_ids_to_process = list(env_ids)

# Save dataset before clearing buffers for environments that are being reset
if save_data and self.cfg.dataset:
if "save" in self.dataset_manager.available_modes:

current_task_success = self.is_task_success()

# Filter to only save successful episodes
successful_env_ids = [
env_id
for env_id in env_ids_to_process
if (
self.episode_success_status.get(env_id, False)
or current_task_success[env_id].item()
)
]

if successful_env_ids:
# Convert back to tensor if needed
successful_env_ids_tensor = torch.tensor(
successful_env_ids, device=self.device
)
self.dataset_manager.apply(
mode="save",
env_ids=successful_env_ids_tensor,
)
else:
logger.log_warning("No successful episodes to save.")

# Clear episode buffers and reset success status for environments being reset
for env_id in env_ids_to_process:
self.episode_obs_buffer[env_id].clear()
self.episode_action_buffer[env_id].clear()
self.episode_success_status[env_id] = False

# apply events such as randomization for environments that need a reset
if self.cfg.events:
Expand Down Expand Up @@ -571,22 +603,6 @@ def is_task_success(self, **kwargs) -> torch.Tensor:

return torch.ones(self.num_envs, dtype=torch.bool, device=self.device)

def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor:
"""Check if the episode is truncated.

Args:
obs: The observation from the environment.
info: The info dictionary.

Returns:
A boolean tensor indicating truncation for each environment in the batch.
"""
# Check if action sequence has reached its end
if self.action_length > 0:
return self._elapsed_steps >= self.action_length

return super().check_truncated(obs, info)

def close(self) -> None:
"""Close the environment and release resources."""
# Finalize dataset if present
Expand Down
22 changes: 3 additions & 19 deletions embodichain/lab/gym/envs/managers/dataset_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ class DatasetManager(ManagerBase):
>>> "robot_meta": {...},
>>> "instruction": {"lang": "pick and place"},
>>> "extra": {"scene_type": "kitchen"},
>>> "save_path": "/data/datasets",
>>> "export_success_only": True,
>>> "save_path": "/data/datasets"
>>> }
>>> )
>>> }
Expand Down Expand Up @@ -177,26 +176,16 @@ def apply(
self,
mode: str,
env_ids: Union[Sequence[int], torch.Tensor, None] = None,
obs: Optional[EnvObs] = None,
action: Optional[EnvAction] = None,
dones: Optional[torch.Tensor] = None,
terminateds: Optional[torch.Tensor] = None,
info: Optional[Dict[str, Any]] = None,
) -> None:
"""Apply dataset functors for the specified mode.

This method follows the same pattern as EventManager.apply() for consistency.
Currently only supports mode="save" which handles both recording and auto-saving.
This method saves completed episodes by reading data from the environment's
episode buffers. It should be called before clearing the buffers during reset.

Args:
mode: The mode to apply (currently only "save" is supported).
env_ids: The indices of the environments to apply the functor to.
Defaults to None, in which case the functor is applied to all environments.
obs: Observation from the environment (batched for all envs).
action: Action applied to the environment (batched for all envs).
dones: Boolean tensor indicating which envs completed episodes.
terminateds: Boolean tensor indicating termination (success/fail).
info: Info dict containing success/fail information.
"""
# check if mode is valid
if mode not in self._mode_functor_names:
Expand All @@ -210,11 +199,6 @@ def apply(
functor_cfg.func(
self._env,
env_ids,
obs,
action,
dones,
terminateds,
info,
**functor_cfg.params,
)

Expand Down
58 changes: 12 additions & 46 deletions embodichain/lab/gym/envs/managers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(self, cfg: DatasetFunctorCfg, env: EmbodiedEnv):
- use_videos: Whether to save videos
- image_writer_threads: Number of threads for image writing
- image_writer_processes: Number of processes for image writing
- export_success_only: Whether to export only successful episodes
env: The environment instance
"""
if not LEROBOT_AVAILABLE:
Expand All @@ -90,7 +89,6 @@ def __init__(self, cfg: DatasetFunctorCfg, env: EmbodiedEnv):
self.instruction = params.get("instruction", None)
self.extra = params.get("extra", {})
self.use_videos = params.get("use_videos", False)
self.export_success_only = params.get("export_success_only", False)

# LeRobot dataset instance
self.dataset: Optional[LeRobotDataset] = None
Expand All @@ -114,52 +112,34 @@ def __call__(
self,
env: EmbodiedEnv,
env_ids: Union[torch.Tensor, None],
obs: EnvObs,
action: EnvAction,
dones: torch.Tensor,
terminateds: torch.Tensor,
info: Dict[str, Any],
save_path: Optional[str] = None,
id: Optional[str] = None,
robot_meta: Optional[Dict] = None,
instruction: Optional[str] = None,
extra: Optional[Dict] = None,
use_videos: bool = False,
export_success_only: bool = False,
) -> None:
"""Main entry point for the recorder functor.

This method is called by DatasetManager.apply(mode="save") with runtime arguments
as positional parameters and configuration parameters from cfg.params.
This method is called by DatasetManager.apply(mode="save") to save completed episodes.
It reads data from the environment's episode buffers.

Args:
env: The environment instance.
env_ids: Environment IDs (for consistency with EventManager pattern).
obs: Observation from the environment.
action: Action applied to the environment.
dones: Boolean tensor indicating which envs completed episodes.
terminateds: Termination flags (success/fail).
info: Info dict containing success/fail information.
save_path: Root directory (already set in __init__).
id: Dataset identifier (already set in __init__).
robot_meta: Robot metadata (already set in __init__).
instruction: Task instruction (already set in __init__).
extra: Extra metadata (already set in __init__).
use_videos: Whether to save videos (already set in __init__).
export_success_only: Whether to export only successful episodes (already set in __init__).
env_ids: Environment IDs to save. If None, attempts to save all environments.
"""
# If env_ids is None, check all environments for completed episodes
if env_ids is None:
env_ids = torch.arange(env.num_envs, device=env.device)
elif isinstance(env_ids, (list, range)):
env_ids = torch.tensor(list(env_ids), device=env.device)

# Check if any episodes are done and save them
done_env_ids = dones.nonzero(as_tuple=False).squeeze(-1)
if len(done_env_ids) > 0:
# Save completed episodes
self._save_episodes(done_env_ids, terminateds, info)
# Save episodes for specified environments
if len(env_ids) > 0:
self._save_episodes(env_ids)

def _save_episodes(
self,
env_ids: torch.Tensor,
terminateds: Optional[torch.Tensor] = None,
info: Optional[Dict[str, Any]] = None,
) -> None:
"""Save completed episodes for specified environments."""
task = self.instruction.get("lang", "unknown_task")
Expand Down Expand Up @@ -187,19 +167,6 @@ def _save_episodes(
self.total_time += current_episode_time
episode_extra_info["total_time"] = self.total_time
self._update_dataset_info({"extra": episode_extra_info})
is_success = False
if info is not None and "success" in info:
success_tensor = info["success"]
if isinstance(success_tensor, torch.Tensor):
is_success = success_tensor[env_id].item()
else:
is_success = success_tensor
elif terminateds is not None:
is_success = terminateds[env_id].item()

if self.export_success_only and not is_success:
logger.log_info(f"Skipping failed episode for env {env_id}")
continue

try:
for obs, action in zip(obs_list, action_list):
Expand All @@ -210,8 +177,7 @@ def _save_episodes(

logger.log_info(
f"[LeRobotRecorder] Saved dataset to: {self.dataset_path}\n"
f" Episode {self.curr_episode} (env {env_id}): "
f"{'successful' if is_success else 'failed'}, {len(obs_list)} frames"
f" Episode {self.curr_episode} (env {env_id}): {len(obs_list)} frames"
)

self.curr_episode += 1
Expand Down
22 changes: 4 additions & 18 deletions embodichain/lab/scripts/run_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,12 @@ def generate_function(
valid = generate_and_execute_action_list(env, trajectory_idx, debug_mode)

if not valid:
_, _ = env.reset()
# Failed execution: reset without saving invalid data
_, _ = env.reset(options={"save_data": False})
break

# Check task success for all environments
if not debug_mode:
success = env.get_wrapper_attr("is_task_success")()
# For multiple environments, check if all succeeded
all_success = (
success.all().item() if success.numel() > 1 else success.item()
)
if all_success:
pass
# TODO: Add data saving and online data streaming logic here.
else:
log_warning(f"Task fail, Skip to next generation.")
valid = False
break
else:
# In debug mode, skip success check
pass
# Successful execution: reset and save data
_, _ = env.reset()

if valid:
break
Expand Down
Loading