eval
Evaluation helpers for SKRL policies.
类:
| 名称 | 描述 |
|---|---|
JITPolicyRunner |
TorchScript policy runner. |
ONNXPolicyRunner |
ONNX policy runner. |
函数:
| 名称 | 描述 |
|---|---|
load_agent |
Load a trained agent for evaluation/inference. |
load_policy |
Load a policy as a unified runner supporting JIT or ONNX. |
JITPolicyRunner
JITPolicyRunner(model_path: str, device: str | device)
Bases: _BasePolicyRunner
TorchScript policy runner.
Initialize a TorchScript policy runner.
参数:
| 名称 | 类型 | 描述 | 默认 |
|---|---|---|---|
|
str
|
Path to the JIT model. |
必需 |
|
str | device
|
Target device. |
必需 |
方法:
| 名称 | 描述 |
|---|---|
act |
Compute actions from observations. |
act
act(observations: ndarray | Tensor) -> np.ndarray
Compute actions from observations.
返回:
| 类型 | 描述 |
|---|---|
ndarray
|
np.ndarray: Action array with shape matching the model output. |
ONNXPolicyRunner
ONNXPolicyRunner(model_path: str, providers: list | None = None)
Bases: _BasePolicyRunner
ONNX policy runner.
Initialize an ONNX policy runner.
参数:
| 名称 | 类型 | 描述 | 默认 |
|---|---|---|---|
|
str
|
Path to the ONNX model. |
必需 |
|
list | None
|
Optional ONNX providers list. |
None
|
引发:
| 类型 | 描述 |
|---|---|
RuntimeError
|
If the ONNX model has no inputs. |
方法:
| 名称 | 描述 |
|---|---|
act |
Compute actions from observations. |
act
act(observations: ndarray | Tensor) -> np.ndarray
Compute actions from observations.
返回:
| 类型 | 描述 |
|---|---|
ndarray
|
np.ndarray: Action array with shape matching the model output. |
load_agent
load_agent(model_cls: type, agent_cfg: Any, checkpoint_path: str, observation_shape: tuple, action_shape: tuple, state_shape: tuple | None = None, device: str | device | None = None) -> tuple
Load a trained agent for evaluation/inference.
参数:
| 名称 | 类型 | 描述 | 默认 |
|---|---|---|---|
|
type
|
Class of the model (policy). |
必需 |
|
Any
|
Agent configuration object. |
必需 |
|
str
|
Path to the checkpoint file. |
必需 |
|
tuple
|
Tuple defining observation dimensions. |
必需 |
|
tuple
|
Tuple defining action dimensions. |
必需 |
|
tuple | None
|
Tuple defining state dimensions. |
None
|
|
str | device | None
|
Torch device (auto-detected if None). |
None
|
返回:
| 名称 | 类型 | 描述 |
|---|---|---|
tuple |
tuple
|
(policy, observation_preprocessor, device). |
引发:
| 类型 | 描述 |
|---|---|
FileNotFoundError
|
If the checkpoint file is not found. |
ImportError
|
If gymnasium is not installed. |
load_policy
load_policy(model_path: str, device: str | device = 'cpu') -> JITPolicyRunner | ONNXPolicyRunner
Load a policy as a unified runner supporting JIT or ONNX.
The returned object exposes act(observations) and returns numpy actions.
Observations may be numpy arrays or torch tensors. A 1-D observation is treated
as a single batch and the returned actions are 1-D accordingly.
返回:
| 类型 | 描述 |
|---|---|
JITPolicyRunner | ONNXPolicyRunner
|
JITPolicyRunner | ONNXPolicyRunner: Loaded policy runner. |
引发:
| 类型 | 描述 |
|---|---|
ValueError
|
If |