use notify::{raw_watcher, RawEvent, RecommendedWatcher, RecursiveMode, Watcher}; use parking_lot::{RwLock, RwLockReadGuard}; use slab::Slab; use std::collections::{BTreeSet, HashMap}; use std::fs::read_to_string; use std::path::{Path, PathBuf}; use std::sync::mpsc::{channel, Receiver, TryRecvError}; use std::sync::Arc; #[repr(transparent)] #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] pub struct ShaderHandle(usize); pub struct ShaderStoreReadGuard<'a> { guard: RwLockReadGuard<'a, Slab>, handle: ShaderHandle, } impl<'a> AsRef for ShaderStoreReadGuard<'a> { fn as_ref(&self) -> &wgpu::ShaderModule { self.guard.get(self.handle.0).unwrap() } } #[derive(Debug)] pub enum ShaderError { InvalidHandle, } pub struct ShaderStore { device: Arc, shaders: RwLock>, } impl ShaderStore { pub fn new(device: Arc) -> Self { Self { device, shaders: Default::default(), } } pub fn load(&self, module: &naga::Module) -> Result { let source = self.generate_wgsl(module)?; let shader = self.load_wgsl(source)?; let index = self.shaders.write().insert(shader); Ok(ShaderHandle(index)) } pub fn reload(&self, handle: &ShaderHandle, module: &naga::Module) -> Result<(), ShaderError> { let mut write = self.shaders.write(); match write.get_mut(handle.0) { Some(stored) => { let source = self.generate_wgsl(module)?; let shader = self.load_wgsl(source)?; let _old = std::mem::replace(stored, shader); Ok(()) } None => Err(ShaderError::InvalidHandle), } } fn generate_wgsl(&self, module: &naga::Module) -> Result { // TODO handle all the errors that can happen here use naga::back::wgsl::{Writer, WriterFlags}; use naga::valid::{Capabilities, ValidationFlags, Validator}; let validation_flags = ValidationFlags::all(); let capabilities = Capabilities::empty(); let mut validator = Validator::new(validation_flags, capabilities); let module_info = validator.validate(&module).unwrap(); let wgsl_flags = WriterFlags::empty(); let mut wgsl_buffer = String::new(); let mut wgsl_writer = Writer::new(&mut wgsl_buffer, wgsl_flags); wgsl_writer .write(&module, &module_info) .expect("wgsl write failed"); Ok(wgsl_buffer) } fn load_wgsl(&self, wgsl_source: String) -> Result { let shader = self .device .create_shader_module(wgpu::ShaderModuleDescriptor { label: None, source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Owned(wgsl_source)), }); Ok(shader) } pub fn get(&self, handle: &ShaderHandle) -> Option { let guard = self.shaders.read(); if let Some(_) = guard.get(handle.0) { Some(ShaderStoreReadGuard { guard, handle: *handle, }) } else { None } } } pub struct ShaderInfo { pub handle: ShaderHandle, pub dependencies: BTreeSet, } pub struct ShaderLoader<'a> { store: &'a ShaderStore, source: String, root_path: PathBuf, old_handle: Option, include_stack: Vec, included: BTreeSet, } impl<'a> ShaderLoader<'a> { pub fn new( store: &'a ShaderStore, root_path: impl AsRef, old_handle: Option, ) -> Self { Self { store, source: String::new(), root_path: root_path.as_ref().to_path_buf(), old_handle, include_stack: Default::default(), included: Default::default(), } } pub fn include_file(&mut self, path: &PathBuf) { let path = self.root_path.join(path); if self.include_stack.contains(&path) { panic!("Circular include of {:?}", path); } self.included.insert(path.clone()); self.include_stack.push(path.clone()); let contents = read_to_string(path).unwrap(); for line in contents.lines() { let words: Vec<&str> = line.split_whitespace().filter(|w| w.len() > 0).collect(); if words.get(0) == Some(&"#include") { let include_path = words[1]; self.include_file(&include_path.into()); } else { self.source.push_str(line); self.source.push_str("\r\n"); } } self.include_stack.pop(); } pub fn finish(self) -> ShaderInfo { let mut parser = naga::front::wgsl::Parser::new(); let module = match parser.parse(&self.source) { Ok(module) => module, // TODO handle parsing errors Err(error) => panic!( "wgsl parsing error:\n{}", error.emit_to_string(&self.source) ), }; let handle = if let Some(handle) = self.old_handle { self.store.reload(&handle, &module).unwrap(); handle } else { self.store.load(&module).unwrap() }; ShaderInfo { handle, dependencies: self.included, } } } pub struct ShaderWatcher { store: Arc, root_path: PathBuf, _watcher: RecommendedWatcher, notify_rx: Receiver, file_infos: RwLock>, } impl ShaderWatcher { pub fn new( store: Arc, root_path: impl AsRef, ) -> Result { let root_path = root_path.as_ref().to_path_buf(); let (notify_tx, notify_rx) = channel(); let mut watcher = raw_watcher(notify_tx)?; watcher.watch(root_path.clone(), RecursiveMode::Recursive)?; Ok(Self { store, root_path, _watcher: watcher, notify_rx, file_infos: Default::default(), }) } pub fn watch(&self) { loop { match self.notify_rx.try_recv() { Ok(event) => match event { RawEvent { path: Some(path), op: Ok(op), cookie: _, // TODO use cookie to disambiguate updates } => { if op.contains(notify::Op::CREATE) { self.on_changed(&path); } } other => panic!("unexpected shader loader watcher event: {:#?}", other), }, Err(TryRecvError::Empty) => break, Err(TryRecvError::Disconnected) => panic!("shader loader watcher disconnected"), } } } pub fn on_changed(&self, path: &Path) { let infos_read = self.file_infos.read(); let path_buf = path.to_path_buf(); match infos_read.get(&path_buf) { Some(handle) => { let mut loader = ShaderLoader::new(&self.store, &self.root_path, Some(*handle)); loader.include_file(&path_buf); let _info = loader.finish(); // TODO update dependencies } _ => {} } } pub fn add_file(&self, shader_path: impl AsRef) -> Result { let shader_path_buf = shader_path.as_ref().to_path_buf(); let mut loader = ShaderLoader::new(&self.store, &self.root_path, None); loader.include_file(&shader_path_buf); let info = loader.finish(); let mut infos_write = self.file_infos.write(); infos_write.insert(self.root_path.join(shader_path_buf), info.handle); Ok(info.handle) } }