cyborg/src/shader.rs

259 lines
7.9 KiB
Rust

use notify::{raw_watcher, RawEvent, RecommendedWatcher, RecursiveMode, Watcher};
use parking_lot::{RwLock, RwLockReadGuard};
use slab::Slab;
use std::collections::{BTreeSet, HashMap};
use std::fs::read_to_string;
use std::path::{Path, PathBuf};
use std::sync::mpsc::{channel, Receiver, TryRecvError};
use std::sync::Arc;
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub struct ShaderHandle(usize);
pub struct ShaderStoreReadGuard<'a> {
guard: RwLockReadGuard<'a, Slab<wgpu::ShaderModule>>,
handle: ShaderHandle,
}
impl<'a> AsRef<wgpu::ShaderModule> for ShaderStoreReadGuard<'a> {
fn as_ref(&self) -> &wgpu::ShaderModule {
self.guard.get(self.handle.0).unwrap()
}
}
#[derive(Debug)]
pub enum ShaderError {
InvalidHandle,
}
pub struct ShaderStore {
device: Arc<wgpu::Device>,
shaders: RwLock<Slab<wgpu::ShaderModule>>,
}
impl ShaderStore {
pub fn new(device: Arc<wgpu::Device>) -> Self {
Self {
device,
shaders: Default::default(),
}
}
pub fn load(&self, module: &naga::Module) -> Result<ShaderHandle, ShaderError> {
let source = self.generate_wgsl(module)?;
let shader = self.load_wgsl(source)?;
let index = self.shaders.write().insert(shader);
Ok(ShaderHandle(index))
}
pub fn reload(&self, handle: &ShaderHandle, module: &naga::Module) -> Result<(), ShaderError> {
let mut write = self.shaders.write();
match write.get_mut(handle.0) {
Some(stored) => {
let source = self.generate_wgsl(module)?;
let shader = self.load_wgsl(source)?;
let _old = std::mem::replace(stored, shader);
Ok(())
}
None => Err(ShaderError::InvalidHandle),
}
}
fn generate_wgsl(&self, module: &naga::Module) -> Result<String, ShaderError> {
// TODO handle all the errors that can happen here
use naga::back::wgsl::{Writer, WriterFlags};
use naga::valid::{Capabilities, ValidationFlags, Validator};
let validation_flags = ValidationFlags::all();
let capabilities = Capabilities::empty();
let mut validator = Validator::new(validation_flags, capabilities);
let module_info = validator.validate(&module).unwrap();
let wgsl_flags = WriterFlags::empty();
let mut wgsl_buffer = String::new();
let mut wgsl_writer = Writer::new(&mut wgsl_buffer, wgsl_flags);
wgsl_writer
.write(&module, &module_info)
.expect("wgsl write failed");
Ok(wgsl_buffer)
}
fn load_wgsl(&self, wgsl_source: String) -> Result<wgpu::ShaderModule, ShaderError> {
let shader = self
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Owned(wgsl_source)),
});
Ok(shader)
}
pub fn get(&self, handle: &ShaderHandle) -> Option<ShaderStoreReadGuard> {
let guard = self.shaders.read();
if let Some(_) = guard.get(handle.0) {
Some(ShaderStoreReadGuard {
guard,
handle: *handle,
})
} else {
None
}
}
}
pub struct ShaderInfo {
pub handle: ShaderHandle,
pub dependencies: BTreeSet<PathBuf>,
}
pub struct ShaderLoader<'a> {
store: &'a ShaderStore,
source: String,
root_path: PathBuf,
old_handle: Option<ShaderHandle>,
include_stack: Vec<PathBuf>,
included: BTreeSet<PathBuf>,
}
impl<'a> ShaderLoader<'a> {
pub fn new(
store: &'a ShaderStore,
root_path: impl AsRef<Path>,
old_handle: Option<ShaderHandle>,
) -> Self {
Self {
store,
source: String::new(),
root_path: root_path.as_ref().to_path_buf(),
old_handle,
include_stack: Default::default(),
included: Default::default(),
}
}
pub fn include_file(&mut self, path: &PathBuf) {
let path = self.root_path.join(path);
if self.include_stack.contains(&path) {
panic!("Circular include of {:?}", path);
}
self.included.insert(path.clone());
self.include_stack.push(path.clone());
let contents = read_to_string(path).unwrap();
for line in contents.lines() {
let words: Vec<&str> = line.split_whitespace().filter(|w| w.len() > 0).collect();
if words.get(0) == Some(&"#include") {
let include_path = words[1];
self.include_file(&include_path.into());
} else {
self.source.push_str(line);
self.source.push_str("\r\n");
}
}
self.include_stack.pop();
}
pub fn finish(self) -> ShaderInfo {
let mut parser = naga::front::wgsl::Parser::new();
let module = match parser.parse(&self.source) {
Ok(module) => module,
// TODO handle parsing errors
Err(error) => panic!(
"wgsl parsing error:\n{}",
error.emit_to_string(&self.source)
),
};
let handle = if let Some(handle) = self.old_handle {
self.store.reload(&handle, &module).unwrap();
handle
} else {
self.store.load(&module).unwrap()
};
ShaderInfo {
handle,
dependencies: self.included,
}
}
}
pub struct ShaderWatcher {
store: Arc<ShaderStore>,
root_path: PathBuf,
_watcher: RecommendedWatcher,
notify_rx: Receiver<RawEvent>,
file_infos: RwLock<HashMap<PathBuf, ShaderHandle>>,
}
impl ShaderWatcher {
pub fn new(
store: Arc<ShaderStore>,
root_path: impl AsRef<Path>,
) -> Result<Self, notify::Error> {
let root_path = root_path.as_ref().to_path_buf();
let (notify_tx, notify_rx) = channel();
let mut watcher = raw_watcher(notify_tx)?;
watcher.watch(root_path.clone(), RecursiveMode::Recursive)?;
Ok(Self {
store,
root_path,
_watcher: watcher,
notify_rx,
file_infos: Default::default(),
})
}
pub fn watch(&self) {
loop {
match self.notify_rx.try_recv() {
Ok(event) => match event {
RawEvent {
path: Some(path),
op: Ok(op),
cookie: _, // TODO use cookie to disambiguate updates
} => {
if op.contains(notify::Op::CREATE) {
self.on_changed(&path);
}
}
other => panic!("unexpected shader loader watcher event: {:#?}", other),
},
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => panic!("shader loader watcher disconnected"),
}
}
}
pub fn on_changed(&self, path: &Path) {
let infos_read = self.file_infos.read();
let path_buf = path.to_path_buf();
match infos_read.get(&path_buf) {
Some(handle) => {
let mut loader = ShaderLoader::new(&self.store, &self.root_path, Some(*handle));
loader.include_file(&path_buf);
let _info = loader.finish(); // TODO update dependencies
}
_ => {}
}
}
pub fn add_file(&self, shader_path: impl AsRef<Path>) -> Result<ShaderHandle, ShaderError> {
let shader_path_buf = shader_path.as_ref().to_path_buf();
let mut loader = ShaderLoader::new(&self.store, &self.root_path, None);
loader.include_file(&shader_path_buf);
let info = loader.finish();
let mut infos_write = self.file_infos.write();
infos_write.insert(self.root_path.join(shader_path_buf), info.handle);
Ok(info.handle)
}
}