Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub(crate) mod manifest;
pub(crate) mod pdk;
mod plugin;
mod plugin_builder;
mod pool;
mod readonly_dir;
mod timer;

Expand All @@ -43,6 +44,7 @@ pub use plugin::{
CancelHandle, CompiledPlugin, Plugin, WasmInput, EXTISM_ENV_MODULE, EXTISM_USER_MODULE,
};
pub use plugin_builder::{DebugOptions, PluginBuilder};
pub use pool::{Pool, PoolBuilder, PoolPlugin};

pub(crate) use internal::{Internal, Wasi};
pub(crate) use timer::{Timer, TimerAction};
Expand Down
1 change: 1 addition & 0 deletions runtime/src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ pub(crate) fn profiling_strategy() -> ProfilingStrategy {
/// Defines an input type for Wasm data.
///
/// Types that implement `Into<WasmInput>` can be passed directly into `Plugin::new`
#[derive(Clone)]
pub enum WasmInput<'a> {
/// Raw Wasm module
Data(std::borrow::Cow<'a, [u8]>),
Expand Down
1 change: 1 addition & 0 deletions runtime/src/plugin_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ impl Default for DebugOptions {
}

/// PluginBuilder is used to configure and create `Plugin` instances
#[derive(Clone)]
pub struct PluginBuilder<'a> {
pub(crate) source: WasmInput<'a>,
pub(crate) config: Option<wasmtime::Config>,
Expand Down
207 changes: 207 additions & 0 deletions runtime/src/pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
use std::collections::HashMap;

use crate::{Error, FromBytesOwned, Plugin, PluginBuilder, ToBytes};

// `PoolBuilder` is used to configure and create `Pool`s
#[derive(Debug, Clone)]
pub struct PoolBuilder {
/// Max number of concurrent instances for a plugin - by default this is set to
/// the output of `std::thread::available_parallelism`
pub max_instances: usize,
}

impl PoolBuilder {
/// Create a `PoolBuilder` with default values
pub fn new() -> Self {
Self::default()
}

/// Set the max number of parallel instances
pub fn with_max_instances(mut self, n: usize) -> Self {
self.max_instances = n;
self
}

/// Create a new `Pool` with the given configuration
pub fn build(self) -> Pool {
Pool::new_from_builder(self)
}
}

impl Default for PoolBuilder {
fn default() -> Self {
PoolBuilder {
max_instances: std::thread::available_parallelism()
.expect("available parallelism")
.into(),
}
}
}

/// `PoolPlugin` is used by the pool to track the number of live instances of a particular plugin
#[derive(Clone, Debug)]
pub struct PoolPlugin(std::rc::Rc<std::cell::RefCell<Plugin>>);

impl PoolPlugin {
fn new(plugin: Plugin) -> Self {
Self(std::rc::Rc::new(std::cell::RefCell::new(plugin)))
}

/// Access the underlying plugin
pub fn plugin(&self) -> std::cell::RefMut<Plugin> {
self.0.borrow_mut()
}

/// Helper to call a plugin function on the underlying plugin
pub fn call<'a, Input: ToBytes<'a>, Output: FromBytesOwned>(
&self,
name: impl AsRef<str>,
input: Input,
) -> Result<Output, Error> {
self.plugin().call(name.as_ref(), input)
}

/// Helper to get the underlying plugin's ID
pub fn id(&self) -> uuid::Uuid {
self.plugin().id
}
}

type PluginSource = dyn Fn() -> Result<Plugin, Error>;

struct PoolInner<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq = String> {
plugins: HashMap<Key, Box<PluginSource>>,
instances: HashMap<Key, Vec<PoolPlugin>>,
}

/// `Pool` manages threadsafe access to a limited number of instances of multiple plugins
#[derive(Clone)]
pub struct Pool<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq = String> {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure what the Key parameter is good for. It may be useful to store multiple different plugin types in the same pool, but this is tricky as there is currently no way to get a list of all existing keys. Anyway the pool itself has very little overhead, so it makes more sense for users to write HashMap<Key, Pool> if needed. In my case Im using Vec<Pool<()>> because there is no need for any hashmap.

config: PoolBuilder,
inner: std::sync::Arc<std::sync::Mutex<PoolInner<Key>>>,
}

unsafe impl<T: std::fmt::Debug + Clone + std::hash::Hash + Eq> Send for Pool<T> {}
unsafe impl<T: std::fmt::Debug + Clone + std::hash::Hash + Eq> Sync for Pool<T> {}

impl<T: std::fmt::Debug + Clone + std::hash::Hash + Eq> Default for Pool<T> {
fn default() -> Self {
Self::new_from_builder(PoolBuilder::default())
}
}

impl<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq> Pool<Key> {
/// Create a new pool with the defailt configuration
pub fn new() -> Self {
Self::default()
}

/// Create a new pool configured using a `PoolBuilder`
pub fn new_from_builder(builder: PoolBuilder) -> Self {
Pool {
config: builder,
inner: std::sync::Arc::new(std::sync::Mutex::new(PoolInner {
plugins: Default::default(),
instances: Default::default(),
})),
}
}

/// Add a plugin using a callback function
pub fn add<F: 'static + Fn() -> Result<Plugin, Error>>(&self, key: Key, source: F) {
let mut pool = self.inner.lock().unwrap();
if !pool.instances.contains_key(&key) {
pool.instances.insert(key.clone(), vec![]);
}

pool.plugins.insert(key, Box::new(source));
}

/// Add a plugin using a `PluginBuilder`
pub fn add_builder(&self, key: Key, source: PluginBuilder<'static>) {
let mut pool = self.inner.lock().unwrap();
if !pool.instances.contains_key(&key) {
pool.instances.insert(key.clone(), vec![]);
}

pool.plugins
.insert(key, Box::new(move || source.clone().build()));
}

fn find_available(&self, key: &Key) -> Result<Option<PoolPlugin>, Error> {
let mut pool = self.inner.lock().unwrap();
if let Some(entry) = pool.instances.get_mut(key) {
for instance in entry.iter() {
if std::rc::Rc::strong_count(&instance.0) == 1 {
return Ok(Some(instance.clone()));
}
}
}
Ok(None)
}

/// Get the number of live instances for a plugin
pub fn count(&self, key: &Key) -> usize {
self.inner
.lock()
.unwrap()
.instances
.get(key)
.map(|x| x.len())
.unwrap_or_default()
}

/// Get access to a plugin, this will create a new instance if needed (and allowed by the specified
/// max_instances). `Ok(None)` is returned if the timeout is reached before an available plugin could be
/// acquired
pub fn get(
&self,
key: &Key,
timeout: std::time::Duration,
) -> Result<Option<PoolPlugin>, Error> {
let start = std::time::Instant::now();
let max = self.config.max_instances;
if let Some(avail) = self.find_available(key)? {
return Ok(Some(avail));
}

{
let mut pool = self.inner.lock().unwrap();
if pool.instances.get(key).map(|x| x.len()).unwrap_or_default() < max {
if let Some(source) = pool.plugins.get(key) {
let plugin = source()?;
let instance = PoolPlugin::new(plugin);
let v = pool.instances.get_mut(key).unwrap();
v.push(instance);
return Ok(Some(v.last().unwrap().clone()));
}
}
}

loop {
if let Ok(Some(x)) = self.find_available(key) {
return Ok(Some(x));
}
if std::time::Instant::now() - start > timeout {
return Ok(None);
}

std::thread::sleep(std::time::Duration::from_millis(100));
}
}

/// Access a plugin in a callback function. This calls `Pool::get` then the provided
/// callback. `Ok(None)` is returned if the timeout is reached before an available
/// plugin could be acquired
pub fn with_plugin<T>(
&self,
key: &Key,
timeout: std::time::Duration,
f: impl FnOnce(&mut Plugin) -> Result<T, Error>,
) -> Result<Option<T>, Error> {
if let Some(plugin) = self.get(key, timeout)? {
return f(&mut plugin.plugin()).map(Some);
}
Ok(None)
}
}
1 change: 1 addition & 0 deletions runtime/src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod issues;
mod kernel;
mod pool;
mod runtime;
48 changes: 48 additions & 0 deletions runtime/src/tests/pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use crate::*;

fn run_thread(p: Pool<String>, i: u64) -> std::thread::JoinHandle<()> {
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(i));
let s: String = p
.get(&"test".to_string(), std::time::Duration::from_secs(1))
.unwrap()
.unwrap()
.call("count_vowels", "abc")
.unwrap();
println!("{}", s);
})
}

#[test]
fn test_threads() {
for i in 1..=3 {
let data = include_bytes!("../../../wasm/code.wasm");
let pool: Pool<String> = PoolBuilder::new().with_max_instances(i).build();

let test = "test".to_string();
pool.add_builder(
test.clone(),
extism::PluginBuilder::new(extism::Manifest::new([extism::Wasm::data(data)]))
.with_wasi(true),
);

let mut threads = vec![];
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 0));

for t in threads {
t.join().unwrap();
}
assert!(pool.count(&test) <= i);
}
}
Loading