Per-checkpoint artifact generation: embeddings, GAT attention weights,
teacher↔student CKA, loss-landscape grids, fusion-policy traces. Driven
by AnalysisConfig and dispatched directly through
graphids.exp.runtime.run_stage to Analyzer(spec).run().
Distinct from graphids.analysis, which owns
cross-run statistical comparison from the MLflow catalog (no torch,
login-safe).
graphids/core/artifacts/
├── analyzer.py orchestrates load → compute → save loop, writes manifest
├── _dispatch.py ARTIFACTS table — each row is the only place compute + I/O meet
├── compute.py pure compute fns + frozen result dataclasses (no fs)
├── io.py every read (val data, teacher ckpt, fusion eval) + every write
└── __init__.py
The compute / I/O split is structural: every compute_* function takes
pre-loaded models and pre-built val_data, returns a frozen dataclass,
and never touches the filesystem. io.save_* consumes those dataclasses
and writes; io.load_* reads. The dispatch table is the single seam
between the two — adding an artifact means one new row in ARTIFACTS,
one new compute fn, and one new save fn.
io.load_val_data goes through CANBusSource → state.get_or_build
— the same path GraphDataModule.setup takes during training.
val_fraction, scaler strategy, and cache digest
live on the source dataclass; the analyzer picks up changes there
automatically with no parallel declaration.
io.load_teacher and the student ckpt load in Analyzer.run both go
through
safe_load_checkpoint
— the canonical "ckpt → module" registry. io.load_fusion_eval wraps
the same FusionDataModule training/eval uses.
Analyzer.run() writes analysis_manifest.json next to the artifacts:
the rendered analysis identity, expected outputs (derived from
expected_outputs(spec)), and which actually exist on disk after the
run. Useful as provenance when an analyze run was submitted via SLURM
and the output dir is the only artifact left.
def__init__(self,spec:AnalysisConfig):self.spec=specifnotPath(spec.ckpt_path).exists():raiseFileNotFoundError(f"Checkpoint not found: {spec.ckpt_path}")ifspec.ckaandnotPath(spec.cka_teacher_ckpt).exists():raiseFileNotFoundError(f"Teacher checkpoint not found: {spec.cka_teacher_ckpt}")
def__init__(self,spec:AnalysisConfig):self.spec=specifnotPath(spec.ckpt_path).exists():raiseFileNotFoundError(f"Checkpoint not found: {spec.ckpt_path}")ifspec.ckaandnotPath(spec.cka_teacher_ckpt).exists():raiseFileNotFoundError(f"Teacher checkpoint not found: {spec.cka_teacher_ckpt}")
Pure compute primitives — no filesystem, no MLflow, no logging side-effects.
Each compute_* returns a frozen dataclass (or plain dict, for CKA's
single layer→score mapping) that io.save_* knows how to serialize.
The analyzer wraps the whole batch in :func:eval_mode, so no compute
function re-enters it.
@torch.no_grad()defcompute_attention(model:torch.nn.Module,val_data:list,device:torch.device,*,max_samples:int=50,)->AttentionResult|None:"""Per-sample per-layer GAT attention weights. ``None`` if model lacks them."""ifgetattr(model,"conv_type",None)!="gat":returnNoneloader=PyGDataLoader(val_data[:max_samples],batch_size=1)out:dict[str,np.ndarray]={}sample_idx=0forbatchinloader:batch=batch.clone().to(device)_xs,attention_weights=model(batch,return_attention_weights=True)prefix=f"sample_{sample_idx}"out[f"{prefix}_label"]=batch.y[0].cpu().numpy()forlayer_idx,alphainenumerate(attention_weights):out[f"{prefix}_layer_{layer_idx}_alpha"]=alpha.cpu().numpy()sample_idx+=1returnAttentionResult(weights=out,n_samples=sample_idx)
Full cross-matrix linear CKA between all teacher and student layers.
Returns keys teacher_{i}student for every combination, giving an
n_teacher × n_student matrix. Previously used min(n_teacher, n_student)
and only compared corresponding pairs — this silently dropped teacher
layers that had no student counterpart (bug: 3-layer teacher × 2-layer
student only produced 2 values instead of 6).
defcompute_cka(student:torch.nn.Module,teacher:torch.nn.Module,val_data:list,device:torch.device,*,max_samples:int=500,)->dict[str,float]:"""Full cross-matrix linear CKA between all teacher and student layers. Returns keys teacher_{i}_student_{j} for every combination, giving an n_teacher × n_student matrix. Previously used min(n_teacher, n_student) and only compared corresponding pairs — this silently dropped teacher layers that had no student counterpart (bug: 3-layer teacher × 2-layer student only produced 2 values instead of 6). """student_reps=_collect_reps(student,val_data,device,max_samples)teacher_reps=_collect_reps(teacher,val_data,device,max_samples)return{f"teacher_{i}_student_{j}":_linear_cka(teacher_reps[i],student_reps[j])foriinrange(len(teacher_reps))forjinrange(len(student_reps))}
defcompute_fusion_policy(module,td:TensorDict,labels:torch.Tensor)->PolicyResult:"""Run fusion module on pre-extracted val states; return alphas + Q-values + labels."""fromgraphids.core.models.fusion.baseimportflatten_featuresresult=module.predict(td)q_vals:np.ndarray|None=Noneifhasattr(module,"q_values"):flat_obs=flatten_features(result["td_norm"])q_vals=module.q_values(flat_obs).cpu().numpy()# RL models (Bandit/DQN) return alphas (mixing weights); non-RL models# (MLP/MoE/WeightedAvg) return only fused_scores — use those as the signal.alpha_tensor=result.get("alphas",result["fused_scores"])returnPolicyResult(alphas=alpha_tensor.detach().cpu().numpy(),labels=labels.cpu().numpy(),q_values=q_vals,)
defcompute_landscape(model:torch.nn.Module,model_type:str,val_data:list,device:torch.device,hparams,*,resolution:int=51,scale:float=1.0,seed:int=42,max_graphs:int=500,dataset:str="",)->LandscapeResult:"""Loss on a ``resolution × resolution`` grid of filter-normalized perturbations. ``KeyError`` on unknown ``model_type`` — dispatch's ``applies_to`` should filter callers; reaching this with an unsupported type is a routing bug. """loss_fn=_LOSS_FN[model_type]iflen(val_data)>max_graphs:rng=np.random.default_rng(seed)idx=rng.choice(len(val_data),max_graphs,replace=False)data=[val_data[i]foriinidx]else:data=val_datadataloader=PyGDataLoader(data,batch_size=min(256,len(data)))dir1=_random_direction(model,seed)dir2=_random_direction(model,seed+1)base=[p.data.clone()forpinmodel.parameters()]alphas=np.linspace(-scale,scale,resolution)betas=np.linspace(-scale,scale,resolution)xs,ys,losses=[],[],[]forainalphas:forbinbetas:_perturb_model(model,base,dir1,dir2,a,b)losses.append(loss_fn(model,dataloader,device,hparams))xs.append(float(a))ys.append(float(b))_perturb_model(model,base,dir1,dir2,0.0,0.0)# restorereturnLandscapeResult(x=xs,y=ys,loss=losses,model_type=model_type,dataset=dataset)
defload_teacher(model_type:str,ckpt_path:str,device:torch.device)->torch.nn.Module:"""Load a teacher checkpoint for analysis."""teacher=safe_load_checkpoint(model_type,ckpt_path,map_location=device)teacher.eval()returnteacher
defload_val_data(*,lake_root:str,dataset:str,vocab_scope:str,seed:int,representation_cfg:GraphRepresentationCfg,)->list:"""Load the val split through the same source/cache path as training."""fromgraphids.core.data.datasets.can_busimportCANBusSourcefromgraphids.core.data.stateimportget_or_buildstate=get_or_build(CANBusSource(name=dataset,lake_root=lake_root,seed=seed,vocab_scope=vocab_scope,representation_cfg=representation_cfg,))val=list(state.val)log.info("data_loaded",n_val=len(val))returnval