device-utils
GPU detection and device resolution utilities.
Probe the system for available GPU hardware.
Detection order: ROCm (AMD) > CUDA (NVIDIA) > MPS (Apple) > CPU.
Returns:
| Type | Description |
|---|---|
GPUInfo
|
GPUInfo describing the best available backend. |
Resolve and validate a device string.
Accepts
- "auto" / None — auto-detect via detect_gpu()
- "cuda" — CUDA device 0
- "cuda:N" — specific CUDA device index N (e.g. "cuda:1")
- "rocm" — AMD ROCm (maps to "cuda" in PyTorch)
- "rocm:N" — specific ROCm device index N
- "mps" — Apple Silicon
- "cpu" — CPU
Returns:
| Type | Description |
|---|---|
str
|
Validated PyTorch device string (e.g. "cuda", "cuda:1", "mps", "cpu"). |
Raises:
| Type | Description |
|---|---|
DeviceError
|
If the requested backend is unavailable or the device index is out of range. |
Bases: BaseModel
Information about the available GPU backend.
Attributes:
| Name | Type | Description |
|---|---|---|
vendor |
str
|
GPU vendor name (e.g. "NVIDIA", "AMD", "Apple", "CPU"). |
backend |
str
|
PyTorch backend in use ("CUDA", "ROCm", "MPS", "CPU"). |
version |
str | None
|
Backend version string (CUDA or ROCm version). None for MPS/CPU. |
devices |
list[str]
|
List of device names detected. |
vram_gb |
list[float]
|
VRAM in GB per device. Empty for MPS and CPU. |
device_str
property
PyTorch device string to pass to torch.device().
Resolve one or more device strings for multi-GPU dispatch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
requested
|
str | list[str] | None
|
A single device string, a list of device strings, or "all" to use every available CUDA device. |
None
|
Returns:
| Type | Description |
|---|---|
list[str]
|
List of validated PyTorch device strings. Always has at least one entry. |
Examples:
resolve_devices("cuda") → ["cuda"] resolve_devices("cuda:1") → ["cuda:1"] resolve_devices(["cuda:0", "cuda:1"]) → ["cuda:0", "cuda:1"] resolve_devices("all") → ["cuda:0", "cuda:1", ...] (all GPUs)
Clear GPU memory cache for the given device. No-op for CPU.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
str
|
PyTorch device string ("cuda", "mps", or "cpu"). |
required |