Enhance shaders + include preprocessing

This commit is contained in:
mars 2022-05-05 06:24:30 -06:00
parent a3d808eb1b
commit 451f292605
7 changed files with 203 additions and 57 deletions

View File

@ -22,12 +22,22 @@ fn main() {
let shader_store = Arc::new(shader::ShaderStore::new(device.to_owned()));
let shaders_dir = std::env::current_dir().unwrap();
let shaders_dir = shaders_dir.join("materials/");
let shader_loader = shader::ShaderLoader::new(shader_store.to_owned(), shaders_dir).unwrap();
shader_loader.add_file("example.wgsl").unwrap();
let shaders_dir = shaders_dir.join("shaders/");
let shader_watcher = shader::ShaderWatcher::new(shader_store.to_owned(), shaders_dir).unwrap();
let mesh_pass =
pass::mesh::MeshPass::new(device.to_owned(), layouts.to_owned(), viewport.get_info());
let mesh_forward = shader_watcher.add_file("mesh_forward.wgsl").unwrap();
let mesh_shaders = pass::mesh::ShaderInfo {
store: shader_store.clone(),
forward: mesh_forward,
};
let mesh_pass = pass::mesh::MeshPass::new(
device.to_owned(),
layouts.to_owned(),
viewport.get_info(),
mesh_shaders,
);
let debug_pass =
pass::debug::DebugPass::new(device.to_owned(), layouts.to_owned(), viewport.get_info());
@ -48,7 +58,7 @@ fn main() {
}
}
Event::MainEventsCleared => {
shader_loader.watch();
shader_watcher.watch();
camera.update();
window.request_redraw();
}

View File

