跳转至

eval

Evaluation helpers for SKRL policies.

类:

名称 描述
JITPolicyRunner

TorchScript policy runner.

ONNXPolicyRunner

ONNX policy runner.

函数:

名称 描述
load_agent

Load a trained agent for evaluation/inference.

load_policy

Load a policy as a unified runner supporting JIT or ONNX.

JITPolicyRunner

JITPolicyRunner(model_path: str, device: str | device)

Bases: _BasePolicyRunner

TorchScript policy runner.

Initialize a TorchScript policy runner.

参数:

名称 类型 描述 默认

model_path

str

Path to the JIT model.

必需

device

str | device

Target device.

必需

方法:

名称 描述
act

Compute actions from observations.

act

act(observations: ndarray | Tensor) -> np.ndarray

Compute actions from observations.

返回:

类型 描述
ndarray

np.ndarray: Action array with shape matching the model output.

ONNXPolicyRunner

ONNXPolicyRunner(model_path: str, providers: list | None = None)

Bases: _BasePolicyRunner

ONNX policy runner.

Initialize an ONNX policy runner.

参数:

名称 类型 描述 默认

model_path

str

Path to the ONNX model.

必需

providers

list | None

Optional ONNX providers list.

None

引发:

类型 描述
RuntimeError

If the ONNX model has no inputs.

方法:

名称 描述
act

Compute actions from observations.

act

act(observations: ndarray | Tensor) -> np.ndarray

Compute actions from observations.

返回:

类型 描述
ndarray

np.ndarray: Action array with shape matching the model output.

load_agent

load_agent(model_cls: type, agent_cfg: Any, checkpoint_path: str, observation_shape: tuple, action_shape: tuple, state_shape: tuple | None = None, device: str | device | None = None) -> tuple

Load a trained agent for evaluation/inference.

参数:

名称 类型 描述 默认

model_cls

type

Class of the model (policy).

必需

agent_cfg

Any

Agent configuration object.

必需

checkpoint_path

str

Path to the checkpoint file.

必需

observation_shape

tuple

Tuple defining observation dimensions.

必需

action_shape

tuple

Tuple defining action dimensions.

必需

state_shape

tuple | None

Tuple defining state dimensions.

None

device

str | device | None

Torch device (auto-detected if None).

None

返回:

名称 类型 描述
tuple tuple

(policy, observation_preprocessor, device).

引发:

类型 描述
FileNotFoundError

If the checkpoint file is not found.

ImportError

If gymnasium is not installed.

load_policy

load_policy(model_path: str, device: str | device = 'cpu') -> JITPolicyRunner | ONNXPolicyRunner

Load a policy as a unified runner supporting JIT or ONNX.

The returned object exposes act(observations) and returns numpy actions. Observations may be numpy arrays or torch tensors. A 1-D observation is treated as a single batch and the returned actions are 1-D accordingly.

返回:

类型 描述
JITPolicyRunner | ONNXPolicyRunner

JITPolicyRunner | ONNXPolicyRunner: Loaded policy runner.

引发:

类型 描述
ValueError

If model_path is None.