eval
Evaluation helpers for SKRL policies.
Classes:
| Name | Description |
|---|---|
JITPolicyRunner |
TorchScript policy runner. |
ONNXPolicyRunner |
ONNX policy runner. |
Functions:
| Name | Description |
|---|---|
load_agent |
Load a trained agent for evaluation/inference. |
load_policy |
Load a policy as a unified runner supporting JIT or ONNX. |
JITPolicyRunner
JITPolicyRunner(model_path: str, device: str | device)
Bases: _BasePolicyRunner
TorchScript policy runner.
Initialize a TorchScript policy runner.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
Path to the JIT model. |
required |
|
str | device
|
Target device. |
required |
Methods:
| Name | Description |
|---|---|
act |
Compute actions from observations. |
act
act(observations: ndarray | Tensor) -> np.ndarray
Compute actions from observations.
Returns:
| Type | Description |
|---|---|
ndarray
|
np.ndarray: Action array with shape matching the model output. |
ONNXPolicyRunner
ONNXPolicyRunner(model_path: str, providers: list | None = None)
Bases: _BasePolicyRunner
ONNX policy runner.
Initialize an ONNX policy runner.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
Path to the ONNX model. |
required |
|
list | None
|
Optional ONNX providers list. |
None
|
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If the ONNX model has no inputs. |
Methods:
| Name | Description |
|---|---|
act |
Compute actions from observations. |
act
act(observations: ndarray | Tensor) -> np.ndarray
Compute actions from observations.
Returns:
| Type | Description |
|---|---|
ndarray
|
np.ndarray: Action array with shape matching the model output. |
load_agent
load_agent(model_cls: type, agent_cfg: Any, checkpoint_path: str, observation_shape: tuple, action_shape: tuple, state_shape: tuple | None = None, device: str | device | None = None) -> tuple
Load a trained agent for evaluation/inference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
type
|
Class of the model (policy). |
required |
|
Any
|
Agent configuration object. |
required |
|
str
|
Path to the checkpoint file. |
required |
|
tuple
|
Tuple defining observation dimensions. |
required |
|
tuple
|
Tuple defining action dimensions. |
required |
|
tuple | None
|
Tuple defining state dimensions. |
None
|
|
str | device | None
|
Torch device (auto-detected if None). |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple
|
(policy, observation_preprocessor, device). |
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If the checkpoint file is not found. |
ImportError
|
If gymnasium is not installed. |
load_policy
load_policy(model_path: str, device: str | device = 'cpu') -> JITPolicyRunner | ONNXPolicyRunner
Load a policy as a unified runner supporting JIT or ONNX.
The returned object exposes act(observations) and returns numpy actions.
Observations may be numpy arrays or torch tensors. A 1-D observation is treated
as a single batch and the returned actions are 1-D accordingly.
Returns:
| Type | Description |
|---|---|
JITPolicyRunner | ONNXPolicyRunner
|
JITPolicyRunner | ONNXPolicyRunner: Loaded policy runner. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |