diff --git a/embodichain/lab/sim/objects/articulation.py b/embodichain/lab/sim/objects/articulation.py index c23e187b..493f26c1 100644 --- a/embodichain/lab/sim/objects/articulation.py +++ b/embodichain/lab/sim/objects/articulation.py @@ -128,6 +128,9 @@ def __init__( if self.device.type == "cuda" else self.dof ) + self._target_qpos = torch.zeros( + (self.num_instances, max_dof), dtype=torch.float32, device=self.device + ) self._qpos = torch.zeros( (self.num_instances, max_dof), dtype=torch.float32, device=self.device ) @@ -248,6 +251,34 @@ def qpos(self) -> torch.Tensor: ) return self._qpos[:, : self.dof].clone() + @property + def target_qpos(self) -> torch.Tensor: + """Get the target positions (target_qpos) of the articulation. + + Returns: + torch.Tensor: The target positions of the articulation with shape of (num_instances, dof). + """ + if self.device.type == "cpu": + # Fetch target_qpos from CPU entities + return torch.as_tensor( + # TODO: cpu get joint target position + np.array( + [ + entity.get_current_qpos(is_target=True) + for entity in self.entities + ], + ), + dtype=torch.float32, + device=self.device, + ) + else: + self.ps.gpu_fetch_joint_data( + data=self._target_qpos, + gpu_indices=self.gpu_indices, + data_type=ArticulationGPUAPIReadType.JOINT_TARGET_POSITION, + ) + return self._target_qpos[:, : self.dof].clone() + @property def qvel(self) -> torch.Tensor: """Get the current velocities (qvel) of the articulation. @@ -976,10 +1007,10 @@ def set_qpos( else: # TODO: trigger qpos getter to sync data, otherwise crash if joint_ids is not None: - self.body_data.qpos + self.body_data.target_qpos indices = self.body_data.gpu_indices[local_env_ids] - qpos_set = self.body_data._qpos[local_env_ids] + qpos_set = self.body_data._target_qpos[local_env_ids] qpos_set[:, local_joint_ids] = qpos self._ps.gpu_apply_joint_data( data=qpos_set, @@ -1189,6 +1220,9 @@ def reallocate_body_data(self) -> None: self._data._qpos = torch.zeros( (self.num_instances, max_dof), dtype=torch.float32, device=self.device ) + self._data._target_qpos = torch.zeros( + (self.num_instances, max_dof), dtype=torch.float32, device=self.device + ) self._data._qvel = torch.zeros( (self.num_instances, max_dof), dtype=torch.float32, device=self.device )