如下C程序执行结果正常
// $ gcc test.c -I/opt/cuda/include -lcuda && ./a.out && rm a.out
// wtf?
// cuInit returns 0
// cuDeviceGet returns 0
// cuCtxCreate returns 0
// cuModuleLoad returns 0
// 0代表没出错
#include<stdio.h>
#include<cuda.h>
void main() {
printf("wtf?\n");
int e = cuInit(0);
printf("cuInit returns %d\n", e);
CUdevice device;
e = cuDeviceGet(&device, 0);
printf("cuDeviceGet returns %d\n", e);
CUcontext ctx = NULL;
e = cuCtxCreate(&ctx, 0, device);
printf("cuCtxCreate returns %d\n", e);
CUmodule module;
e = cuModuleLoad(&module, "test.ptx");
printf("cuModuleLoad returns %d\n", e);
}
如下Rust程序执行结果报错
// rustc --edition 2024 test.rs -o test && ./test && rm ./test
// wtf?
// init returns 0
// DeviceGet 0
// CtxCreate returns 0
// ModuleLoad returns 200
// 200代表出错了
type CUcontext = *const i8;
type CUmodule = *const u8; // 懒得写opaque type了,这两个type的定义都是没问题的
#[link(name = "cuda")]
unsafe extern "C" {
fn cuInit(flag: i32) -> i32;
fn cuDeviceGet(dev: *mut i32, item: i32) -> i32;
fn cuCtxCreate_v2(ctx: *mut CUcontext, flag: i32, device: i32) -> i32;
fn cuModuleLoad(module: *mut CUmodule, file: *const i8) -> i32;
}
fn main() {
unsafe {
println!("wtf?");
println!("init returns {}", cuInit(0));
let mut device = 0i32;
println!("DeviceGet {}", cuDeviceGet(&mut device, 0));
let mut ctx = core::ptr::null();
println!("CtxCreate returns {}", cuCtxCreate_v2(&mut ctx, 0, device));
let mut ctx = core::ptr::null();
println!("ModuleLoad returns {}", cuModuleLoad(&mut ctx, c"test.ptx".as_ptr()));
}
}
猜猜为什么
用到的ptx如下,是完全没有问题的,根正苗红的NVPTX:
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-35583870
// Cuda compilation tools, release 12.8, V12.8.93
// Based on NVVM 20.0.0
//
.version 8.7
.target sm_120
.address_size 64
// .globl _Z6stampaf
.extern .func (.param .b32 func_retval0) vprintf
(
.param .b64 vprintf_param_0,
.param .b64 vprintf_param_1
)
;
.global .align 1 .b8 $str[23] = {72, 101, 108, 108, 111, 32, 116, 104, 114, 101, 97, 100, 32, 37, 100, 44, 32, 102, 61, 37, 102, 10};
.visible .entry _Z6stampaf(
.param .f32 _Z6stampaf_param_0
)
{
.local .align 16 .b8 __local_depot0[16];
.reg .b64 %SP;
.reg .b64 %SPL;
.reg .b32 %r<4>;
.reg .f32 %f<2>;
.reg .b64 %rd<5>;
.reg .f64 %fd<2>;
mov.u64 %SPL, __local_depot0;
cvta.local.u64 %SP, %SPL;
ld.param.f32 %f1, [_Z6stampaf_param_0];
add.u64 %rd1, %SP, 0;
add.u64 %rd2, %SPL, 0;
mov.u32 %r1, %tid.x;
cvt.f64.f32 %fd1, %f1;
st.local.u32 [%rd2], %r1;
st.local.f64 [%rd2+8], %fd1;
mov.u64 %rd3, $str;
cvta.global.u64 %rd4, %rd3;
{ // callseq 0, 0
.param .b64 param0;
st.param.b64 [param0+0], %rd4;
.param .b64 param1;
st.param.b64 [param1+0], %rd1;
.param .b32 retval0;
call.uni (retval0),
vprintf,
(
param0,
param1
);
ld.param.b32 %r2, [retval0+0];
} // callseq 0
ret;
}
如果大家一定觉得PTX有问题,我可以加一个误导项: 删除如下代码可以让程序正常执行
{ // callseq 0, 0
.param .b64 param0;
st.param.b64 [param0+0], %rd4;
.param .b64 param1;
st.param.b64 [param1+0], %rd1;
.param .b32 retval0;
call.uni (retval0),
vprintf,
(
param0,
param1
);
ld.param.b32 %r2, [retval0+0];
} // callseq 0
如果大家有闲情雅致的话可以来猜猜这里出了什么见鬼的问题
1
共 1 条评论, 1 页
评论区
写评论intriguing me...