@ -44,7 +44,8 @@ impl DebugPass {
layouts: Arc<RenderLayouts>,
target_info: ViewportInfo,
) -> Self {
let shader = device.create_shader_module(&wgpu::include_wgsl!("mesh_shader.wgsl"));
// TODO hook into ShaderStore system
let shader = device.create_shader_module(&wgpu::include_wgsl!("debug_shader.wgsl"));
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("DebugPass Pipeline Layout"),

View File

@ -0,0 +1,40 @@
struct CameraUniform {
eye: vec4<f32>;
vp: mat4x4<f32>;
};
struct VertexInput {
[[location(0)]] position: vec3<f32>;
[[location(1)]] color: vec3<f32>;
};
struct VertexOutput {
[[builtin(position)]] clip_position: vec4<f32>;
[[location(0)]] position: vec3<f32>;
[[location(1)]] color: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> camera: CameraUniform;
[[stage(vertex)]]
fn vs_main(
[[builtin(instance_index)]] mesh_idx: u32,
[[builtin(vertex_index)]] vertex_idx: u32,
vertex: VertexInput,
) -> VertexOutput {
let world_pos = vertex.position;
var out: VertexOutput;
out.clip_position = camera.vp * vec4<f32>(world_pos, 1.0);
out.position = world_pos;
out.color = vertex.color;
return out;
}
[[stage(fragment)]]
fn fs_main(
frag: VertexOutput,
) -> [[location(0)]] vec4<f32> {
return vec4<f32>(frag.color, 1.0);
}

View File

@ -2,6 +2,12 @@ use super::*;
use crate::mesh::*;
use crate::viewport::ViewportInfo;
use crate::RenderLayouts;
use crate::shader::{ShaderStore, ShaderHandle};
pub struct ShaderInfo {
pub store: Arc<ShaderStore>,
pub forward: ShaderHandle,
}
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
@ -37,6 +43,7 @@ pub struct MeshPass {
device: Arc<wgpu::Device>,
layouts: Arc<RenderLayouts>,
attr_store: Arc<AttrStore>,
shader_info: ShaderInfo,
mesh_pool: Arc<MeshPool>,
vertex_attr_id: AttrId,
index_attr_id: AttrId,
@ -52,6 +59,7 @@ impl MeshPass {
device: Arc<wgpu::Device>,
layouts: Arc<RenderLayouts>,
target_info: ViewportInfo,
shader_info: ShaderInfo,
) -> Self {
let attr_store = AttrStore::new();
let mesh_pool = MeshPool::new(device.clone(), attr_store.to_owned());
@ -110,7 +118,7 @@ impl MeshPass {
push_constant_ranges: &[],
});
let shader = device.create_shader_module(&wgpu::include_wgsl!("mesh_shader.wgsl"));
let shader = shader_info.store.get(&shader_info.forward).unwrap();
let targets = &[wgpu::ColorTargetState {
format: target_info.output_format,
@ -122,12 +130,12 @@ impl MeshPass {
label: Some("Opaque MeshPass Pipeline"),
layout: Some(&render_pipeline_layout),
vertex: wgpu::VertexState {
module: &shader,
module: shader.as_ref(),
entry_point: "vs_main",
buffers: &[Vertex::desc()],
},
fragment: Some(wgpu::FragmentState {
module: &shader,
module: shader.as_ref(),
entry_point: "fs_main",
targets,
}),
@ -167,10 +175,13 @@ impl MeshPass {
let opaque_pipeline = device.create_render_pipeline(&pipeline_desc);
drop(shader);
Self {
device,
layouts,
attr_store,
shader_info,
mesh_pool,
index_attr_id,
vertex_attr_id,

View File

@ -1,15 +1,26 @@
use notify::{raw_watcher, RawEvent, RecommendedWatcher, RecursiveMode, Watcher};
use slab::Slab;
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};
use std::fs::read_to_string;
use std::path::{Path, PathBuf};
use std::sync::mpsc::{channel, Receiver, TryRecvError};
use std::sync::{Arc, RwLock};
use std::sync::{Arc, RwLock, RwLockReadGuard};
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub struct ShaderHandle(usize);
pub struct ShaderStoreReadGuard<'a> {
guard: RwLockReadGuard<'a, Slab<wgpu::ShaderModule>>,
handle: ShaderHandle,
}
impl<'a> AsRef<wgpu::ShaderModule> for ShaderStoreReadGuard<'a> {
fn as_ref(&self) -> &wgpu::ShaderModule {
self.guard.get(self.handle.0).unwrap()
}
}
#[derive(Debug)]
pub enum ShaderError {
InvalidHandle,
@ -77,17 +88,110 @@ impl ShaderStore {
});
Ok(shader)
}
pub fn get(&self, handle: &ShaderHandle) -> Option<ShaderStoreReadGuard> {
let guard = self.shaders.read().unwrap();
if let Some(_) = guard.get(handle.0) {
Some(ShaderStoreReadGuard {
guard,
handle: *handle,
})
} else {
None
}
}
}
pub struct ShaderLoader {
pub struct ShaderInfo {
handle: ShaderHandle,
dependencies: BTreeSet<PathBuf>,
}
pub struct ShaderLoader<'a> {
store: &'a ShaderStore,
source: String,
root_path: PathBuf,
old_handle: Option<ShaderHandle>,
include_stack: Vec<PathBuf>,
included: BTreeSet<PathBuf>,
}
impl<'a> ShaderLoader<'a> {
pub fn new(
store: &'a ShaderStore,
root_path: impl AsRef<Path>,
old_handle: Option<ShaderHandle>,
) -> Self {
Self {
store,
source: String::new(),
root_path: root_path.as_ref().to_path_buf(),
old_handle,
include_stack: Default::default(),
included: Default::default(),
}
}
pub fn include_file(&mut self, path: &PathBuf) {
let path = self.root_path.join(path);
if self.include_stack.contains(&path) {
panic!("Circular include of {:?}", path);
}
self.included.insert(path.clone());
self.include_stack.push(path.clone());
let contents = read_to_string(path).unwrap();
for line in contents.lines() {
let words: Vec<&str> = line.split_whitespace().filter(|w| w.len() > 0).collect();
if words.get(0) == Some(&"#include") {
let include_path = words[1];
self.include_file(&include_path.into());
} else {
self.source.push_str(line);
self.source.push_str("\r\n");
}
}
self.include_stack.pop();
}
pub fn finish(self) -> ShaderInfo {
let mut parser = naga::front::wgsl::Parser::new();
let module = match parser.parse(&self.source) {
Ok(module) => module,
// TODO handle parsing errors
Err(error) => panic!(
"wgsl parsing error:\n{}",
error.emit_to_string(&self.source)
),
};
let handle = if let Some(handle) = self.old_handle {
self.store.reload(&handle, &module).unwrap();
handle
} else {
self.store.load(&module).unwrap()
};
ShaderInfo {
handle,
dependencies: self.included,
}
}
}
pub struct ShaderWatcher {
store: Arc<ShaderStore>,
root_path: PathBuf,
_watcher: RecommendedWatcher,
notify_rx: Receiver<RawEvent>,
file_handles: RwLock<HashMap<PathBuf, ShaderHandle>>,
file_infos: RwLock<HashMap<PathBuf, ShaderHandle>>,
}
impl ShaderLoader {
impl ShaderWatcher {
pub fn new(
store: Arc<ShaderStore>,
root_path: impl AsRef<Path>,
@ -103,7 +207,7 @@ impl ShaderLoader {
root_path,
_watcher: watcher,
notify_rx,
file_handles: Default::default(),
file_infos: Default::default(),
})
}
@ -116,12 +220,8 @@ impl ShaderLoader {
op: Ok(op),
cookie: _, // TODO use cookie to disambiguate updates
} => {
let handles_read = self.file_handles.read().unwrap();
let path_buf = path.to_path_buf();
if let Some(handle) = handles_read.get(&path_buf) {
if op.contains(notify::Op::CREATE) {
let _result = self.reload(handle, path_buf);
}
if op.contains(notify::Op::CREATE) {
self.on_changed(&path);
}
}
other => panic!("unexpected shader loader watcher event: {:#?}", other),
@ -132,42 +232,26 @@ impl ShaderLoader {
}
}
pub fn add_file(&self, shader_path: impl AsRef<Path>) -> Result<ShaderHandle, ShaderError> {
let path_buf = self.root_path.join(shader_path);
let module = self.load(&path_buf)?;
let handle = self.store.load(&module)?;
self.file_handles.write().unwrap().insert(path_buf, handle);
Ok(handle)
}
fn load(&self, shader_path: impl AsRef<Path>) -> Result<naga::Module, ShaderError> {
// TODO handle IO errors
let source = read_to_string(shader_path).unwrap();
let mut parser = naga::front::wgsl::Parser::new();
match parser.parse(&source) {
Ok(module) => Ok(module),
// TODO handle parsing errors
Err(error) => panic!("wgsl parsing error:\n{}", error.emit_to_string(&source)),
}
}
fn reload(
&self,
handle: &ShaderHandle,
shader_path: impl AsRef<Path>,
) -> Result<(), ShaderError> {
match self.load(shader_path) {
Ok(shader) => match self.store.reload(handle, &shader) {
Ok(()) => return Ok(()),
Err(e) => {
eprintln!("Shader reload error: {:?}", e);
return Err(e);
}
},
Err(e) => {
eprintln!("Shader reload error: {:?}", e);
return Err(e);
pub fn on_changed(&self, path: &Path) {
let infos_read = self.file_infos.read().unwrap();
let path_buf = path.to_path_buf();
match infos_read.get(&path_buf) {
Some(handle) => {
let mut loader = ShaderLoader::new(&self.store, &self.root_path, Some(*handle));
loader.include_file(&path_buf);
let _info = loader.finish(); // TODO update dependencies
}
_ => {}
}
}
pub fn add_file(&self, shader_path: impl AsRef<Path>) -> Result<ShaderHandle, ShaderError> {
let shader_path_buf = shader_path.as_ref().to_path_buf();
let mut loader = ShaderLoader::new(&self.store, &self.root_path, None);
loader.include_file(&shader_path_buf);
let info = loader.finish();
let mut infos_write = self.file_infos.write().unwrap();
infos_write.insert(self.root_path.join(shader_path_buf), info.handle);
Ok(info.handle)
}
}