1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use crate::{contexted_call, contexted_new, device::*, error::*, *};
use cuda::*;
use std::ffi::*;
#[derive(Debug)]
pub struct Kernel<'module> {
pub(crate) func: CUfunction,
module: &'module Module,
}
impl Contexted for Kernel<'_> {
fn sync(&self) -> Result<()> {
self.module.context.sync()
}
fn version(&self) -> Result<u32> {
self.module.context.version()
}
fn guard(&self) -> Result<ContextGuard> {
self.module.context.guard()
}
fn get_ref(&self) -> ContextRef {
self.module.get_ref()
}
}
#[derive(Debug, Contexted)]
pub struct Module {
module: CUmodule,
context: Context,
}
impl Drop for Module {
fn drop(&mut self) {
if let Err(e) = unsafe { contexted_call!(&self.context, cuModuleUnload, self.module) } {
log::error!("Failed to unload module: {:?}", e);
}
}
}
impl Module {
pub fn load(context: &Context, data: &Instruction) -> Result<Self> {
match *data {
Instruction::PTX(ref ptx) => {
let module =
unsafe { contexted_new!(context, cuModuleLoadData, ptx.as_ptr() as *const _)? };
Ok(Module {
module,
context: context.clone(),
})
}
Instruction::Cubin(ref bin) => {
let module =
unsafe { contexted_new!(context, cuModuleLoadData, bin.as_ptr() as *const _)? };
Ok(Module {
module,
context: context.clone(),
})
}
Instruction::PTXFile(ref path) | Instruction::CubinFile(ref path) => {
let filename = CString::new(path.to_str().unwrap()).expect("Invalid Path");
let module = unsafe { contexted_new!(context, cuModuleLoad, filename.as_ptr())? };
Ok(Module {
module,
context: context.clone(),
})
}
}
}
pub fn from_str(context: &Context, ptx: &str) -> Result<Self> {
let data = Instruction::ptx(ptx);
Self::load(context, &data)
}
pub fn get_kernel(&self, name: &str) -> Result<Kernel> {
let name = CString::new(name).expect("Invalid Kernel name");
let func =
unsafe { contexted_new!(self, cuModuleGetFunction, self.module, name.as_ptr()) }?;
Ok(Kernel { func, module: self })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn load_do_nothing() -> Result<()> {
let ptx = r#"
.version 3.2
.target sm_30
.address_size 64
.visible .entry do_nothing()
{
ret;
}
"#;
let device = Device::nth(0)?;
let ctx = device.create_context();
let _mod = Module::from_str(&ctx, ptx)?;
Ok(())
}
}