Example Workloads#

This document provides example workloads and configurations for using the AMD GPU device plugin in Kubernetes.

Basic GPU Pod Example#

This example demonstrates how to run a basic PyTorch workload on an AMD GPU. The pod creates simple tensors on the GPU and performs basic addition operations to verify GPU functionality. Since this is a job-like workload that runs once and completes, we set restartPolicy: Never to prevent the pod from restarting after completion.

Here’s a simple example of a pod requesting an AMD GPU:

apiVersion: v1
kind: Pod
metadata:
  name: pytorch-gpu-pod-example
spec:
  restartPolicy: Never
  containers:
  - name: gpu-container
    image: rocm/pytorch:latest
    command:
    - python3
    - "-c"
    - |
      import torch
      if torch.cuda.is_available():
        print(f"GPU is available. Device count: {torch.cuda.device_count()}")
        print(f"Device name: {torch.cuda.get_device_name(0)}")
        x = torch.ones(3, 3, device='cuda')
        y = torch.ones(3, 3, device='cuda') * 2
        z = x + y
        print(f"Result of tensor addition on GPU: {z}")
      else:
        print("No GPU available.")
    resources:
      limits:
        amd.com/gpu: 1  # Request 1 AMD GPU

To run the example:

kubectl create -f https://raw.githubusercontent.com/ROCm/k8s-device-plugin/master/example/pod/pytorch.yaml

Check the output with:

kubectl logs pytorch-gpu-pod-example

This example manifest is available for download here: https://raw.githubusercontent.com/ROCm/k8s-device-plugin/master/example/pod/pytorch.yaml

Multiple GPU Example#

This example shows how to utilize multiple GPUs in a JAX application. It performs parallel matrix multiplications across both GPUs using JAX’s pmap functionality for distributed computation.

apiVersion: v1
kind: Pod
metadata:
  name: jax-multi-gpu-pod
spec:
  restartPolicy: Never
  containers:
  - name: multi-gpu-container
    image: rocm/jax:latest
    command:
    - /bin/bash
    - "-c"
    - |
      python3 -c "
      import jax
      import jax.numpy as jnp
      print('Available JAX devices:', jax.devices())

      # Create data to process in parallel
      n_devices = jax.device_count()
      print(f'Number of devices: {n_devices}')

      # Create matrices for each device
      x = jnp.ones((n_devices, 1000, 1000))
      y = jnp.ones((n_devices, 1000, 1000))

      # Define computation to run in parallel
      @jax.pmap
      def parallel_matmul(a, b):
          return jnp.matmul(a, b)

      # Run computation in parallel across GPUs
      result = parallel_matmul(x, y)

      print(f'Parallel computation complete across {n_devices} devices')
      print('Result shape:', result.shape)
      print('Device mapping:', jax.devices())
      "
    resources:
      limits:
        amd.com/gpu: 2  # Request 2 AMD GPUs

To run the example:

kubectl create -f https://raw.githubusercontent.com/ROCm/k8s-device-plugin/master/example/pod/jax-non-privileged.yaml

Check the output with:

kubectl logs jax-multigpu-pod

This example manifest is available for download here: https://raw.githubusercontent.com/ROCm/k8s-device-plugin/master/example/pod/jax-mult-gpu.yaml

Non-privileged Pod with GPU Access Example#

This example demonstrates the same JAX example as above, running as a non-privileged container configuration for enhanced security.

apiVersion: v1
kind: Pod
metadata:
  name: jax-non-privileged-multi-gpu-pod
spec:
  restartPolicy: Never
  hostIPC: true
  containers:
  - name: jax-multi-gpu-container
    image: rocm/jax:latest
    command:
    - python3
    - "-c"
    - |
      import jax
      import jax.numpy as jnp
      print('Available JAX devices:', jax.devices())

      # Create data to process in parallel
      n_devices = jax.device_count()
      print(f'Number of devices: {n_devices}')

      # Create matrices for each device
      x = jnp.ones((n_devices, 1000, 1000))
      y = jnp.ones((n_devices, 1000, 1000))

      # Define computation to run in parallel
      @jax.pmap
      def parallel_matmul(a, b):
          return jnp.matmul(a, b)

      # Run computation in parallel across GPUs
      result = parallel_matmul(x, y)

      print(f'Parallel computation complete across {n_devices} devices')
      print('Result shape:', result.shape)
      print('Device mapping:', jax.devices())
    resources:
      limits:
        amd.com/gpu: 2  # Request 2 AMD GPUs
    securityContext:
      privileged: false
      allowPrivilegeEscalation: false
      seccompProfile:
        type: Unconfined

To run the example:

kubectl create -f https://raw.githubusercontent.com/ROCm/k8s-device-plugin/master/example/pod/jax-non-privileged.yaml

Check the output with:

kubectl logs jax-non-privileged-multi-gpu-pod

This example manifest is available for download here: https://raw.githubusercontent.com/ROCm/k8s-device-plugin/master/example/pod/jax-non-privileged.yaml