diff --git a/Cargo.lock b/Cargo.lock index e0b0a71..1c5f7eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,18 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -728,6 +740,15 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.15.2" @@ -953,6 +974,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" + [[package]] name = "ident_case" version = "1.0.1" @@ -1111,6 +1138,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "leb128" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "884e2677b40cc8c339eaefcb701c32ef1fd2493d71118dc0ca4b6a736c93bd67" + [[package]] name = "lebe" version = "0.5.2" @@ -1338,6 +1371,8 @@ dependencies = [ "sha1", "sha2", "thiserror 2.0.12", + "url", + "waki", "wasm-bindgen", "wasm-bindgen-futures", ] @@ -1425,6 +1460,16 @@ dependencies = [ "yansi", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.94" @@ -1867,6 +1912,12 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "semver" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" + [[package]] name = "serde" version = "1.0.219" @@ -2030,6 +2081,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "spdx" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3" +dependencies = [ + "smallvec", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -2050,9 +2110,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.100" +version = "2.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" dependencies = [ "proc-macro2", "quote", @@ -2353,6 +2413,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "unindent" version = "0.2.4" @@ -2411,6 +2477,32 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "waki" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "068839b80fb420247a7f1423fe30c4e768c0e8f9a16a62a85ecd68ef6a999f92" +dependencies = [ + "anyhow", + "http", + "serde", + "serde_urlencoded", + "url", + "waki-macros", + "wit-bindgen", +] + +[[package]] +name = "waki-macros" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89c18865d2c6c414628037585c1f2dda43a578d2f9ba1500fbb2927a5a3087ba" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "want" version = "0.3.1" @@ -2432,7 +2524,7 @@ version = "0.14.2+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen-rt 0.39.0", ] [[package]] @@ -2508,6 +2600,32 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.219.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8aa79bcd666a043b58f5fa62b221b0b914dd901e6f620e8ab7371057a797f3e1" +dependencies = [ + "leb128", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.219.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1ef51bd442042a2a7b562dddb6016ead52c4abab254c376dcffc83add2c9c34" +dependencies = [ + "anyhow", + "indexmap 2.9.0", + "serde", + "serde_derive", + "serde_json", + "spdx", + "wasm-encoder", + "wasmparser", +] + [[package]] name = "wasm-streams" version = "0.4.2" @@ -2521,6 +2639,19 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasmparser" +version = "0.219.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5220ee4c6ffcc0cb9d7c47398052203bc902c8ef3985b0c8134118440c0b2921" +dependencies = [ + "ahash", + "bitflags 2.9.0", + "hashbrown 0.14.5", + "indexmap 2.9.0", + "semver", +] + [[package]] name = "web-sys" version = "0.3.77" @@ -2790,6 +2921,36 @@ dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e11ad55616555605a60a8b2d1d89e006c2076f46c465c892cc2c153b20d4b30" +dependencies = [ + "wit-bindgen-rt 0.34.0", + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "163cee59d3d5ceec0b256735f3ab0dccac434afb0ec38c406276de9c5a11e906" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rt" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "744845cde309b8fa32408d6fb67456449278c66ea4dcd96de29797b302721f02" +dependencies = [ + "bitflags 2.9.0", +] + [[package]] name = "wit-bindgen-rt" version = "0.39.0" @@ -2799,6 +2960,74 @@ dependencies = [ "bitflags 2.9.0", ] +[[package]] +name = "wit-bindgen-rust" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6919521fc7807f927a739181db93100ca7ed03c29509b84d5f96b27b2e49a9a" +dependencies = [ + "anyhow", + "heck", + "indexmap 2.9.0", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c967731fc5d50244d7241ecfc9302a8929db508eea3c601fbc5371b196ba38a5" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.219.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8479a29d81c063264c3ab89d496787ef78f8345317a2dcf6dece0f129e5fcd" +dependencies = [ + "anyhow", + "bitflags 2.9.0", + "indexmap 2.9.0", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.219.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca004bb251010fe956f4a5b9d4bf86b4e415064160dd6669569939e8cbf2504f" +dependencies = [ + "anyhow", + "id-arena", + "indexmap 2.9.0", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + [[package]] name = "writeable" version = "0.6.1" diff --git a/Cargo.toml b/Cargo.toml index 23fa1ac..4020e40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,10 @@ wasm-bindgen = { version = "0.2.100", optional = true, features = [ ] } serde-wasm-bindgen = { version = "0.6.5", optional = true } wasm-bindgen-futures = { version = "0.4.42", optional = true } +url = "2.5" + +[target.'cfg(target_os = "wasi")'.dependencies] +waki = "0.3" [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/src/registry.rs b/src/registry.rs index d1ffd2e..9ec0cf9 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -81,7 +81,7 @@ pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result anyhow::Result { match name { HarmonyEncodingName::HarmonyGptOss => { @@ -116,7 +116,46 @@ pub async fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result< FormattingToken::EndMessageDoneSampling, FormattingToken::EndMessageAssistantToTool, ]), - conversation_has_function_tools: Arc::new(AtomicBool::new(false)), + }) + } + } +} + +#[cfg(all(target_arch = "wasm32", target_os = "wasi"))] +pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result { + match name { + HarmonyEncodingName::HarmonyGptOss => { + let n_ctx = 1_048_576; // 2^20 + let max_action_length = 524_288; // 2^19 + let encoding_ext = tiktoken_ext::Encoding::O200kHarmony; + Ok(HarmonyEncoding { + name: name.to_string(), + n_ctx, + tokenizer: Arc::new(encoding_ext.load()?), + tokenizer_name: encoding_ext.name().to_owned(), + max_message_tokens: n_ctx - max_action_length, + max_action_length, + format_token_mapping: make_mapping([ + (FormattingToken::Start, "<|start|>"), + (FormattingToken::Message, "<|message|>"), + (FormattingToken::EndMessage, "<|end|>"), + (FormattingToken::EndMessageDoneSampling, "<|return|>"), + (FormattingToken::Refusal, "<|refusal|>"), + (FormattingToken::ConstrainedFormat, "<|constrain|>"), + (FormattingToken::Channel, "<|channel|>"), + (FormattingToken::EndMessageAssistantToTool, "<|call|>"), + (FormattingToken::BeginUntrusted, "<|untrusted|>"), + (FormattingToken::EndUntrusted, "<|end_untrusted|>"), + ]), + stop_formatting_tokens: HashSet::from([ + FormattingToken::EndMessageDoneSampling, + FormattingToken::EndMessageAssistantToTool, + FormattingToken::EndMessage, + ]), + stop_formatting_tokens_for_assistant_actions: HashSet::from([ + FormattingToken::EndMessageDoneSampling, + FormattingToken::EndMessageAssistantToTool, + ]), }) } } diff --git a/src/tiktoken_ext/public_encodings.rs b/src/tiktoken_ext/public_encodings.rs index ab9c435..3545b59 100644 --- a/src/tiktoken_ext/public_encodings.rs +++ b/src/tiktoken_ext/public_encodings.rs @@ -104,7 +104,7 @@ impl Encoding { .load() } - #[cfg(target_arch = "wasm32")] + #[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] pub async fn load_from_name(name: impl AsRef) -> Result { let name = name.as_ref(); Self::from_name(name) @@ -113,6 +113,14 @@ impl Encoding { .await } + #[cfg(all(target_arch = "wasm32", target_os = "wasi"))] + pub fn load_from_name(name: impl AsRef) -> Result { + let name = name.as_ref(); + Self::from_name(name) + .ok_or_else(|| LoadError::UnknownEncodingName(name.to_string()))? + .load() + } + pub fn name(&self) -> &'static str { match self { Self::O200kBase => "o200k_base", @@ -202,7 +210,7 @@ impl Encoding { } } - #[cfg(target_arch = "wasm32")] + #[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] pub async fn load(&self) -> Result { let url = self.public_vocab_file_url(); let vocab_bytes = download_or_find_cached_file_bytes(&url, Some(self.expected_hash())) @@ -237,6 +245,40 @@ impl Encoding { } } + #[cfg(all(target_arch = "wasm32", target_os = "wasi"))] + pub fn load(&self) -> Result { + let url = self.public_vocab_file_url(); + let vocab_bytes = download_or_find_cached_file_bytes(&url, Some(self.expected_hash())) + .map_err(LoadError::DownloadOrLoadVocabFile)?; + + match self { + Self::O200kHarmony => { + let mut specials: Vec<(String, Rank)> = self + .special_tokens() + .iter() + .map(|(s, r)| ((*s).to_string(), *r)) + .collect(); + specials.extend((200014..=201088).map(|id| (format!("<|reserved_{id}|>"), id))); + load_encoding_from_bytes(&vocab_bytes, None, specials, &self.pattern()) + } + Self::O200kBase => { + let mut specials: Vec<(String, Rank)> = self + .special_tokens() + .iter() + .map(|(s, r)| ((*s).to_string(), *r)) + .collect(); + specials.extend((199998..=201088).map(|id| (format!("<|reserved_{id}|>"), id))); + load_encoding_from_bytes(&vocab_bytes, None, specials, &self.pattern()) + } + _ => load_encoding_from_bytes( + &vocab_bytes, + None, + self.special_tokens().iter().cloned(), + &self.pattern(), + ), + } + } + fn public_vocab_file_url(&self) -> String { let base = tiktoken_base_url(); match self { @@ -390,6 +432,29 @@ where load_tiktoken_vocab(reader, expected_hash) } +#[cfg(any(target_arch = "wasm32"))] +pub fn load_encoding_from_bytes( + bytes: &[u8], + expected_hash: Option<&str>, + special_tokens: S, + pattern: &str, +) -> Result +where + S: IntoIterator, + TS: Into, +{ + let cursor = std::io::Cursor::new(bytes); + let reader = std::io::BufReader::new(cursor); + let encoder = load_tiktoken_vocab(reader, expected_hash) + .map_err(LoadError::InvalidTiktokenVocabFile)?; + CoreBPE::new( + encoder, + special_tokens.into_iter().map(|(k, v)| (k.into(), v)), + pattern, + ) + .map_err(LoadError::CoreBPECreationFailed) +} + pub fn load_encoding_from_file( file_path: P, expected_hash: Option<&str>, @@ -440,7 +505,7 @@ fn download_or_find_cached_file( Ok(cache_path) } -#[cfg(target_arch = "wasm32")] +#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] async fn download_or_find_cached_file_bytes( url: &str, expected_hash: Option<&str>, @@ -459,6 +524,25 @@ async fn download_or_find_cached_file_bytes( Ok(bytes) } +#[cfg(all(target_arch = "wasm32", target_os = "wasi"))] +fn download_or_find_cached_file_bytes( + url: &str, + expected_hash: Option<&str>, +) -> Result, RemoteVocabFileError> { + let bytes = load_remote_file_bytes(url)?; + if let Some(expected_hash) = expected_hash { + let computed_hash = format!("{:x}", Sha256::digest(&bytes)); + if computed_hash != expected_hash { + return Err(RemoteVocabFileError::HashMismatch { + file_url: url.to_string(), + expected_hash: expected_hash.to_string(), + computed_hash, + }); + } + } + Ok(bytes) +} + fn resolve_cache_dir() -> Result { // we use a different env var and a different default dir name to avoid // conflicts with the python tiktoken package, while sharing a cache dir @@ -534,17 +618,27 @@ fn load_remote_file(url: &str, destination: &Path) -> Result Result { + Err(RemoteVocabFileError::FailedToDownloadOrLoadVocabFile( + Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Downloading files is not supported in wasm32, use async version", + )), + )) +} + +#[cfg(all(target_arch = "wasm32", target_os = "wasi"))] fn load_remote_file(_url: &str, _destination: &Path) -> Result { Err(RemoteVocabFileError::FailedToDownloadOrLoadVocabFile( Box::new(std::io::Error::new( std::io::ErrorKind::Other, - "Downloading files is not supported in wasm32", + "Synchronous file downloading is not supported in WASI, use async version", )), )) } -#[cfg(target_arch = "wasm32")] +#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] async fn load_remote_file_bytes(url: &str) -> Result, RemoteVocabFileError> { use reqwest::Client; @@ -562,6 +656,46 @@ async fn load_remote_file_bytes(url: &str) -> Result, RemoteVocabFileErr Ok(bytes.to_vec()) } +#[cfg(all(target_arch = "wasm32", target_os = "wasi"))] +fn load_remote_file_bytes(url: &str) -> Result, RemoteVocabFileError> { + use waki::Client; + + // Create a waki client and make the request + let client = Client::new(); + let response = client + .get(url) + .send() + .map_err(|e| RemoteVocabFileError::FailedToDownloadOrLoadVocabFile( + Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("Request failed: {}", e) + )) + ))?; + + // Check if the response was successful + let status_code = response.status_code(); + if !(200..300).contains(&status_code) { + return Err(RemoteVocabFileError::FailedToDownloadOrLoadVocabFile( + Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("HTTP request failed with status: {}", status_code) + )) + )); + } + + // Get the response body as bytes + let bytes = response + .body() + .map_err(|e| RemoteVocabFileError::FailedToDownloadOrLoadVocabFile( + Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("Failed to read response body: {}", e) + )) + ))?; + + Ok(bytes) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/wasm_module.rs b/src/wasm_module.rs index 1b96a4f..5f11587 100644 --- a/src/wasm_module.rs +++ b/src/wasm_module.rs @@ -333,6 +333,7 @@ pub enum StreamState { Content, } +#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] #[wasm_bindgen] pub async fn load_harmony_encoding( name: &str, @@ -344,8 +345,26 @@ pub async fn load_harmony_encoding( let parsed: HarmonyEncodingName = name .parse::() .map_err(|e| JsValue::from_str(&e.to_string()))?; - let encoding = - inner_load_harmony_encoding(parsed).map_err(|e| JsValue::from_str(&e.to_string()))?; + let encoding = inner_load_harmony_encoding(parsed) + .await + .map_err(|e| JsValue::from_str(&e.to_string()))?; + Ok(JsHarmonyEncoding { inner: encoding }) +} + +#[cfg(all(target_arch = "wasm32", target_os = "wasi"))] +#[wasm_bindgen] +pub fn load_harmony_encoding( + name: &str, + base_url: Option, +) -> Result { + if let Some(base) = base_url { + crate::tiktoken_ext::set_tiktoken_base_url(base); + } + let parsed: HarmonyEncodingName = name + .parse::() + .map_err(|e| JsValue::from_str(&e.to_string()))?; + let encoding = inner_load_harmony_encoding(parsed) + .map_err(|e| JsValue::from_str(&e.to_string()))?; Ok(JsHarmonyEncoding { inner: encoding }) }