最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

python - TRITON - Strange error with matrix multiplication - Stack Overflow

programmeradmin2浏览0评论

I have 2 matrices P and V and when I take their dot product with triton I get results that are inconsistent with pytorch.

The P and V matrices are as follows. P is basically the softmax which is why it is mostly 0s except the final column, the result of the dot product should be the final row of V.

P = torch.zeros((32,32), device = 'cuda',  dtype = torch.float32)
P[:,-1] = 1
V = torch.arange(32*64, 64 * 64, device = 'cuda', dtype = torch.float32). reshape(32, 64)

On calling tl.dot(P, V) both are loading correctly (or so they appear to me), but the output is

[4032., 4032., 4034., 4034., 4036., 4036., 4038., 4038., 4040., 4040.,
         4042., 4042., 4044., 4044., 4046., 4046., 4048., 4048., 4050., 4050.,
         4052., 4052., 4054., 4054., 4056., 4056., 4058., 4058., 4060., 4060.,
         4062., 4062., 4064., 4064., 4066., 4066., 4068., 4068., 4070., 4070.,
         4072., 4072., 4074., 4074., 4076., 4076., 4078., 4078., 4080., 4080.,
         4082., 4082., 4084., 4084., 4086., 4086., 4088., 4088., 4090., 4090.,
         4092., 4092., 4094., 4094.]

instead of what I get from torch.matmul which is

[4032., 4033., 4034., 4035., 4036., 4037., 4038., 4039., 4040., 4041.,
         4042., 4043., 4044., 4045., 4046., 4047., 4048., 4049., 4050., 4051.,
         4052., 4053., 4054., 4055., 4056., 4057., 4058., 4059., 4060., 4061.,
         4062., 4063., 4064., 4065., 4066., 4067., 4068., 4069., 4070., 4071.,
         4072., 4073., 4074., 4075., 4076., 4077., 4078., 4079., 4080., 4081.,
         4082., 4083., 4084., 4085., 4086., 4087., 4088., 4089., 4090., 4091.,
         4092., 4093., 4094., 4095.]

The following is the code I'm testing this out in

import triton
import triton.language as tl
import torch
torch.cuda.is_available()
torch.set_printoptions(profile="full")
@triton.jit
def test_kernel(x_ptr,y_ptr,output_ptr,
                M, K, N,
                stride_xm, stride_xk,
                stride_yk, stride_yn,
                stride_om, stride_on,
                BLOCK_SIZE_M: tl.constexpr,
                BLOCK_SIZE_K: tl.constexpr,
                BLOCK_SIZE_N: tl.constexpr):
    pid_m = tl.program_id(axis = 0) * BLOCK_SIZE_M
    pid_n = tl.program_id(axis = 1) * BLOCK_SIZE_N
    x_ptr += (pid_m + tl.arange(0, BLOCK_SIZE_M))[:,None] * stride_xm +  (pid_n + tl.arange(0, BLOCK_SIZE_K))[None,:]*stride_xk
    y_ptr += (pid_m + tl.arange(0, BLOCK_SIZE_K))[:,None] * stride_yk +  (pid_n + tl.arange(0, BLOCK_SIZE_N))[None,:]*stride_yn
    x = tl.load(x_ptr) 
    y = tl.load(y_ptr) 
    output_offset = (pid_m + tl.arange(0, BLOCK_SIZE_M))[:,None] *stride_om + (pid_n + tl.arange(0, BLOCK_SIZE_N))[None, :] *stride_on
    tl.store(output_ptr + output_offset,  tl.dot(x,y))

def helper(x: torch.Tensor, y: torch.Tensor):
    M , K = x.shape
    K1, N = y.shape
    assert K == K1
    output = torch.empty((M, N), device = 'cuda', dtype = torch.float32)
    assert x.is_cuda and y.is_cuda and output.is_cuda

    grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(N, meta['BLOCK_SIZE_N']),)
    test_kernel[grid](x, y, output, 
                        M, K, N, 
                        x.stride(0), x.stride(1),
                        y.stride(0), y.stride(1),
                        output.stride(0), output.stride(1),
                        BLOCK_SIZE_N = 64, 
                        BLOCK_SIZE_K = 32, 
                        BLOCK_SIZE_M = 32, 
                        )
    return output

The strangest thing is when I define V = torch.arange(0, 32*64, device = 'cuda', dtype = torch.float32). reshape(32, 64) it works as expected. Is there something with pointer operations that I'm missing here?

I have 2 matrices P and V and when I take their dot product with triton I get results that are inconsistent with pytorch.

The P and V matrices are as follows. P is basically the softmax which is why it is mostly 0s except the final column, the result of the dot product should be the final row of V.

P = torch.zeros((32,32), device = 'cuda',  dtype = torch.float32)
P[:,-1] = 1
V = torch.arange(32*64, 64 * 64, device = 'cuda', dtype = torch.float32). reshape(32, 64)

