I’m trying to manually reproduce the inference forward-pass to understand exactly how quantized inference works. To do so, I trained and quantized a model in PyTorch using QAT, manually simulate the forward pass, and compare the outputs to a PyTorch inference.
The activations, however, start to diverge right after the first layer, as some channels differ by a few units (~5–15 difference):
🔸 PyTorch Quantized Model Output INT8 logits : [[182, 122, 163, 129, 113, 114, 165, 105, 139, 139, 152, 113, 179, 148, 159, 132]] 🔹 My Golden Model Simulation layer 0 unsigned logits : [[182, 135, 163, 147, 128, 128, 183, 128, 152, 128, 168, 130, 179, 164, 140, 128]]
This is the script I used for inference:
import torch import numpy as np from sklearn.preprocessing import StandardScaler from sklearn import datasets # ---------- Input (floating point) ---------- model_int8 = torch.load("trained_qat_model2.pth", map_location="cpu") input_raw = np.array([5.3, 3.7, 1.5, 0.2], dtype=np.float32) #setosa input_raw = np.array([7.0, 3.2, 4.7, 1.4], dtype=np.float32) #versicolor #input_raw = np.array([7.7, 3.8, 6.7, 2.2], dtype=np.float32) #virginica # ---------- StandardScaler ---------- iris = datasets.load_iris() scaler = StandardScaler().fit(iris.data) input_float = scaler.transform([input_raw])[0] ############################################################### capture = {} def dump_first_layer(mod, inp, out): x_q = inp[0] w_q = mod.weight() b = mod.bias() capture["x_int8"] = x_q.int_repr() # int8 input capture["s_x"] = x_q.q_scale() capture["zp_x"] = x_q.q_zero_point() capture["w_int8"] = w_q.int_repr() # int8 weights capture["s_w"] = w_q.q_per_channel_scales() # fp32 weight scales capture["zp_w"] = w_q.q_per_channel_zero_points() # int (all zero here – symmetric) capture["bias_fp32"] = b.clone() capture["y_int8"] = out.int_repr() # int8 after requant capture["s_y"] = out.q_scale() capture["zp_y"] = out.q_zero_point() # ---------- Pytorch inference ---------- hook = model_int8.fc1.register_forward_hook(dump_first_layer) test_input = torch.tensor([input_float], dtype=torch.float32) logits_fp32 = model_int8(test_input) hook.remove() print() print("🔸 PyTorch Quantized Model Output") print("input :", capture["x_int8"]) print("INT8 logits :", (capture["y_int8"]).tolist()) ############################################################### # ---------- Simulate inference ---------- print() print("🔹 Golden Model Simulation") # ---------- real uint8 activations & first-layer params ---------- x_u8 = capture["x_int8"].numpy() # 0…255 zp_x = int(capture["zp_x"]) # 128 for our symmetric model s_in = float(capture["s_x"]) layers = [m for m in model_int8.modules() if isinstance(m, torch.nn.quantized.Linear)] for i, layer in enumerate(layers): # ------- weight & per-output parameters ------- w = layer.weight().int_repr().numpy().astype(np.int8) s_w = layer.weight().q_per_channel_scales().numpy() b_fp = layer.bias().detach().numpy() s_out = float(layer.scale) zp_y = int(layer.zero_point) # ------- 1. signed-domain input ------- x_s8 = x_u8.astype(np.int16) - zp_x # [-128,127] # ------- 2. int32/64 MAC ------- acc = x_s8.astype(np.int32) @ w.T.astype(np.int32) # [out] # ------- 3. integer bias before requant ------- b_int = np.round(b_fp / (s_in * s_w)).astype(np.int32) acc += b_int # ------- 4. requant (float form; ±1 LSB accurate) ------- mult = (s_in * s_w) / s_out out_s32 = np.round(acc * mult).astype(np.int32) # ------- 5. add zp_y, clamp, optional ReLU ------- out_u8 = out_s32 + zp_y out_u8 = np.clip(out_u8, 0, 255).astype(np.uint8) if i != len(layers)-1: out_u8[out_u8 < zp_y] = zp_y # ------- Print ------- unsigned_logits = out_u8.astype(np.uint8) print(f"layer {i} unsigned logits :", unsigned_logits.tolist()) # ------- Feed next layer ------- x_u8, zp_x, s_in = out_u8, zp_y, s_out
This is the script I used to train the model:
import torch import torch.nn as nn import torch.optim as optim import torch.quantization from torch.utils.data import TensorDataset, DataLoader from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler #from quant_model import QuantizableFNN, xQuantizableFNN from torch.quantization.observer import PerChannelMinMaxObserver, MovingAverageMinMaxObserver from torch.ao.quantization import ( MovingAverageMinMaxObserver, FakeQuantize, QConfig, default_per_channel_weight_fake_quant) class QuantizableFNN(nn.Module): def __init__(self): super().__init__() self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() self.fc1 = nn.Linear(4, 16) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(16, 3) def forward(self, x): x = self.quant(x) x = self.relu1(self.fc1(x)) x = self.fc2(x) x = self.dequant(x) return x # Load and preprocess dataset iris = datasets.load_iris() X = StandardScaler().fit_transform(iris.data) y = iris.target X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) X_train_tensor = torch.tensor(X_train, dtype=torch.float32) y_train_tensor = torch.tensor(y_train, dtype=torch.long) train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=16, shuffle=True) # Train with QAT # Custom symmetric per-channel qconfig model_fp32 = QuantizableFNN() act_fake_quant = FakeQuantize.with_args( observer=MovingAverageMinMaxObserver, qscheme=torch.per_tensor_symmetric, reduce_range=False ) model_fp32.qconfig = QConfig( activation=act_fake_quant, weight=default_per_channel_weight_fake_quant) torch.quantization.prepare_qat(model_fp32, inplace=True) optimizer = optim.Adam(model_fp32.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() for epoch in range(100): correct = 0 total = 0 running_loss = 0.0 for inputs, labels in train_loader: optimizer.zero_grad() outputs = model_fp32(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # Accuracy calculation _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum().item() total += labels.size(0) running_loss += loss.item() acc = 100.0 * correct / total avg_loss = running_loss / len(train_loader) print(f"Epoch {epoch+1:03d} | Loss: {avg_loss:.4f} | Accuracy: {acc:.2f}%") # Convert and save model_int8 = torch.quantization.convert(model_fp32.eval(), inplace=False) torch.save(model_int8, "trained_qat_model2.pth") print("✅ Quantized model saved as 'trained_qat_model2.pth'")