Update autotune configuration to avoid crash on AMD devices
Browse filesWhen running on an AMD device, trying to autotune with 32 warps causes a crash with `RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument`. Thus we are removing that configuration when the device name contains "AMD", which is the case for MI250, MI300 and MI355. Tested on MI300.
torch-ext/triton_layer_norm/layer_norm.py
CHANGED
|
@@ -16,6 +16,22 @@ import triton
|
|
| 16 |
import triton.language as tl
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def layer_norm_ref(
|
| 20 |
x,
|
| 21 |
weight,
|
|
@@ -128,14 +144,7 @@ def rms_norm_ref(
|
|
| 128 |
|
| 129 |
|
| 130 |
@triton.autotune(
|
| 131 |
-
configs=[
|
| 132 |
-
triton.Config({}, num_warps=1),
|
| 133 |
-
triton.Config({}, num_warps=2),
|
| 134 |
-
triton.Config({}, num_warps=4),
|
| 135 |
-
triton.Config({}, num_warps=8),
|
| 136 |
-
triton.Config({}, num_warps=16),
|
| 137 |
-
triton.Config({}, num_warps=32),
|
| 138 |
-
],
|
| 139 |
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
| 140 |
)
|
| 141 |
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
|
@@ -407,14 +416,7 @@ def _layer_norm_fwd(
|
|
| 407 |
|
| 408 |
|
| 409 |
@triton.autotune(
|
| 410 |
-
configs=[
|
| 411 |
-
triton.Config({}, num_warps=1),
|
| 412 |
-
triton.Config({}, num_warps=2),
|
| 413 |
-
triton.Config({}, num_warps=4),
|
| 414 |
-
triton.Config({}, num_warps=8),
|
| 415 |
-
triton.Config({}, num_warps=16),
|
| 416 |
-
triton.Config({}, num_warps=32),
|
| 417 |
-
],
|
| 418 |
key=[
|
| 419 |
"N",
|
| 420 |
"HAS_DRESIDUAL",
|
|
|
|
| 16 |
import triton.language as tl
|
| 17 |
|
| 18 |
|
| 19 |
+
autotune_configs = [
|
| 20 |
+
triton.Config({}, num_warps=1),
|
| 21 |
+
triton.Config({}, num_warps=2),
|
| 22 |
+
triton.Config({}, num_warps=4),
|
| 23 |
+
triton.Config({}, num_warps=8),
|
| 24 |
+
triton.Config({}, num_warps=16),
|
| 25 |
+
triton.Config({}, num_warps=32),
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
if torch.cuda.is_available():
|
| 29 |
+
is_amd_device = ("AMD" in torch.cuda.get_device_name())
|
| 30 |
+
# AMD devices have a maximum of 16 warps, so we remove the 32 warps autotune config
|
| 31 |
+
if is_amd_device and autotune_configs[-1].num_warps == 32:
|
| 32 |
+
autotune_configs.pop()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
def layer_norm_ref(
|
| 36 |
x,
|
| 37 |
weight,
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
@triton.autotune(
|
| 147 |
+
configs=autotune_configs[:],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
| 149 |
)
|
| 150 |
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
|
|
|
| 416 |
|
| 417 |
|
| 418 |
@triton.autotune(
|
| 419 |
+
configs=autotune_configs[:],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
key=[
|
| 421 |
"N",
|
| 422 |
"HAS_DRESIDUAL",
|