On calling tl.dot(P, V) both are loading correctly (or so they appear to me), but the output is

[4032., 4032., 4034., 4034., 4036., 4036., 4038., 4038., 4040., 4040.,
         4042., 4042., 4044., 4044., 4046., 4046., 4048., 4048., 4050., 4050.,
         4052., 4052., 4054., 4054., 4056., 4056., 4058., 4058., 4060., 4060.,
         4062., 4062., 4064., 4064., 4066., 4066., 4068., 4068., 4070., 4070.,
         4072., 4072., 4074., 4074., 4076., 4076., 4078., 4078., 4080., 4080.,
         4082., 4082., 4084., 4084., 4086., 4086., 4088., 4088., 4090., 4090.,
         4092., 4092., 4094., 4094.]

instead of what I get from torch.matmul which is

[4032., 4033., 4034., 4035., 4036., 4037., 4038., 4039., 4040., 4041.,
         4042., 4043., 4044., 4045., 4046., 4047., 4048., 4049., 4050., 4051.,
         4052., 4053., 4054., 4055., 4056., 4057., 4058., 4059., 4060., 4061.,
         4062., 4063., 4064., 4065., 4066., 4067., 4068., 4069., 4070., 4071.,
         4072., 4073., 4074., 4075., 4076., 4077., 4078., 4079., 4080., 4081.,
         4082., 4083., 4084., 4085., 4086., 4087., 4088., 4089., 4090., 4091.,
         4092., 4093., 4094., 4095.]

The following is the code I'm testing this out in

import triton
import triton.language as tl
import torch
torch.cuda.is_available()
torch.set_printoptions(profile="full")
@triton.jit
def test_kernel(x_ptr,y_ptr,output_ptr,
                M, K, N,
                stride_xm, stride_xk,
                stride_yk, stride_yn,
                stride_om, stride_on,
                BLOCK_SIZE_M: tl.constexpr,
                BLOCK_SIZE_K: tl.constexpr,
                BLOCK_SIZE_N: tl.constexpr):
    pid_m = tl.program_id(axis = 0) * BLOCK_SIZE_M
    pid_n = tl.program_id(axis = 1) * BLOCK_SIZE_N
    x_ptr += (pid_m + tl.arange(0, BLOCK_SIZE_M))[:,None] * stride_xm +  (pid_n + tl.arange(0, BLOCK_SIZE_K))[None,:]*stride_xk
    y_ptr += (pid_m + tl.arange(0, BLOCK_SIZE_K))[:,None] * stride_yk +  (pid_n + tl.arange(0, BLOCK_SIZE_N))[None,:]*stride_yn
    x = tl.load(x_ptr) 
    y = tl.load(y_ptr) 
    output_offset = (pid_m + tl.arange(0, BLOCK_SIZE_M))[:,None] *stride_om + (pid_n + tl.arange(0, BLOCK_SIZE_N))[None, :] *stride_on
    tl.store(output_ptr + output_offset,  tl.dot(x,y))

def helper(x: torch.Tensor, y: torch.Tensor):
    M , K = x.shape
    K1, N = y.shape
    assert K == K1
    output = torch.empty((M, N), device = 'cuda', dtype = torch.float32)
    assert x.is_cuda and y.is_cuda and output.is_cuda

    grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(N, meta['BLOCK_SIZE_N']),)
    test_kernel[grid](x, y, output, 
                        M, K, N, 
                        x.stride(0), x.stride(1),
                        y.stride(0), y.stride(1),
                        output.stride(0), output.stride(1),
                        BLOCK_SIZE_N = 64, 
                        BLOCK_SIZE_K = 32, 
                        BLOCK_SIZE_M = 32, 
                        )
    return output

The strangest thing is when I define V = torch.arange(0, 32*64, device = 'cuda', dtype = torch.float32). reshape(32, 64) it works as expected. Is there something with pointer operations that I'm missing here?

Share Improve this question edited Apr 6 at 11:01 talonmies 72.4k35 gold badges203 silver badges289 bronze badges asked Mar 18 at 9:40 DivDiv 211 silver badge4 bronze badges 1
  • The way to signify the problem is solved is to accept the answer – talonmies Commented Apr 6 at 11:51
Add a comment  | 

1 Answer 1

Reset to default 1

The problem has been solved (thanks to dai on discord). The issue is input_precision is tf32 by default for dot product, which has 10bits mantissa - leading to trailing digit loss. The problem was very pronounced with V = torch.arange(4096, 4096 + 2048, device = 'cuda', dtype = torch.float32), where the output was [6080., 6080., 6080., 6080., 6084., 6084., 6084., 6084., 6088.,...] . Switching to "ieee" input precision tl.dot(x,y, input_precision = "ieee") solved the issue.

发布评论

评论列表(0)

  1. 暂无评论