diff --git a/materials/example.wgsl b/shaders/example.wgsl similarity index 100% rename from materials/example.wgsl rename to shaders/example.wgsl diff --git a/src/pass/mesh_shader.wgsl b/shaders/mesh_forward.wgsl similarity index 100% rename from src/pass/mesh_shader.wgsl rename to shaders/mesh_forward.wgsl diff --git a/src/main.rs b/src/main.rs index f55b107..124ca1f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,12 +22,22 @@ fn main() { let shader_store = Arc::new(shader::ShaderStore::new(device.to_owned())); let shaders_dir = std::env::current_dir().unwrap(); - let shaders_dir = shaders_dir.join("materials/"); - let shader_loader = shader::ShaderLoader::new(shader_store.to_owned(), shaders_dir).unwrap(); - shader_loader.add_file("example.wgsl").unwrap(); + let shaders_dir = shaders_dir.join("shaders/"); + let shader_watcher = shader::ShaderWatcher::new(shader_store.to_owned(), shaders_dir).unwrap(); - let mesh_pass = - pass::mesh::MeshPass::new(device.to_owned(), layouts.to_owned(), viewport.get_info()); + let mesh_forward = shader_watcher.add_file("mesh_forward.wgsl").unwrap(); + + let mesh_shaders = pass::mesh::ShaderInfo { + store: shader_store.clone(), + forward: mesh_forward, + }; + + let mesh_pass = pass::mesh::MeshPass::new( + device.to_owned(), + layouts.to_owned(), + viewport.get_info(), + mesh_shaders, + ); let debug_pass = pass::debug::DebugPass::new(device.to_owned(), layouts.to_owned(), viewport.get_info()); @@ -48,7 +58,7 @@ fn main() { } } Event::MainEventsCleared => { - shader_loader.watch(); + shader_watcher.watch(); camera.update(); window.request_redraw(); } diff --git a/src/pass/debug.rs b/src/pass/debug.rs index 8605319..ad226c4 100644 --- a/src/pass/debug.rs +++ b/src/pass/debug.rs @@ -44,7 +44,8 @@ impl DebugPass { layouts: Arc, target_info: ViewportInfo, ) -> Self { - let shader = device.create_shader_module(&wgpu::include_wgsl!("mesh_shader.wgsl")); + // TODO hook into ShaderStore system + let shader = device.create_shader_module(&wgpu::include_wgsl!("debug_shader.wgsl")); let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: Some("DebugPass Pipeline Layout"), diff --git a/src/pass/debug_shader.wgsl b/src/pass/debug_shader.wgsl new file mode 100644 index 0000000..ffba437 --- /dev/null +++ b/src/pass/debug_shader.wgsl @@ -0,0 +1,40 @@ +struct CameraUniform { + eye: vec4; + vp: mat4x4; +}; + +struct VertexInput { + [[location(0)]] position: vec3; + [[location(1)]] color: vec3; +}; + +struct VertexOutput { + [[builtin(position)]] clip_position: vec4; + [[location(0)]] position: vec3; + [[location(1)]] color: vec3; +}; + +[[group(0), binding(0)]] +var camera: CameraUniform; + +[[stage(vertex)]] +fn vs_main( + [[builtin(instance_index)]] mesh_idx: u32, + [[builtin(vertex_index)]] vertex_idx: u32, + vertex: VertexInput, +) -> VertexOutput { + let world_pos = vertex.position; + + var out: VertexOutput; + out.clip_position = camera.vp * vec4(world_pos, 1.0); + out.position = world_pos; + out.color = vertex.color; + return out; +} + +[[stage(fragment)]] +fn fs_main( + frag: VertexOutput, +) -> [[location(0)]] vec4 { + return vec4(frag.color, 1.0); +} diff --git a/src/pass/mesh.rs b/src/pass/mesh.rs index c3b53af..d011630 100644 --- a/src/pass/mesh.rs +++ b/src/pass/mesh.rs @@ -2,6 +2,12 @@ use super::*; use crate::mesh::*; use crate::viewport::ViewportInfo; use crate::RenderLayouts; +use crate::shader::{ShaderStore, ShaderHandle}; + +pub struct ShaderInfo { + pub store: Arc, + pub forward: ShaderHandle, +} #[repr(C)] #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] @@ -37,6 +43,7 @@ pub struct MeshPass { device: Arc, layouts: Arc, attr_store: Arc, + shader_info: ShaderInfo, mesh_pool: Arc, vertex_attr_id: AttrId, index_attr_id: AttrId, @@ -52,6 +59,7 @@ impl MeshPass { device: Arc, layouts: Arc, target_info: ViewportInfo, + shader_info: ShaderInfo, ) -> Self { let attr_store = AttrStore::new(); let mesh_pool = MeshPool::new(device.clone(), attr_store.to_owned()); @@ -110,7 +118,7 @@ impl MeshPass { push_constant_ranges: &[], }); - let shader = device.create_shader_module(&wgpu::include_wgsl!("mesh_shader.wgsl")); + let shader = shader_info.store.get(&shader_info.forward).unwrap(); let targets = &[wgpu::ColorTargetState { format: target_info.output_format, @@ -122,12 +130,12 @@ impl MeshPass { label: Some("Opaque MeshPass Pipeline"), layout: Some(&render_pipeline_layout), vertex: wgpu::VertexState { - module: &shader, + module: shader.as_ref(), entry_point: "vs_main", buffers: &[Vertex::desc()], }, fragment: Some(wgpu::FragmentState { - module: &shader, + module: shader.as_ref(), entry_point: "fs_main", targets, }), @@ -167,10 +175,13 @@ impl MeshPass { let opaque_pipeline = device.create_render_pipeline(&pipeline_desc); + drop(shader); + Self { device, layouts, attr_store, + shader_info, mesh_pool, index_attr_id, vertex_attr_id, diff --git a/src/shader.rs b/src/shader.rs index 7ef7ef1..283289d 100644 --- a/src/shader.rs +++ b/src/shader.rs @@ -1,15 +1,26 @@ use notify::{raw_watcher, RawEvent, RecommendedWatcher, RecursiveMode, Watcher}; use slab::Slab; -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use std::fs::read_to_string; use std::path::{Path, PathBuf}; use std::sync::mpsc::{channel, Receiver, TryRecvError}; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, RwLock, RwLockReadGuard}; #[repr(transparent)] #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] pub struct ShaderHandle(usize); +pub struct ShaderStoreReadGuard<'a> { + guard: RwLockReadGuard<'a, Slab>, + handle: ShaderHandle, +} + +impl<'a> AsRef for ShaderStoreReadGuard<'a> { + fn as_ref(&self) -> &wgpu::ShaderModule { + self.guard.get(self.handle.0).unwrap() + } +} + #[derive(Debug)] pub enum ShaderError { InvalidHandle, @@ -77,17 +88,110 @@ impl ShaderStore { }); Ok(shader) } + + pub fn get(&self, handle: &ShaderHandle) -> Option { + let guard = self.shaders.read().unwrap(); + if let Some(_) = guard.get(handle.0) { + Some(ShaderStoreReadGuard { + guard, + handle: *handle, + }) + } else { + None + } + } } -pub struct ShaderLoader { +pub struct ShaderInfo { + handle: ShaderHandle, + dependencies: BTreeSet, +} + +pub struct ShaderLoader<'a> { + store: &'a ShaderStore, + source: String, + root_path: PathBuf, + old_handle: Option, + include_stack: Vec, + included: BTreeSet, +} + +impl<'a> ShaderLoader<'a> { + pub fn new( + store: &'a ShaderStore, + root_path: impl AsRef, + old_handle: Option, + ) -> Self { + Self { + store, + source: String::new(), + root_path: root_path.as_ref().to_path_buf(), + old_handle, + include_stack: Default::default(), + included: Default::default(), + } + } + + pub fn include_file(&mut self, path: &PathBuf) { + let path = self.root_path.join(path); + + if self.include_stack.contains(&path) { + panic!("Circular include of {:?}", path); + } + + self.included.insert(path.clone()); + self.include_stack.push(path.clone()); + + let contents = read_to_string(path).unwrap(); + for line in contents.lines() { + let words: Vec<&str> = line.split_whitespace().filter(|w| w.len() > 0).collect(); + + if words.get(0) == Some(&"#include") { + let include_path = words[1]; + self.include_file(&include_path.into()); + } else { + self.source.push_str(line); + self.source.push_str("\r\n"); + } + } + + self.include_stack.pop(); + } + + pub fn finish(self) -> ShaderInfo { + let mut parser = naga::front::wgsl::Parser::new(); + let module = match parser.parse(&self.source) { + Ok(module) => module, + // TODO handle parsing errors + Err(error) => panic!( + "wgsl parsing error:\n{}", + error.emit_to_string(&self.source) + ), + }; + + let handle = if let Some(handle) = self.old_handle { + self.store.reload(&handle, &module).unwrap(); + handle + } else { + self.store.load(&module).unwrap() + }; + + ShaderInfo { + handle, + dependencies: self.included, + } + } +} + +pub struct ShaderWatcher { store: Arc, root_path: PathBuf, _watcher: RecommendedWatcher, notify_rx: Receiver, - file_handles: RwLock>, + file_infos: RwLock>, } -impl ShaderLoader { +impl ShaderWatcher { pub fn new( store: Arc, root_path: impl AsRef, @@ -103,7 +207,7 @@ impl ShaderLoader { root_path, _watcher: watcher, notify_rx, - file_handles: Default::default(), + file_infos: Default::default(), }) } @@ -116,12 +220,8 @@ impl ShaderLoader { op: Ok(op), cookie: _, // TODO use cookie to disambiguate updates } => { - let handles_read = self.file_handles.read().unwrap(); - let path_buf = path.to_path_buf(); - if let Some(handle) = handles_read.get(&path_buf) { - if op.contains(notify::Op::CREATE) { - let _result = self.reload(handle, path_buf); - } + if op.contains(notify::Op::CREATE) { + self.on_changed(&path); } } other => panic!("unexpected shader loader watcher event: {:#?}", other), @@ -132,42 +232,26 @@ impl ShaderLoader { } } - pub fn add_file(&self, shader_path: impl AsRef) -> Result { - let path_buf = self.root_path.join(shader_path); - let module = self.load(&path_buf)?; - let handle = self.store.load(&module)?; - self.file_handles.write().unwrap().insert(path_buf, handle); - Ok(handle) - } - - fn load(&self, shader_path: impl AsRef) -> Result { - // TODO handle IO errors - let source = read_to_string(shader_path).unwrap(); - let mut parser = naga::front::wgsl::Parser::new(); - match parser.parse(&source) { - Ok(module) => Ok(module), - // TODO handle parsing errors - Err(error) => panic!("wgsl parsing error:\n{}", error.emit_to_string(&source)), - } - } - - fn reload( - &self, - handle: &ShaderHandle, - shader_path: impl AsRef, - ) -> Result<(), ShaderError> { - match self.load(shader_path) { - Ok(shader) => match self.store.reload(handle, &shader) { - Ok(()) => return Ok(()), - Err(e) => { - eprintln!("Shader reload error: {:?}", e); - return Err(e); - } - }, - Err(e) => { - eprintln!("Shader reload error: {:?}", e); - return Err(e); + pub fn on_changed(&self, path: &Path) { + let infos_read = self.file_infos.read().unwrap(); + let path_buf = path.to_path_buf(); + match infos_read.get(&path_buf) { + Some(handle) => { + let mut loader = ShaderLoader::new(&self.store, &self.root_path, Some(*handle)); + loader.include_file(&path_buf); + let _info = loader.finish(); // TODO update dependencies } + _ => {} } } + + pub fn add_file(&self, shader_path: impl AsRef) -> Result { + let shader_path_buf = shader_path.as_ref().to_path_buf(); + let mut loader = ShaderLoader::new(&self.store, &self.root_path, None); + loader.include_file(&shader_path_buf); + let info = loader.finish(); + let mut infos_write = self.file_infos.write().unwrap(); + infos_write.insert(self.root_path.join(shader_path_buf), info.handle); + Ok(info.handle) + } }