最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - Triton Naive Matrix multiplication implementation - Stack Overflow

programmeradmin9浏览0评论

I am trying to implement naive matrix multiplication in triton but was getting error

**IndexError: map::at**
    214         if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
    215             passes.llvmir.add_di_scope(pm)
--> 216         pm.run(mod)
    217         # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
    218         llvm.init_targets()
  • I made the indices that are used for loading and storing values to be valid by using the mask that checks for OOB
  • When i disable the dot product of input matrices and store dummy accum values in output, it is running to completion which says, there is some issue w.r.t way i am doing the dot product

Much appreciate any suggestions on how to resolve this issue

Below is the code

import triton
import triton.language as tl
import torch
import numpy as np

#matmul(MxK, KxN) = MxN
# M = 16
# K = 16
# N = 16
# ROW_BLOCK_SIZE = 16
# COL_BLOCK_SIZE = 16
# K_SIZE = 16

@triton.jit
def mat_mul_gpu(mat1, mat2, result, M, N, K, ROW_BLOCK_SIZE : tl.constexpr, COL_BLOCK_SIZE : tl.constexpr, K_SIZE: tl.constexpr):
  row_block_id = tl.program_id(axis=0)
  col_block_id = tl.program_id(axis=1)

  row = tl.arange(0, ROW_BLOCK_SIZE) [:, None]
  col = tl.arange(0, COL_BLOCK_SIZE) [None, :]
  tmp = tl.arange(0, K_SIZE)
  mat1_idx = (row * K) + tmp[None, :]
  mat2_idx = (tmp * N)[:, None] + col

  # Masks to prevent out-of-bound memory access
  mat1_mask = (row < M) & (tmp[None, :] < K)
  mat2_mask = (tmp[:, None] < K) & (col < N)

  # Load matrix values and compute dot product
  accum = tl.zeros((ROW_BLOCK_SIZE, COL_BLOCK_SIZE), dtype=tl.float32)
  a = tl.load(mat1 + mat1_idx, mask = mat1_mask)
  b = tl.load(mat2 + mat2_idx, mask = mat2_mask)
  accum += tl.dot(a, b)

  result_mask = (row < M) & (col < N)
  result_ptr = result + (N * row) + col
  tl.store(result_ptr, accum, mask=result_mask)


mat1_shape = (16, 16)
mat2_shape = (16, 16)
result_shape = (16, 16)
torch.manual_seed(0)

mat1 = torch.rand(mat1_shape, device="cuda", dtype=torch.float)
mat2 = torch.rand(mat2_shape, device="cuda", dtype=torch.float)
d_result = torch.empty(result_shape, device="cuda", dtype=torch.float)
result = torch.empty(result_shape, dtype=torch.float)

ROW_BLOCK_SIZE = 16
COL_BLOCK_SIZE = 16
num_row_blocks = triton.cdiv(result_shape[0], ROW_BLOCK_SIZE)
num_col_blocks = triton.cdiv(result_shape[1], COL_BLOCK_SIZE)
grid = (num_row_blocks, num_col_blocks)
mat_mul_gpu[grid](mat1, mat2, d_result, result_shape[0], result_shape[1], mat1_shape[1], ROW_BLOCK_SIZE, COL_BLOCK_SIZE, mat1_shape[1])

I am trying to implement naive matrix multiplication in triton but was getting error

**IndexError: map::at**
    214         if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
    215             passes.llvmir.add_di_scope(pm)
--> 216         pm.run(mod)
    217         # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
    218         llvm.init_targets()
  • I made the indices that are used for loading and storing values to be valid by using the mask that checks for OOB
  • When i disable the dot product of input matrices and store dummy accum values in output, it is running to completion which says, there is some issue w.r.t way i am doing the dot product

Much appreciate any suggestions on how to resolve this issue

Below is the code

import triton
import triton.language as tl
import torch
import numpy as np

#matmul(MxK, KxN) = MxN
# M = 16
# K = 16
# N = 16
# ROW_BLOCK_SIZE = 16
# COL_BLOCK_SIZE = 16
# K_SIZE = 16

@triton.jit
def mat_mul_gpu(mat1, mat2, result, M, N, K, ROW_BLOCK_SIZE : tl.constexpr, COL_BLOCK_SIZE : tl.constexpr, K_SIZE: tl.constexpr):
  row_block_id = tl.program_id(axis=0)
  col_block_id = tl.program_id(axis=1)

  row = tl.arange(0, ROW_BLOCK_SIZE) [:, None]
  col = tl.arange(0, COL_BLOCK_SIZE) [None, :]
  tmp = tl.arange(0, K_SIZE)
  mat1_idx = (row * K) + tmp[None, :]
  mat2_idx = (tmp * N)[:, None] + col

  # Masks to prevent out-of-bound memory access
  mat1_mask = (row < M) & (tmp[None, :] < K)
  mat2_mask = (tmp[:, None] < K) & (col < N)

  # Load matrix values and compute dot product
  accum = tl.zeros((ROW_BLOCK_SIZE, COL_BLOCK_SIZE), dtype=tl.float32)
  a = tl.load(mat1 + mat1_idx, mask = mat1_mask)
  b = tl.load(mat2 + mat2_idx, mask = mat2_mask)
  accum += tl.dot(a, b)

  result_mask = (row < M) & (col < N)
  result_ptr = result + (N * row) + col
  tl.store(result_ptr, accum, mask=result_mask)


mat1_shape = (16, 16)
mat2_shape = (16, 16)
result_shape = (16, 16)
torch.manual_seed(0)

mat1 = torch.rand(mat1_shape, device="cuda", dtype=torch.float)
mat2 = torch.rand(mat2_shape, device="cuda", dtype=torch.float)
d_result = torch.empty(result_shape, device="cuda", dtype=torch.float)
result = torch.empty(result_shape, dtype=torch.float)

ROW_BLOCK_SIZE = 16
COL_BLOCK_SIZE = 16
num_row_blocks = triton.cdiv(result_shape[0], ROW_BLOCK_SIZE)
num_col_blocks = triton.cdiv(result_shape[1], COL_BLOCK_SIZE)
grid = (num_row_blocks, num_col_blocks)
mat_mul_gpu[grid](mat1, mat2, d_result, result_shape[0], result_shape[1], mat1_shape[1], ROW_BLOCK_SIZE, COL_BLOCK_SIZE, mat1_shape[1])
Share Improve this question asked Feb 5 at 8:47 Sampath Sampath 12 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 0

Fixed the issue by changing from F32 to F16 I was running it in google colab on T4 GPU and looks like there is some issue with F32 support on T4 https://github.com/triton-lang/triton/issues/5557

发布评论

评论列表(0)

  1. 暂无评论