loader
Policy loading and inference runners for SKRL evaluation.
Classes:
| Name | Description |
|---|---|
BasePolicyRunner |
Base policy runner interface. |
JITPolicyRunner |
TorchScript policy runner. |
ONNXPolicyRunner |
ONNX policy runner. |
TensorRTPolicyRunner |
TensorRT policy runner for serialized engine artifacts. |
SKRLPolicyRunner |
Adapter exposing a unified |
Functions:
| Name | Description |
|---|---|
load_agent |
Load a trained PyTorch agent for evaluation/inference. |
load_policy |
Load a policy as a unified runner supporting JIT, ONNX, or TensorRT. |
BasePolicyRunner
BasePolicyRunner(action_dim: int | None = None)
Bases: ABC
Base policy runner interface.
Initialize the base runner.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
int | None
|
Optional action dimension hint. |
None
|
Methods:
| Name | Description |
|---|---|
act |
Return action predictions for the given observations. |
act
abstractmethod
act(observations: ndarray | Tensor) -> np.ndarray
Return action predictions for the given observations.
JITPolicyRunner
JITPolicyRunner(model_path: str, device: str | device)
Bases: BasePolicyRunner
TorchScript policy runner.
Initialize a TorchScript policy runner.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
Path to the JIT model. |
required |
|
str | device
|
Target device. |
required |
Methods:
| Name | Description |
|---|---|
act |
Compute actions from observations. |
act
act(observations: ndarray | Tensor) -> np.ndarray
Compute actions from observations.
Returns:
| Type | Description |
|---|---|
ndarray
|
np.ndarray: Action array with shape matching the model output. |
ONNXPolicyRunner
ONNXPolicyRunner(model_path: str, providers: list | None = None, session_options: Any | None = None)
Bases: BasePolicyRunner
ONNX policy runner.
Initialize an ONNX policy runner.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
Path to the ONNX model. |
required |
|
list | None
|
Optional ONNX providers list. |
None
|
|
Any | None
|
Optional ONNX Runtime session options. |
None
|
Raises:
| Type | Description |
|---|---|
ImportError
|
If onnxruntime is not installed. |
RuntimeError
|
If the ONNX model has no inputs. |
Methods:
| Name | Description |
|---|---|
act |
Compute actions from observations. |
act
act(observations: ndarray | Tensor) -> np.ndarray
Compute actions from observations.
Returns:
| Type | Description |
|---|---|
ndarray
|
np.ndarray: Action array with shape matching the model output. |
TensorRTPolicyRunner
TensorRTPolicyRunner(model_path: str)
Bases: BasePolicyRunner
TensorRT policy runner for serialized engine artifacts.
Initialize a TensorRT policy runner.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
Path to a serialized TensorRT engine. |
required |
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If the TensorRT engine file is missing. |
RuntimeError
|
If CUDA is unavailable. |
RuntimeError
|
If the engine cannot be deserialized or has no I/O. |
Methods:
| Name | Description |
|---|---|
convert_onnx_to_engine |
Build a serialized TensorRT engine from an ONNX policy. |
act |
Compute actions from observations with a TensorRT engine. |
convert_onnx_to_engine
classmethod
convert_onnx_to_engine(onnx_path: str, engine_path: str | None = None, *, workspace_size_bytes: int = 1 << 30, fp16: bool = True, input_shapes: dict[str, tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]] | None = None) -> str
Build a serialized TensorRT engine from an ONNX policy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
Path to the ONNX model. |
required |
|
str | None
|
Output path for the serialized engine.
Defaults to replacing the suffix with |
None
|
|
int
|
TensorRT workspace memory pool size. |
1 << 30
|
|
bool
|
Whether to enable FP16 builder mode when supported. |
True
|
|
dict[str, tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]] | None
|
Optional dynamic-shape profiles keyed by input tensor
name as |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
Path to the serialized TensorRT engine. |
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If the ONNX model does not exist. |
ImportError
|
If TensorRT is unavailable. |
RuntimeError
|
If parsing or engine building fails. |
act
act(observations: ndarray | Tensor) -> np.ndarray
Compute actions from observations with a TensorRT engine.
SKRLPolicyRunner
SKRLPolicyRunner(policy: Any, observation_preprocessor: Any, device: str | device)
Bases: BasePolicyRunner
Adapter exposing a unified act(observations) interface for SKRL.
Initialize an SKRL policy runner.
Methods:
| Name | Description |
|---|---|
act |
Run policy inference and return numpy actions. |
act
act(observations: ndarray | Tensor) -> np.ndarray
Run policy inference and return numpy actions.
load_agent
load_agent(model_cls: type, agent_cfg: Any, checkpoint_path: str, observation_shape: tuple, action_shape: tuple, state_shape: tuple | None = None, device: str | device | None = None) -> tuple
Load a trained PyTorch agent for evaluation/inference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
type
|
Class of the model (policy). |
required |
|
Any
|
Agent configuration object. |
required |
|
str
|
Path to the checkpoint file. |
required |
|
tuple
|
Tuple defining observation dimensions. |
required |
|
tuple
|
Tuple defining action dimensions. |
required |
|
tuple | None
|
Tuple defining state dimensions. |
None
|
|
str | device | None
|
Torch device (auto-detected if None). |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple
|
(policy, observation_preprocessor, device). |
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If the checkpoint file is not found. |
ImportError
|
If gymnasium is not installed. |
load_policy
load_policy(model_path: str, device: str | device = 'cpu', providers: list | None = None, session_options: Any | None = None) -> BasePolicyRunner
Load a policy as a unified runner supporting JIT, ONNX, or TensorRT.
The returned object exposes act(observations) and returns numpy actions.
Observations may be numpy arrays or torch tensors. A 1-D observation is treated
as a single batch and the returned actions are 1-D accordingly.
Returns:
| Name | Type | Description |
|---|---|---|
BasePolicyRunner |
BasePolicyRunner
|
Loaded policy runner. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |