Coverage for jetgp/full_degp/grad_check.py: 0%

97 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-04-03 15:09 -0500

1""" 

2Finite-difference gradient check for DEGP optimizer. 

3Tests: get_all_derivs_fast vs get_all_derivs, rbf_kernel_fast vs rbf_kernel, 

4and overall nll_and_grad correctness. 

5""" 

6import numpy as np 

7from jetgp.full_degp.degp import degp 

8 

9np.random.seed(42) 

10n, d = 15, 2 

11X = np.random.rand(n, d) 

12y = (np.sin(X[:, 0]) * np.cos(X[:, 1])).reshape(-1, 1) 

13dy_dx1 = (np.cos(X[:, 0]) * np.cos(X[:, 1])).reshape(-1, 1) 

14dy_dx2 = (-np.sin(X[:, 0]) * np.sin(X[:, 1])).reshape(-1, 1) 

15y_train = [y, dy_dx1, dy_dx2] 

16der_indices = [ 

17 [[[1, 1]], [[2, 1]]], # first-order 

18] 

19model = degp(X, y_train, n_order=1, n_bases=d, der_indices=der_indices, 

20 normalize=True, kernel='SE', kernel_type='anisotropic') 

21opt = model.optimizer 

22 

23# --- Test 1: get_all_derivs_fast vs get_all_derivs --- 

24print("=" * 60) 

25print("Test 1: get_all_derivs_fast vs get_all_derivs") 

26print("=" * 60) 

27x0 = np.array([0.1, -0.2, 0.5, -3.0]) 

28diffs = model.differences_by_dim 

29oti = model.kernel_factory.oti 

30phi = model.kernel_func(diffs, x0[:-1]) 

31n_bases = phi.get_active_bases()[-1] 

32deriv_order = 2 * model.n_order 

33 

34slow = phi.get_all_derivs(n_bases, deriv_order) 

35if hasattr(phi, 'get_all_derivs_fast'): 

36 factors = opt._get_deriv_factors(n_bases, deriv_order) 

37 buf = np.zeros_like(slow) 

38 fast = phi.get_all_derivs_fast(factors, buf) 

39 diff = np.max(np.abs(slow - fast)) 

40 print(f" max |slow - fast| = {diff:.2e}") 

41 if diff > 1e-10: 

42 print(" >>> get_all_derivs_fast DISAGREES with get_all_derivs!") 

43 # Show where they differ 

44 for i in range(slow.shape[0]): 

45 d = np.max(np.abs(slow[i] - fast[i])) 

46 if d > 1e-10: 

47 print(f" slice [{i}]: max diff = {d:.2e}, slow max={np.max(np.abs(slow[i])):.2e}, fast max={np.max(np.abs(fast[i])):.2e}") 

48 else: 

49 print(" OK") 

50else: 

51 print(" get_all_derivs_fast not available, skipping") 

52 

53# --- Test 2: Check if fused_scale_sq_mul matches manual computation --- 

54print() 

55print("=" * 60) 

56print("Test 2: fused_scale_sq_mul vs manual") 

57print("=" * 60) 

58phi = model.kernel_func(diffs, x0[:-1]) # fresh phi 

59if hasattr(phi, 'fused_scale_sq_mul'): 

60 ell = 10.0 ** x0[:d] 

61 ln10 = np.log(10.0) 

62 for dim in range(d): 

63 # Manual: -ln10 * ell[dim]^2 * diffs[dim]^2 * phi 

64 d_sq = oti.mul(diffs[dim], diffs[dim]) 

65 manual = oti.mul(-ln10 * ell[dim] ** 2, oti.mul(d_sq, phi)) 

66 manual_derivs = manual.get_all_derivs(n_bases, deriv_order) 

67 

68 # Fused path 

69 phi2 = model.kernel_func(diffs, x0[:-1]) # fresh phi (buffer corruption!) 

70 dphi_buf = oti.zeros(phi2.shape) 

71 dphi_buf.fused_scale_sq_mul(diffs[dim], phi2, -ln10 * ell[dim] ** 2) 

72 fused_derivs = dphi_buf.get_all_derivs(n_bases, deriv_order) 

73 

74 diff = np.max(np.abs(manual_derivs - fused_derivs)) 

75 print(f" dim {dim}: max |manual - fused| = {diff:.2e}", end="") 

76 if diff > 1e-10: 

77 print(" <--- MISMATCH") 

78 else: 

79 print(" OK") 

80else: 

81 print(" fused_scale_sq_mul not available, skipping") 

82 

83# --- Test 3: Overall FD gradient check --- 

84print() 

85print("=" * 60) 

86print("Test 3: Analytic gradient vs finite differences") 

87print("=" * 60) 

88for x0 in [np.array([0.1, -0.2, 0.5, -3.0]), 

89 np.array([-0.5, 0.3, 0.0, -5.0]), 

90 np.array([0.0, 0.0, 0.0, -4.0])]: 

91 nll, grad_analytic = opt.nll_and_grad(x0) 

92 eps = 1e-5 

93 grad_fd = np.zeros_like(x0) 

94 for i in range(len(x0)): 

95 xp = x0.copy(); xp[i] += eps 

96 xm = x0.copy(); xm[i] -= eps 

97 grad_fd[i] = (opt.nll_wrapper(xp) - opt.nll_wrapper(xm)) / (2 * eps) 

98 print(f"x0: {x0}, nll: {nll:.4f}") 

99 for i in range(len(x0)): 

100 rel = abs(grad_analytic[i] - grad_fd[i]) / (abs(grad_fd[i]) + 1e-12) 

101 flag = ' <--- BAD' if rel > 1e-3 else '' 

102 print(f" x[{i}] analytic={grad_analytic[i]:>12.4e} fd={grad_fd[i]:>12.4e} rel={rel:.2e}{flag}") 

103 print() 

104 

105# --- Test 4: Force slow path and re-check gradient --- 

106print("=" * 60) 

107print("Test 4: Gradient with forced slow path (get_all_derivs only)") 

108print("=" * 60) 

109# Monkey-patch _expand_derivs to force slow path 

110original_expand = opt._expand_derivs 

111def slow_expand(phi, n_bases, deriv_order): 

112 return phi.get_all_derivs(n_bases, deriv_order) 

113opt._expand_derivs = slow_expand 

114 

115x0 = np.array([0.1, -0.2, 0.5, -3.0]) 

116nll, grad_analytic = opt.nll_and_grad(x0) 

117eps = 1e-5 

118grad_fd = np.zeros_like(x0) 

119for i in range(len(x0)): 

120 xp = x0.copy(); xp[i] += eps 

121 xm = x0.copy(); xm[i] -= eps 

122 grad_fd[i] = (opt.nll_wrapper(xp) - opt.nll_wrapper(xm)) / (2 * eps) 

123print(f"x0: {x0}, nll: {nll:.4f}") 

124for i in range(len(x0)): 

125 rel = abs(grad_analytic[i] - grad_fd[i]) / (abs(grad_fd[i]) + 1e-12) 

126 flag = ' <--- BAD' if rel > 1e-3 else '' 

127 print(f" x[{i}] analytic={grad_analytic[i]:>12.4e} fd={grad_fd[i]:>12.4e} rel={rel:.2e}{flag}") 

128 

129# Restore 

130opt._expand_derivs = original_expand