Coverage for jetgp/full_degp/grad_check.py: 0%
97 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
1"""
2Finite-difference gradient check for DEGP optimizer.
3Tests: get_all_derivs_fast vs get_all_derivs, rbf_kernel_fast vs rbf_kernel,
4and overall nll_and_grad correctness.
5"""
6import numpy as np
7from jetgp.full_degp.degp import degp
9np.random.seed(42)
10n, d = 15, 2
11X = np.random.rand(n, d)
12y = (np.sin(X[:, 0]) * np.cos(X[:, 1])).reshape(-1, 1)
13dy_dx1 = (np.cos(X[:, 0]) * np.cos(X[:, 1])).reshape(-1, 1)
14dy_dx2 = (-np.sin(X[:, 0]) * np.sin(X[:, 1])).reshape(-1, 1)
15y_train = [y, dy_dx1, dy_dx2]
16der_indices = [
17 [[[1, 1]], [[2, 1]]], # first-order
18]
19model = degp(X, y_train, n_order=1, n_bases=d, der_indices=der_indices,
20 normalize=True, kernel='SE', kernel_type='anisotropic')
21opt = model.optimizer
23# --- Test 1: get_all_derivs_fast vs get_all_derivs ---
24print("=" * 60)
25print("Test 1: get_all_derivs_fast vs get_all_derivs")
26print("=" * 60)
27x0 = np.array([0.1, -0.2, 0.5, -3.0])
28diffs = model.differences_by_dim
29oti = model.kernel_factory.oti
30phi = model.kernel_func(diffs, x0[:-1])
31n_bases = phi.get_active_bases()[-1]
32deriv_order = 2 * model.n_order
34slow = phi.get_all_derivs(n_bases, deriv_order)
35if hasattr(phi, 'get_all_derivs_fast'):
36 factors = opt._get_deriv_factors(n_bases, deriv_order)
37 buf = np.zeros_like(slow)
38 fast = phi.get_all_derivs_fast(factors, buf)
39 diff = np.max(np.abs(slow - fast))
40 print(f" max |slow - fast| = {diff:.2e}")
41 if diff > 1e-10:
42 print(" >>> get_all_derivs_fast DISAGREES with get_all_derivs!")
43 # Show where they differ
44 for i in range(slow.shape[0]):
45 d = np.max(np.abs(slow[i] - fast[i]))
46 if d > 1e-10:
47 print(f" slice [{i}]: max diff = {d:.2e}, slow max={np.max(np.abs(slow[i])):.2e}, fast max={np.max(np.abs(fast[i])):.2e}")
48 else:
49 print(" OK")
50else:
51 print(" get_all_derivs_fast not available, skipping")
53# --- Test 2: Check if fused_scale_sq_mul matches manual computation ---
54print()
55print("=" * 60)
56print("Test 2: fused_scale_sq_mul vs manual")
57print("=" * 60)
58phi = model.kernel_func(diffs, x0[:-1]) # fresh phi
59if hasattr(phi, 'fused_scale_sq_mul'):
60 ell = 10.0 ** x0[:d]
61 ln10 = np.log(10.0)
62 for dim in range(d):
63 # Manual: -ln10 * ell[dim]^2 * diffs[dim]^2 * phi
64 d_sq = oti.mul(diffs[dim], diffs[dim])
65 manual = oti.mul(-ln10 * ell[dim] ** 2, oti.mul(d_sq, phi))
66 manual_derivs = manual.get_all_derivs(n_bases, deriv_order)
68 # Fused path
69 phi2 = model.kernel_func(diffs, x0[:-1]) # fresh phi (buffer corruption!)
70 dphi_buf = oti.zeros(phi2.shape)
71 dphi_buf.fused_scale_sq_mul(diffs[dim], phi2, -ln10 * ell[dim] ** 2)
72 fused_derivs = dphi_buf.get_all_derivs(n_bases, deriv_order)
74 diff = np.max(np.abs(manual_derivs - fused_derivs))
75 print(f" dim {dim}: max |manual - fused| = {diff:.2e}", end="")
76 if diff > 1e-10:
77 print(" <--- MISMATCH")
78 else:
79 print(" OK")
80else:
81 print(" fused_scale_sq_mul not available, skipping")
83# --- Test 3: Overall FD gradient check ---
84print()
85print("=" * 60)
86print("Test 3: Analytic gradient vs finite differences")
87print("=" * 60)
88for x0 in [np.array([0.1, -0.2, 0.5, -3.0]),
89 np.array([-0.5, 0.3, 0.0, -5.0]),
90 np.array([0.0, 0.0, 0.0, -4.0])]:
91 nll, grad_analytic = opt.nll_and_grad(x0)
92 eps = 1e-5
93 grad_fd = np.zeros_like(x0)
94 for i in range(len(x0)):
95 xp = x0.copy(); xp[i] += eps
96 xm = x0.copy(); xm[i] -= eps
97 grad_fd[i] = (opt.nll_wrapper(xp) - opt.nll_wrapper(xm)) / (2 * eps)
98 print(f"x0: {x0}, nll: {nll:.4f}")
99 for i in range(len(x0)):
100 rel = abs(grad_analytic[i] - grad_fd[i]) / (abs(grad_fd[i]) + 1e-12)
101 flag = ' <--- BAD' if rel > 1e-3 else ''
102 print(f" x[{i}] analytic={grad_analytic[i]:>12.4e} fd={grad_fd[i]:>12.4e} rel={rel:.2e}{flag}")
103 print()
105# --- Test 4: Force slow path and re-check gradient ---
106print("=" * 60)
107print("Test 4: Gradient with forced slow path (get_all_derivs only)")
108print("=" * 60)
109# Monkey-patch _expand_derivs to force slow path
110original_expand = opt._expand_derivs
111def slow_expand(phi, n_bases, deriv_order):
112 return phi.get_all_derivs(n_bases, deriv_order)
113opt._expand_derivs = slow_expand
115x0 = np.array([0.1, -0.2, 0.5, -3.0])
116nll, grad_analytic = opt.nll_and_grad(x0)
117eps = 1e-5
118grad_fd = np.zeros_like(x0)
119for i in range(len(x0)):
120 xp = x0.copy(); xp[i] += eps
121 xm = x0.copy(); xm[i] -= eps
122 grad_fd[i] = (opt.nll_wrapper(xp) - opt.nll_wrapper(xm)) / (2 * eps)
123print(f"x0: {x0}, nll: {nll:.4f}")
124for i in range(len(x0)):
125 rel = abs(grad_analytic[i] - grad_fd[i]) / (abs(grad_fd[i]) + 1e-12)
126 flag = ' <--- BAD' if rel > 1e-3 else ''
127 print(f" x[{i}] analytic={grad_analytic[i]:>12.4e} fd={grad_fd[i]:>12.4e} rel={rel:.2e}{flag}")
129# Restore
130opt._expand_derivs = original_expand