loader
Policy loading and inference runners for SKRL evaluation.
类:
| 名称 | 描述 |
|---|---|
BasePolicyRunner |
Base policy runner interface. |
JITPolicyRunner |
TorchScript policy runner. |
ONNXPolicyRunner |
ONNX policy runner. |
SKRLPolicyRunner |
Adapter exposing a unified |
函数:
| 名称 | 描述 |
|---|---|
load_agent |
Load a trained PyTorch agent for evaluation/inference. |
load_policy |
Load a policy as a unified runner supporting JIT or ONNX. |
BasePolicyRunner
BasePolicyRunner(action_dim: int | None = None)
Bases: ABC
Base policy runner interface.
Initialize the base runner.
参数:
| 名称 | 类型 | 描述 | 默认 |
|---|---|---|---|
|
int | None
|
Optional action dimension hint. |
None
|
方法:
| 名称 | 描述 |
|---|---|
act |
Return action predictions for the given observations. |
act
abstractmethod
act(observations: ndarray | Tensor) -> np.ndarray
Return action predictions for the given observations.
JITPolicyRunner
JITPolicyRunner(model_path: str, device: str | device)
Bases: BasePolicyRunner
TorchScript policy runner.
Initialize a TorchScript policy runner.
参数:
| 名称 | 类型 | 描述 | 默认 |
|---|---|---|---|
|
str
|
Path to the JIT model. |
必需 |
|
str | device
|
Target device. |
必需 |
方法:
| 名称 | 描述 |
|---|---|
act |
Compute actions from observations. |
act
act(observations: ndarray | Tensor) -> np.ndarray
Compute actions from observations.
返回:
| 类型 | 描述 |
|---|---|
ndarray
|
np.ndarray: Action array with shape matching the model output. |
ONNXPolicyRunner
ONNXPolicyRunner(model_path: str, providers: list | None = None)
Bases: BasePolicyRunner
ONNX policy runner.
Initialize an ONNX policy runner.
参数:
| 名称 | 类型 | 描述 | 默认 |
|---|---|---|---|
|
str
|
Path to the ONNX model. |
必需 |
|
list | None
|
Optional ONNX providers list. |
None
|
引发:
| 类型 | 描述 |
|---|---|
ImportError
|
If onnxruntime is not installed. |
RuntimeError
|
If the ONNX model has no inputs. |
方法:
| 名称 | 描述 |
|---|---|
act |
Compute actions from observations. |
act
act(observations: ndarray | Tensor) -> np.ndarray
Compute actions from observations.
返回:
| 类型 | 描述 |
|---|---|
ndarray
|
np.ndarray: Action array with shape matching the model output. |
SKRLPolicyRunner
SKRLPolicyRunner(policy: Any, observation_preprocessor: Any, device: str | device)
Bases: BasePolicyRunner
Adapter exposing a unified act(observations) interface for SKRL.
Initialize an SKRL policy runner.
方法:
| 名称 | 描述 |
|---|---|
act |
Run policy inference and return numpy actions. |
act
act(observations: ndarray | Tensor) -> np.ndarray
Run policy inference and return numpy actions.
load_agent
load_agent(model_cls: type, agent_cfg: Any, checkpoint_path: str, observation_shape: tuple, action_shape: tuple, state_shape: tuple | None = None, device: str | device | None = None) -> tuple
Load a trained PyTorch agent for evaluation/inference.
参数:
| 名称 | 类型 | 描述 | 默认 |
|---|---|---|---|
|
type
|
Class of the model (policy). |
必需 |
|
Any
|
Agent configuration object. |
必需 |
|
str
|
Path to the checkpoint file. |
必需 |
|
tuple
|
Tuple defining observation dimensions. |
必需 |
|
tuple
|
Tuple defining action dimensions. |
必需 |
|
tuple | None
|
Tuple defining state dimensions. |
None
|
|
str | device | None
|
Torch device (auto-detected if None). |
None
|
返回:
| 名称 | 类型 | 描述 |
|---|---|---|
tuple |
tuple
|
(policy, observation_preprocessor, device). |
引发:
| 类型 | 描述 |
|---|---|
FileNotFoundError
|
If the checkpoint file is not found. |
ImportError
|
If gymnasium is not installed. |
load_policy
load_policy(model_path: str, device: str | device = 'cpu') -> BasePolicyRunner
Load a policy as a unified runner supporting JIT or ONNX.
The returned object exposes act(observations) and returns numpy actions.
Observations may be numpy arrays or torch tensors. A 1-D observation is treated
as a single batch and the returned actions are 1-D accordingly.
返回:
| 名称 | 类型 | 描述 |
|---|---|---|
BasePolicyRunner |
BasePolicyRunner
|
Loaded policy runner. |
引发:
| 类型 | 描述 |
|---|---|
ValueError
|
If |