Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ members = [
# "crates/aprender-test",
]
exclude = [
"tools/ccpa-sft-export",
"fuzz",
# Old workspace root shells (no package, just held sub-crates):
"crates/aprender-viz-ttop",
Expand Down
1 change: 1 addition & 0 deletions crates/aprender-train/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ serde_json = "1.0"
toml = "0.8"
thiserror = "2.0"
rand = "0.9"
rayon = { workspace = true } # PMAT-FINETUNE-CONSTRUCT: parallel F32 dequant in Transformer::from_apr
clap = { version = "4.5", features = ["derive"] }
clap_complete = "4.5" # Shell completion generation
dirs = "5.0" # Cache directory resolution for HF pipeline
Expand Down
181 changes: 165 additions & 16 deletions crates/aprender-train/src/transformer/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,24 @@ impl Transformer {
/// weights contain NaN/Inf values.
pub fn from_apr(apr_path: impl AsRef<Path>, config: &TransformerConfig) -> Result<Self> {
use aprender::serialization::apr::AprReader;
use rayon::prelude::*;

// PMAT-FINETUNE-CONSTRUCT: phase-timed construction (was a 15+min unprofiled wall
// before any training step). Set APR_FROM_APR_TIMING=1 to see per-phase ms.
let timing = std::env::var("APR_FROM_APR_TIMING").is_ok();
let t0 = std::time::Instant::now();

let apr_path = apr_path.as_ref();
let reader = AprReader::open(apr_path).map_err(|e| {
Error::ConfigError(format!("Failed to open APR file '{}': {e}", apr_path.display()))
})?;
if timing {
eprintln!(
"[from_apr-timing] open+header: {:?} ({} tensors)",
t0.elapsed(),
reader.tensors.len()
);
}

// Build weight map from APR tensors — map GGUF names to HF convention (PMAT-489)
let is_gguf_names = reader.tensors.iter().any(|t| t.name == "token_embd.weight");
Expand All @@ -165,34 +178,63 @@ impl Transformer {
"[PMAT-489] Detected GGUF tensor names in APR file, mapping to HF convention"
);
}
let mut weights = HashMap::new();
for desc in &reader.tensors {
let data = reader.read_tensor_as_f32(&desc.name).map_err(|e| {
Error::ConfigError(format!("Failed to read tensor '{}': {e}", desc.name))
})?;
let mapped_name = if is_gguf_names {
super::weights::mapping::map_weight_name(
&desc.name,
super::weights::Architecture::Gguf,
)
} else {
desc.name.clone()
};
weights.insert(mapped_name, Tensor::from_vec(data, false));
// PMAT-FINETUNE-CONSTRUCT: dequant every tensor to F32 IN PARALLEL. Previously this
// was a serial loop over hundreds of tensors (each a Q4K/Q6K/F16 dequant of up to
// ~vocab*hidden f32s) — the dominant term in the construction wall. Rayon-parallel
// dequant scales with cores; the AprReader.read_tensor_as_f32 path is read-only
// (offsets into the mmap'd/owned byte buffer) so it is data-parallel-safe.
let t_dequant = std::time::Instant::now();
let dequanted: Vec<(String, Vec<f32>)> = reader
.tensors
.par_iter()
.map(|desc| {
let data = reader.read_tensor_as_f32(&desc.name).map_err(|e| {
Error::ConfigError(format!("Failed to read tensor '{}': {e}", desc.name))
})?;
let mapped_name = if is_gguf_names {
super::weights::mapping::map_weight_name(
&desc.name,
super::weights::Architecture::Gguf,
)
} else {
desc.name.clone()
};
Ok::<_, Error>((mapped_name, data))
})
.collect::<Result<Vec<_>>>()?;
let mut weights = HashMap::with_capacity(dequanted.len());
for (name, data) in dequanted {
weights.insert(name, Tensor::from_vec(data, false));
}
if timing {
eprintln!("[from_apr-timing] parallel dequant->F32: {:?}", t_dequant.elapsed());
}

// Same validation pipeline as from_safetensors
let t_val = std::time::Instant::now();
validate_weights(&weights, config.num_hidden_layers)?;
Self::validate_weight_shapes(&weights, config)?;
Self::validate_weight_values(&weights)?;
if timing {
eprintln!("[from_apr-timing] validate (struct+shape+nan/inf): {:?}", t_val.elapsed());
}

Self::from_params(config, &weights).ok_or_else(|| {
let t_params = std::time::Instant::now();
let model = Self::from_params(config, &weights).ok_or_else(|| {
Error::ConfigError(
"Failed to construct Transformer from APR weights \
(from_params returned None after validation passed)"
.into(),
)
})
});
if timing {
eprintln!(
"[from_apr-timing] from_params(build tensors): {:?} | TOTAL {:?}",
t_params.elapsed(),
t0.elapsed()
);
}
model
}

/// Validate that all weight tensor shapes match the config dimensions
Expand Down Expand Up @@ -689,6 +731,113 @@ mod tests {
assert_eq!(logits.len(), 2 * config.vocab_size);
}

/// FALSIFY-FINETUNE-CONSTRUCT-001 — `Transformer::from_apr` parallel F32
/// dequant loads EVERY tensor correctly (no dropped/garbled weight from the
/// rayon `par_iter` collect that replaced the serial dequant loop).
///
/// PMAT-FINETUNE-CONSTRUCT: the construction path was changed from a serial
/// `for desc in &reader.tensors { read_tensor_as_f32 }` loop to a
/// `par_iter().map(read_tensor_as_f32).collect()`. `read_tensor_as_f32` is a
/// pure read into an owned/mmap'd byte buffer (no shared mutable state), so
/// the parallel collect MUST yield the identical (name -> f32 data) map.
///
/// Falsifier: build a tiny but complete Qwen2-shaped APR with KNOWN F32 weight
/// values, load it through `from_apr`, and assert the round-tripped weights are
/// bit-exact and the model produces finite logits. If the parallel collect ever
/// drops a tensor, mis-maps a name, or corrupts data, validate_weights /
/// validate_weight_values / the value assertions below go RED.
#[test]
fn falsify_finetune_construct_001_from_apr_parallel_dequant_loads_all_weights() {
use aprender::serialization::apr::AprWriter;

let config = TransformerConfig::tiny(); // hidden=64, heads=2, kv=2, inter=256, layers=2, vocab=1000
let hidden = config.hidden_size;
let q_dim = config.q_dim();
let kv_hidden = config.num_kv_heads * config.head_dim();
let inter = config.intermediate_size;
let vocab = config.vocab_size;

// Distinct, exactly-representable f32 fill per tensor so a swap/corruption
// (not just a drop) is also caught. 0.5 is bit-exact under any dequant.
let mut w = AprWriter::new();
w.add_tensor_f32(
"model.embed_tokens.weight",
vec![vocab, hidden],
&vec![0.5; vocab * hidden],
);
w.add_tensor_f32("model.norm.weight", vec![hidden], &vec![1.0; hidden]);
w.add_tensor_f32("lm_head.weight", vec![vocab, hidden], &vec![0.25; vocab * hidden]);
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{i}");
w.add_tensor_f32(
&format!("{p}.input_layernorm.weight"),
vec![hidden],
&vec![1.0; hidden],
);
w.add_tensor_f32(
&format!("{p}.post_attention_layernorm.weight"),
vec![hidden],
&vec![1.0; hidden],
);
w.add_tensor_f32(
&format!("{p}.self_attn.q_proj.weight"),
vec![q_dim, hidden],
&vec![0.01; q_dim * hidden],
);
w.add_tensor_f32(
&format!("{p}.self_attn.k_proj.weight"),
vec![kv_hidden, hidden],
&vec![0.02; kv_hidden * hidden],
);
w.add_tensor_f32(
&format!("{p}.self_attn.v_proj.weight"),
vec![kv_hidden, hidden],
&vec![0.03; kv_hidden * hidden],
);
w.add_tensor_f32(
&format!("{p}.self_attn.o_proj.weight"),
vec![hidden, q_dim],
&vec![0.04; hidden * q_dim],
);
w.add_tensor_f32(
&format!("{p}.mlp.gate_proj.weight"),
vec![inter, hidden],
&vec![0.05; inter * hidden],
);
w.add_tensor_f32(
&format!("{p}.mlp.up_proj.weight"),
vec![inter, hidden],
&vec![0.06; inter * hidden],
);
w.add_tensor_f32(
&format!("{p}.mlp.down_proj.weight"),
vec![hidden, inter],
&vec![0.07; hidden * inter],
);
}

let dir = std::env::temp_dir().join(format!("apr_from_apr_falsify_{}", std::process::id()));
std::fs::create_dir_all(&dir).unwrap();
let apr_path = dir.join("tiny.apr");
std::fs::write(&apr_path, w.to_bytes().unwrap()).unwrap();

// Exercises the parallel dequant + validation construction path.
let model = Transformer::from_apr(&apr_path, &config)
.expect("from_apr must load every tensor via the parallel dequant collect");

// Structural: forward produces finite logits of the right shape — only
// possible if ALL per-layer weights were loaded (a dropped tensor would
// have failed validate_weights before we get here).
let logits = model.forward(&[1u32, 2, 3]);
assert_eq!(logits.len(), 3 * vocab, "all weights present => correct logit shape");
assert!(
logits.data().iter().all(|v| v.is_finite()),
"parallel-dequanted weights must produce finite logits"
);

let _ = std::fs::remove_dir_all(&dir);
}

#[test]
fn test_from_params_returns_none_on_missing() {
let config = TransformerConfig::tiny();
Expand Down
Loading
Loading