From bc8426e37a979539648a8a879562d621b9e63696 Mon Sep 17 00:00:00 2001 From: mars Date: Wed, 27 Apr 2022 21:14:58 -0600 Subject: [PATCH] Add ShaderLoader and materials/ --- Cargo.toml | 5 ++ materials/example.wgsl | 1 + src/lib.rs | 1 + src/main.rs | 9 +++ src/shader.rs | 173 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 189 insertions(+) create mode 100644 materials/example.wgsl create mode 100644 src/shader.rs diff --git a/Cargo.toml b/Cargo.toml index 4ff6641..6bcaab8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" bytemuck = { version = "^1.0", features = ["derive"] } glam = "0.20" multimap = "0.8" +notify = "^4" pollster = "0.2" rayon = "1" slab = "^0.4" @@ -15,3 +16,7 @@ smallvec = "^1.0" strum = { version = "0.24", features = ["derive"] } wgpu = "^0.12" winit = "0.26" + +[dependencies.naga] +version = "0.8.5" +features = ["wgsl-in", "glsl-in", "wgsl-out", "serialize", "deserialize"] diff --git a/materials/example.wgsl b/materials/example.wgsl new file mode 100644 index 0000000..0559dfa --- /dev/null +++ b/materials/example.wgsl @@ -0,0 +1 @@ +// example shader for when the time is right diff --git a/src/lib.rs b/src/lib.rs index 47ca92d..b106a5d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod camera; pub mod mesh; pub mod pass; pub mod phase; +pub mod shader; pub mod staging; pub mod viewport; diff --git a/src/main.rs b/src/main.rs index f17192b..f55b107 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use cyborg::shader; use cyborg::{pass, viewport::*, Renderer}; use std::sync::Arc; use winit::{ @@ -18,6 +19,13 @@ fn main() { let device = renderer.get_device(); let layouts = renderer.get_layouts(); + let shader_store = Arc::new(shader::ShaderStore::new(device.to_owned())); + + let shaders_dir = std::env::current_dir().unwrap(); + let shaders_dir = shaders_dir.join("materials/"); + let shader_loader = shader::ShaderLoader::new(shader_store.to_owned(), shaders_dir).unwrap(); + shader_loader.add_file("example.wgsl").unwrap(); + let mesh_pass = pass::mesh::MeshPass::new(device.to_owned(), layouts.to_owned(), viewport.get_info()); let debug_pass = @@ -40,6 +48,7 @@ fn main() { } } Event::MainEventsCleared => { + shader_loader.watch(); camera.update(); window.request_redraw(); } diff --git a/src/shader.rs b/src/shader.rs new file mode 100644 index 0000000..a39ac7f --- /dev/null +++ b/src/shader.rs @@ -0,0 +1,173 @@ +use notify::{raw_watcher, RawEvent, RecommendedWatcher, RecursiveMode, Watcher}; +use slab::Slab; +use std::collections::HashMap; +use std::fs::read_to_string; +use std::path::{Path, PathBuf}; +use std::sync::mpsc::{channel, Receiver, TryRecvError}; +use std::sync::{Arc, RwLock}; + +#[repr(transparent)] +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] +pub struct ShaderHandle(usize); + +#[derive(Debug)] +pub enum ShaderError { + InvalidHandle, +} + +pub struct ShaderStore { + device: Arc, + shaders: RwLock>, +} + +impl ShaderStore { + pub fn new(device: Arc) -> Self { + Self { + device, + shaders: Default::default(), + } + } + + pub fn load(&self, module: &naga::Module) -> Result { + let source = self.generate_wgsl(module)?; + let shader = self.load_wgsl(source)?; + let index = self.shaders.write().unwrap().insert(shader); + Ok(ShaderHandle(index)) + } + + pub fn reload(&self, handle: &ShaderHandle, module: &naga::Module) -> Result<(), ShaderError> { + let mut write = self.shaders.write().unwrap(); + match write.get_mut(handle.0) { + Some(stored) => { + let source = self.generate_wgsl(module)?; + let shader = self.load_wgsl(source)?; + let _old = std::mem::replace(stored, shader); + Ok(()) + } + None => Err(ShaderError::InvalidHandle), + } + } + + fn generate_wgsl(&self, module: &naga::Module) -> Result { + // TODO handle all the errors that can happen here + use naga::back::wgsl::{Writer, WriterFlags}; + use naga::valid::{Capabilities, ValidationFlags, Validator}; + + let validation_flags = ValidationFlags::all(); + let capabilities = Capabilities::empty(); + let mut validator = Validator::new(validation_flags, capabilities); + let module_info = validator.validate(&module).unwrap(); + + let wgsl_flags = WriterFlags::empty(); + let mut wgsl_buffer = String::new(); + let mut wgsl_writer = Writer::new(&mut wgsl_buffer, wgsl_flags); + wgsl_writer + .write(&module, &module_info) + .expect("wgsl write failed"); + + Ok(wgsl_buffer) + } + + fn load_wgsl(&self, wgsl_source: String) -> Result { + let shader = self + .device + .create_shader_module(&wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Owned(wgsl_source)), + }); + Ok(shader) + } +} + +pub struct ShaderLoader { + store: Arc, + root_path: PathBuf, + _watcher: RecommendedWatcher, + notify_rx: Receiver, + file_handles: RwLock>, +} + +impl ShaderLoader { + pub fn new( + store: Arc, + root_path: impl AsRef, + ) -> Result { + let root_path = root_path.as_ref().to_path_buf(); + + let (notify_tx, notify_rx) = channel(); + let mut watcher = raw_watcher(notify_tx)?; + watcher.watch(root_path.clone(), RecursiveMode::Recursive)?; + + Ok(Self { + store, + root_path, + _watcher: watcher, + notify_rx, + file_handles: Default::default(), + }) + } + + pub fn watch(&self) { + loop { + match self.notify_rx.try_recv() { + Ok(event) => match event { + RawEvent { + path: Some(path), + op: Ok(op), + cookie: _, // TODO use cookie to disambiguate updates + } => { + let handles_read = self.file_handles.read().unwrap(); + let path_buf = path.to_path_buf(); + if let Some(handle) = handles_read.get(&path_buf) { + if op.contains(notify::Op::CREATE) { + let _result = self.reload(handle, path_buf); + } + } + } + other => panic!("unexpected shader loader watcher event: {:#?}", other), + }, + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => panic!("shader loader watcher disconnected"), + } + } + } + + pub fn add_file(&self, shader_path: impl AsRef) -> Result { + let path_buf = self.root_path.join(shader_path); + let module = self.load(&path_buf)?; + let handle = self.store.load(&module)?; + self.file_handles.write().unwrap().insert(path_buf, handle); + Ok(handle) + } + + fn load(&self, shader_path: impl AsRef) -> Result { + // TODO handle IO errors + let source = read_to_string(shader_path).unwrap(); + let mut parser = naga::front::wgsl::Parser::new(); + match parser.parse(&source) { + Ok(module) => Ok(module), + // TODO handle parsing errors + Err(error) => panic!("wgsl parsing error:\n{}", error), + } + } + + fn reload( + &self, + handle: &ShaderHandle, + shader_path: impl AsRef, + ) -> Result<(), ShaderError> { + match self.load(shader_path) { + Ok(shader) => match self.store.reload(handle, &shader) { + Ok(()) => return Ok(()), + Err(e) => { + eprintln!("Shader reload error: {:?}", e); + return Err(e); + } + }, + Err(e) => { + eprintln!("Shader reload error: {:?}", e); + return Err(e); + } + } + } +}