Coverage for jetgp/wdegp/optimizer.py: 82%

586 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-04-03 15:37 -0500

1import numpy as np 

2from scipy.linalg import cho_solve, cho_factor 

3from line_profiler import profile 

4import jetgp.utils as gen_utils 

5from jetgp.hyperparameter_optimizers import OPTIMIZERS 

6from jetgp.utils import matern_kernel_grad_builder 

7 

8 

9class Optimizer: 

10 """ 

11 Optimizer class for fitting the hyperparameters of a weighted derivative-enhanced GP model (wDEGP) 

12 by minimizing the negative log marginal likelihood (NLL). 

13 

14 Supports DEGP, DDEGP, and GDDEGP modes. 

15 

16 Attributes 

17 ---------- 

18 model : object 

19 Instance of a weighted derivative-enhanced GP model (wDEGP) with attributes: 

20 x_train, y_train, n_order, n_bases, der_indices, index, bounds, submodel_type, etc. 

21 """ 

22 

23 def __init__(self, model): 

24 """ 

25 Parameters 

26 ---------- 

27 model : object 

28 An instance of a wDEGP model containing training data, hyperparameter bounds, 

29 and other model-specific structures required for kernel computation. 

30 """ 

31 self.model = model 

32 

33 # Import the appropriate utils module based on submodel_type 

34 self._setup_utils() 

35 

36 # Precompute kernel plans (structural info that never changes) 

37 self._kernel_plans = None # lazily initialized on first NLL call 

38 self._deriv_buf = None 

39 self._deriv_buf_shape = None 

40 self._deriv_factors = None 

41 self._deriv_factors_key = None 

42 self._K_bufs = None # per-submodel pre-allocated K buffers 

43 self._dK_bufs = None # per-submodel pre-allocated dK buffers 

44 self._W_proj_buf = None 

45 self._W_proj_shape = None 

46 

47 def _get_deriv_buf(self, phi, n_bases, order): 

48 from math import comb 

49 ndir = comb(n_bases + order, order) 

50 shape = (ndir, phi.shape[0], phi.shape[1]) 

51 if self._deriv_buf is None or self._deriv_buf_shape != shape: 

52 self._deriv_buf = np.zeros(shape, dtype=np.float64) 

53 self._deriv_buf_shape = shape 

54 return self._deriv_buf 

55 

56 def _expand_derivs(self, phi, n_bases, deriv_order): 

57 """Expand OTI derivatives, using fast struct path if available.""" 

58 if hasattr(phi, 'get_all_derivs_fast'): 

59 buf = self._get_deriv_buf(phi, n_bases, deriv_order) 

60 factors = self._get_deriv_factors(n_bases, deriv_order) 

61 return phi.get_all_derivs_fast(factors, buf) 

62 return phi.get_all_derivs(n_bases, deriv_order) 

63 

64 @staticmethod 

65 def _enum_factors(max_basis, ordi): 

66 from math import factorial 

67 from collections import Counter 

68 if ordi == 1: 

69 for _ in range(max_basis): 

70 yield 1.0 

71 return 

72 for last in range(1, max_basis + 1): 

73 if ordi == 2: 

74 for i in range(1, last + 1): 

75 counts = Counter((i, last)) 

76 f = 1 

77 for c in counts.values(): 

78 f *= factorial(c) 

79 yield float(f) 

80 else: 

81 for _, prefix_counts in Optimizer._enum_factors_with_counts(last, ordi - 1): 

82 counts = dict(prefix_counts) 

83 counts[last] = counts.get(last, 0) + 1 

84 f = 1 

85 for c in counts.values(): 

86 f *= factorial(c) 

87 yield float(f) 

88 

89 @staticmethod 

90 def _enum_factors_with_counts(max_basis, ordi): 

91 from math import factorial 

92 from collections import Counter 

93 if ordi == 1: 

94 for i in range(1, max_basis + 1): 

95 yield 1.0, {i: 1} 

96 return 

97 for last in range(1, max_basis + 1): 

98 for _, prefix_counts in Optimizer._enum_factors_with_counts(last, ordi - 1): 

99 counts = dict(prefix_counts) 

100 counts[last] = counts.get(last, 0) + 1 

101 f = 1 

102 for c in counts.values(): 

103 f *= factorial(c) 

104 yield float(f), counts 

105 

106 def _get_deriv_factors(self, n_bases, order): 

107 key = (n_bases, order) 

108 if self._deriv_factors is not None and self._deriv_factors_key == key: 

109 return self._deriv_factors 

110 factors = [1.0] 

111 for ordi in range(1, order + 1): 

112 factors.extend(self._enum_factors(n_bases, ordi)) 

113 self._deriv_factors = np.array(factors, dtype=np.float64) 

114 self._deriv_factors_key = key 

115 return self._deriv_factors 

116 

117 def _setup_utils(self): 

118 """Set up the correct utils module based on submodel_type.""" 

119 submodel_type = getattr(self.model, 'submodel_type', 'degp') 

120 

121 if submodel_type == 'degp': 

122 from jetgp.wdegp import wdegp_utils 

123 self.utils = wdegp_utils 

124 self._uses_signs = True 

125 elif submodel_type == 'ddegp': 

126 from jetgp.full_ddegp import wddegp_utils 

127 self.utils = wddegp_utils 

128 self._uses_signs = True 

129 elif submodel_type == 'gddegp': 

130 from jetgp.full_gddegp import wgddegp_utils 

131 self.utils = wgddegp_utils 

132 self._uses_signs = False 

133 else: 

134 # Default to degp 

135 from jetgp.wdegp import wdegp_utils 

136 self.utils = wdegp_utils 

137 self._uses_signs = True 

138 

139 def _ensure_kernel_plans(self, n_bases): 

140 """Lazily precompute kernel plans for all submodels (once per n_bases).""" 

141 if self._kernel_plans is not None and self._kernel_plans_n_bases == n_bases: 

142 return 

143 if not hasattr(self.utils, 'precompute_kernel_plan'): 

144 self._kernel_plans = None 

145 return 

146 plans = [] 

147 index = self.model.derivative_locations 

148 for i in range(len(index)): 

149 plan = self.utils.precompute_kernel_plan( 

150 self.model.n_order, n_bases, 

151 self.model.flattened_der_indices[i], 

152 self.model.powers[i], 

153 index[i], 

154 ) 

155 plans.append(plan) 

156 self._kernel_plans = plans 

157 self._kernel_plans_n_bases = n_bases 

158 # Reset buffers when plans change 

159 self._K_bufs = None 

160 self._dK_bufs = None 

161 

162 def _ensure_kernel_bufs(self, n_rows_func): 

163 """Pre-allocate reusable K and dK buffers for each submodel.""" 

164 if self._kernel_plans is None: 

165 return 

166 if self._K_bufs is not None: 

167 return # already allocated 

168 self._K_bufs = [] 

169 self._dK_bufs = [] 

170 for plan in self._kernel_plans: 

171 total = n_rows_func + plan['n_pts_with_derivs'] 

172 self._K_bufs.append(np.empty((total, total))) 

173 self._dK_bufs.append(np.empty((total, total))) 

174 if 'row_offsets_abs' not in plan: 

175 plan['row_offsets_abs'] = plan['row_offsets'] + n_rows_func 

176 plan['col_offsets_abs'] = plan['col_offsets'] + n_rows_func 

177 

178 @profile 

179 def negative_log_marginal_likelihood( 

180 self, 

181 x0, 

182 x_train, 

183 y_train, 

184 n_order, 

185 n_bases, 

186 der_indices, 

187 index, 

188 ): 

189 """ 

190 Computes the negative log marginal likelihood (NLL) for a given hyperparameter vector. 

191 

192 NLL = 0.5 * y^T (K^-1) y + 0.5 * log|K| + 0.5*N*log(2*pi) 

193 

194 Parameters 

195 ---------- 

196 x0 : ndarray 

197 Log-scaled hyperparameter vector, where the last entry is log10(sigma_n). 

198 x_train : list of ndarrays 

199 Input training points (unused inside loop, included for general interface). 

200 y_train : list of ndarrays 

201 List of function and derivative training values for each submodel. 

202 n_order : int 

203 Maximum order of derivatives used. 

204 n_bases : int 

205 Number of Taylor bases used in the expansion. 

206 der_indices : list 

207 Multi-index derivative information. 

208 index : list of lists 

209 Indices partitioning the training data into submodels (derivative_locations). 

210 

211 Returns 

212 ------- 

213 float 

214 The computed negative log marginal likelihood. 

215 """ 

216 ell = x0[:-1] 

217 sigma_n = x0[-1] 

218 llhood = 0 

219 # ell[0] = 0 

220 # ell[1] = 0 

221 # ell[2] = 0 

222 # sigma_n = -16 

223 diffs = self.model.differences_by_dim 

224 phi = self.model.kernel_func(diffs, ell) 

225 if self.model.n_order == 0: 

226 n_bases = 0 

227 phi_exp = phi.real 

228 phi_exp = phi_exp[np.newaxis, :, :] 

229 else: 

230 n_bases = phi.get_active_bases()[-1] 

231 

232 # Extract ALL derivative components 

233 deriv_order = 2 * n_order 

234 phi_exp = self._expand_derivs(phi, n_bases, deriv_order) 

235 

236 # Ensure kernel plans are precomputed 

237 self._ensure_kernel_plans(n_bases) 

238 use_fast = self._kernel_plans is not None 

239 

240 # Pre-reshape phi_exp to 3D once 

241 if use_fast: 

242 base_shape = phi.shape 

243 self._ensure_kernel_bufs(base_shape[0]) 

244 phi_exp_3d = phi_exp.reshape(phi_exp.shape[0], base_shape[0], base_shape[1]) 

245 

246 for i in range(len(index)): 

247 y_train_sub = y_train[i] 

248 

249 if use_fast: 

250 K = self.utils.rbf_kernel_fast(phi_exp_3d, self._kernel_plans[i], out=self._K_bufs[i]) 

251 else: 

252 K = self.utils.rbf_kernel( 

253 phi, phi_exp, n_order, n_bases, 

254 self.model.flattened_der_indices[i], 

255 self.model.powers[i], index=index[i] 

256 ) 

257 

258 K += (10 ** sigma_n) ** 2 * np.eye(len(K)) 

259 

260 try: 

261 L, low = cho_factor(K) 

262 alpha = cho_solve( 

263 (L, low), 

264 y_train_sub 

265 ) 

266 

267 data_fit = 0.5 * np.dot(y_train_sub.flatten(), alpha.flatten()) 

268 log_det = np.sum(np.log(np.diag(L))) 

269 const = 0.5 * len(y_train_sub) * np.log(2 * np.pi) 

270 

271 llhood += data_fit + log_det + const 

272 except np.linalg.LinAlgError: 

273 llhood += 1e6 # Penalize badly conditioned matrices 

274 

275 return llhood 

276 

277 def nll_wrapper(self, x0): 

278 """ 

279 Wrapper for NLL function to fit PSO optimizer interface. 

280 

281 Parameters 

282 ---------- 

283 x0 : ndarray 

284 Hyperparameter vector. 

285 

286 Returns 

287 ------- 

288 float 

289 Computed NLL value. 

290 """ 

291 return self.negative_log_marginal_likelihood( 

292 x0, 

293 self.model.x_train, 

294 self.model.y_train_normalized, 

295 self.model.n_order, 

296 self.model.n_bases, 

297 self.model.der_indices, 

298 self.model.derivative_locations, 

299 ) 

300 

301 def nll_grad(self, x0): 

302 """Analytic gradient of the NLL w.r.t. log10-scaled hyperparameters.""" 

303 ln10 = np.log(10.0) 

304 

305 kernel = self.model.kernel 

306 kernel_type = self.model.kernel_type 

307 D = len(self.model.differences_by_dim) 

308 sigma_n_sq = (10.0 ** x0[-1]) ** 2 

309 diffs = self.model.differences_by_dim 

310 oti = self.model.kernel_factory.oti 

311 index = self.model.derivative_locations 

312 

313 phi = self.model.kernel_func(diffs, x0[:-1]) 

314 if self.model.n_order == 0: 

315 n_bases = 0 

316 phi_exp = phi.real[np.newaxis, :, :] 

317 else: 

318 n_bases = phi.get_active_bases()[-1] 

319 deriv_order = 2 * self.model.n_order 

320 phi_exp = self._expand_derivs(phi, n_bases, deriv_order) 

321 

322 # Ensure kernel plans are precomputed 

323 self._ensure_kernel_plans(n_bases) 

324 use_fast = self._kernel_plans is not None 

325 

326 # Pre-reshape phi_exp to 3D once 

327 if use_fast: 

328 base_shape = phi.shape 

329 self._ensure_kernel_bufs(base_shape[0]) 

330 phi_exp_3d = phi_exp.reshape(phi_exp.shape[0], base_shape[0], base_shape[1]) 

331 

332 # Build per-submodel W matrices 

333 W_list = [] 

334 for i in range(len(index)): 

335 y_train_sub = self.model.y_train_normalized[i] 

336 if use_fast: 

337 K = self.utils.rbf_kernel_fast(phi_exp_3d, self._kernel_plans[i], out=self._K_bufs[i]) 

338 else: 

339 K = self.utils.rbf_kernel( 

340 phi, phi_exp, self.model.n_order, n_bases, 

341 self.model.flattened_der_indices[i], 

342 self.model.powers[i], index=index[i] 

343 ) 

344 K.flat[::K.shape[0] + 1] += sigma_n_sq 

345 try: 

346 L, low = cho_factor(K) 

347 alpha_v = cho_solve((L, low), y_train_sub) 

348 N = len(y_train_sub) 

349 K_inv = cho_solve((L, low), np.eye(N)) 

350 W_list.append(K_inv - np.outer(alpha_v, alpha_v)) 

351 except Exception: 

352 return np.zeros(len(x0)) 

353 

354 grad = np.zeros(len(x0)) 

355 

356 # Precompute W projected into phi_exp space (sum over submodels) 

357 W_proj = None 

358 if use_fast and self.model.n_order > 0: 

359 from math import comb 

360 ndir = comb(n_bases + deriv_order, deriv_order) 

361 proj_shape = (ndir, base_shape[0], base_shape[1]) 

362 if self._W_proj_buf is None or self._W_proj_shape != proj_shape: 

363 self._W_proj_buf = np.empty(proj_shape) 

364 self._W_proj_shape = proj_shape 

365 W_proj = self._W_proj_buf 

366 W_proj[:] = 0.0 

367 tmp_proj = np.empty_like(W_proj) 

368 for i in range(len(index)): 

369 plan = self._kernel_plans[i] 

370 row_off = plan.get('row_offsets_abs', plan['row_offsets'] + base_shape[0]) 

371 col_off = plan.get('col_offsets_abs', plan['col_offsets'] + base_shape[1]) 

372 args = [ 

373 W_list[i], tmp_proj, base_shape[0], base_shape[1], 

374 plan['fd_flat_indices'], plan['df_flat_indices'], 

375 plan['dd_flat_indices'], 

376 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'], 

377 ] 

378 if self._uses_signs: 

379 args.append(plan['signs']) 

380 args.extend([plan['n_deriv_types'], row_off, col_off]) 

381 self.utils._project_W_to_phi_space(*args) 

382 W_proj += tmp_proj 

383 

384 def _gc(dphi): 

385 if self.model.n_order == 0: 

386 dphi_exp = dphi.real[np.newaxis, :, :] 

387 else: 

388 dphi_exp = self._expand_derivs(dphi, n_bases, deriv_order) 

389 if W_proj is not None: 

390 dphi_3d = dphi_exp.reshape(W_proj.shape) 

391 return 0.5 * np.vdot(W_proj, dphi_3d) 

392 elif use_fast: 

393 dphi_3d = dphi_exp.reshape(dphi_exp.shape[0], base_shape[0], base_shape[1]) 

394 total = 0.0 

395 for i in range(len(index)): 

396 dK = self.utils.rbf_kernel_fast(dphi_3d, self._kernel_plans[i], out=self._dK_bufs[i]) 

397 total += np.vdot(W_list[i], dK) 

398 return 0.5 * total 

399 else: 

400 total = 0.0 

401 for i in range(len(index)): 

402 dK = self.utils.rbf_kernel( 

403 dphi, dphi_exp, 

404 self.model.n_order, n_bases, 

405 self.model.flattened_der_indices[i], 

406 self.model.powers[i], 

407 index=index[i], 

408 ) 

409 total += np.vdot(W_list[i], dK) 

410 return 0.5 * total 

411 

412 grad[-2] = _gc(oti.mul(2.0 * ln10, phi)) 

413 grad[-1] = ln10 * sigma_n_sq * sum(np.trace(W) for W in W_list) 

414 

415 if kernel == 'SE': 

416 if kernel_type == 'anisotropic': 

417 ell = 10.0 ** x0[:D] 

418 if hasattr(phi, 'fused_scale_sq_mul'): 

419 dphi_buf = oti.zeros(phi.shape) 

420 for d in range(D): 

421 dphi_buf.fused_scale_sq_mul(diffs[d], phi, -ln10 * ell[d] ** 2) 

422 grad[d] = _gc(dphi_buf) 

423 else: 

424 for d in range(D): 

425 grad[d] = _gc(oti.mul(-ln10 * ell[d] ** 2, 

426 oti.mul(oti.mul(diffs[d], diffs[d]), phi))) 

427 else: 

428 ell = 10.0 ** float(x0[0]) 

429 if hasattr(phi, 'fused_sum_sq'): 

430 sum_sq = oti.zeros(phi.shape) 

431 sum_sq.fused_sum_sq(diffs) 

432 else: 

433 sum_sq = oti.mul(diffs[0], diffs[0]) 

434 for d in range(1, D): 

435 sum_sq = oti.sum(sum_sq, oti.mul(diffs[d], diffs[d])) 

436 grad[0] = _gc(oti.mul(-ln10 * ell ** 2, oti.mul(sum_sq, phi))) 

437 

438 elif kernel == 'RQ': 

439 if kernel_type == 'anisotropic': 

440 ell = 10.0 ** x0[:D]; alpha_rq = 10.0 ** float(x0[D]); alpha_idx = D 

441 else: 

442 ell = np.full(D, 10.0 ** float(x0[0])) 

443 alpha_rq = np.exp(float(x0[1])); alpha_idx = 1 

444 if hasattr(phi, 'fused_sqdist'): 

445 r2 = oti.zeros(phi.shape) 

446 ell_sq = np.ascontiguousarray(ell ** 2, dtype=np.float64) 

447 r2.fused_sqdist(diffs, ell_sq) 

448 else: 

449 r2 = oti.mul(ell[0], diffs[0]); r2 = oti.mul(r2, r2) 

450 for d in range(1, D): 

451 td = oti.mul(ell[d], diffs[d]); r2 = oti.sum(r2, oti.mul(td, td)) 

452 base = oti.sum(1.0, oti.mul(r2, 1.0 / (2.0 * alpha_rq))) 

453 inv_base = oti.pow(base, -1) 

454 phi_over_base = oti.mul(phi, inv_base) 

455 if kernel_type == 'anisotropic': 

456 if hasattr(phi, 'fused_scale_sq_mul'): 

457 dphi_buf = oti.zeros(phi.shape) 

458 for d in range(D): 

459 dphi_buf.fused_scale_sq_mul(diffs[d], phi_over_base, -ln10 * ell[d] ** 2) 

460 grad[d] = _gc(dphi_buf) 

461 else: 

462 for d in range(D): 

463 grad[d] = _gc(oti.mul(-ln10 * ell[d] ** 2, 

464 oti.mul(oti.mul(diffs[d], diffs[d]), phi_over_base))) 

465 else: 

466 if hasattr(phi, 'fused_sum_sq'): 

467 sum_sq = oti.zeros(phi.shape) 

468 sum_sq.fused_sum_sq(diffs) 

469 else: 

470 sum_sq = oti.mul(diffs[0], diffs[0]) 

471 for d in range(1, D): 

472 sum_sq = oti.sum(sum_sq, oti.mul(diffs[d], diffs[d])) 

473 grad[0] = _gc(oti.mul(-ln10 * ell[0] ** 2, oti.mul(sum_sq, phi_over_base))) 

474 log_base = oti.log(base) 

475 term = oti.sub(oti.sub(1.0, inv_base), log_base) 

476 alpha_factor = ln10 * alpha_rq if kernel_type == 'anisotropic' else alpha_rq 

477 grad[alpha_idx] = _gc(oti.mul(alpha_factor, oti.mul(phi, term))) 

478 

479 elif kernel == 'SineExp': 

480 if kernel_type == 'anisotropic': 

481 ell = 10.0 ** x0[:D]; p = 10.0 ** x0[D:2*D] 

482 pip = np.pi / p; p_start = D 

483 else: 

484 ell = np.full(D, 10.0 ** float(x0[0])) 

485 pip = np.full(D, np.pi / 10.0 ** float(x0[1])); p_start = 1 

486 sin_d = [oti.sin(oti.mul(pip[d], diffs[d])) for d in range(D)] 

487 cos_d = [oti.cos(oti.mul(pip[d], diffs[d])) for d in range(D)] 

488 if kernel_type == 'anisotropic': 

489 if hasattr(phi, 'fused_scale_sq_mul'): 

490 dphi_buf = oti.zeros(phi.shape) 

491 for d in range(D): 

492 dphi_buf.fused_scale_sq_mul(sin_d[d], phi, -4.0 * ln10 * ell[d] ** 2) 

493 grad[d] = _gc(dphi_buf) 

494 else: 

495 for d in range(D): 

496 grad[d] = _gc(oti.mul(-4.0 * ln10 * ell[d] ** 2, 

497 oti.mul(oti.mul(sin_d[d], sin_d[d]), phi))) 

498 for d in range(D): 

499 sc = oti.mul(sin_d[d], oti.mul(cos_d[d], diffs[d])) 

500 grad[p_start + d] = _gc(oti.mul(4.0 * ln10 * ell[d] ** 2 * pip[d], 

501 oti.mul(sc, phi))) 

502 else: 

503 if hasattr(phi, 'fused_sum_sq'): 

504 ss = oti.zeros(phi.shape) 

505 ss.fused_sum_sq(sin_d) 

506 else: 

507 ss = oti.mul(sin_d[0], sin_d[0]) 

508 for d in range(1, D): 

509 ss = oti.sum(ss, oti.mul(sin_d[d], sin_d[d])) 

510 grad[0] = _gc(oti.mul(-4.0 * ln10 * ell[0] ** 2, oti.mul(ss, phi))) 

511 scd = oti.mul(sin_d[0], oti.mul(cos_d[0], diffs[0])) 

512 for d in range(1, D): 

513 scd = oti.sum(scd, oti.mul(sin_d[d], oti.mul(cos_d[d], diffs[d]))) 

514 grad[p_start] = _gc(oti.mul(4.0 * ln10 * ell[0] ** 2 * pip[0], 

515 oti.mul(scd, phi))) 

516 

517 elif kernel == 'Matern': 

518 kf = self.model.kernel_factory 

519 if not hasattr(kf, '_matern_grad_prebuild'): 

520 kf._matern_grad_prebuild = matern_kernel_grad_builder(getattr(kf, "nu", 1.5), oti_module=oti) 

521 ell = (10.0 ** x0[:D] if kernel_type == 'anisotropic' 

522 else np.full(D, 10.0 ** float(x0[0]))) 

523 sigma_f_sq = (10.0 ** float(x0[-2])) ** 2 

524 _eps = 1e-10 

525 if hasattr(phi, 'fused_sqdist'): 

526 r2 = oti.zeros(phi.shape) 

527 ell_sq = np.ascontiguousarray(ell ** 2, dtype=np.float64) 

528 r2.fused_sqdist(diffs, ell_sq) 

529 else: 

530 r2 = oti.mul(ell[0], diffs[0]); r2 = oti.mul(r2, r2) 

531 for d in range(1, D): 

532 td = oti.mul(ell[d], diffs[d]); r2 = oti.sum(r2, oti.mul(td, td)) 

533 r_oti = oti.sqrt(oti.sum(r2, _eps ** 2)) 

534 f_prime_r = kf._matern_grad_prebuild(r_oti) 

535 inv_r = oti.pow(r_oti, -1) 

536 base_matern = oti.mul(sigma_f_sq, oti.mul(f_prime_r, inv_r)) 

537 if kernel_type == 'anisotropic': 

538 if hasattr(phi, 'fused_scale_sq_mul'): 

539 dphi_buf = oti.zeros(phi.shape) 

540 for d in range(D): 

541 dphi_buf.fused_scale_sq_mul(diffs[d], base_matern, ln10 * ell[d] ** 2) 

542 grad[d] = _gc(dphi_buf) 

543 else: 

544 for d in range(D): 

545 d_sq = oti.mul(diffs[d], diffs[d]) 

546 dphi_d = oti.mul(ln10 * ell[d] ** 2, oti.mul(d_sq, base_matern)) 

547 grad[d] = _gc(dphi_d) 

548 else: 

549 if hasattr(phi, 'fused_sum_sq'): 

550 sum_dsq = oti.zeros(phi.shape) 

551 sum_dsq.fused_sum_sq(diffs) 

552 else: 

553 sum_dsq = oti.mul(diffs[0], diffs[0]) 

554 for d in range(1, D): 

555 sum_dsq = oti.sum(sum_dsq, oti.mul(diffs[d], diffs[d])) 

556 dphi_e = oti.mul(ln10 * ell[0] ** 2, oti.mul(sum_dsq, base_matern)) 

557 grad[0] = _gc(dphi_e) 

558 

559 return grad 

560 

561 def nll_and_grad(self, x0): 

562 """Compute NLL and its gradient in a single pass, sharing one Cholesky per submodel.""" 

563 ln10 = np.log(10.0) 

564 

565 kernel = self.model.kernel 

566 kernel_type = self.model.kernel_type 

567 D = len(self.model.differences_by_dim) 

568 sigma_n_sq = (10.0 ** x0[-1]) ** 2 

569 diffs = self.model.differences_by_dim 

570 oti = self.model.kernel_factory.oti 

571 index = self.model.derivative_locations 

572 

573 # --- shared kernel computation (done ONCE) --- 

574 phi = self.model.kernel_func(diffs, x0[:-1]) 

575 if self.model.n_order == 0: 

576 n_bases = 0 

577 phi_exp = phi.real[np.newaxis, :, :] 

578 else: 

579 n_bases = phi.get_active_bases()[-1] 

580 deriv_order = 2 * self.model.n_order 

581 phi_exp = self._expand_derivs(phi, n_bases, deriv_order) 

582 

583 # Ensure kernel plans are precomputed 

584 self._ensure_kernel_plans(n_bases) 

585 use_fast = self._kernel_plans is not None 

586 

587 # Pre-reshape phi_exp to 3D once 

588 if use_fast: 

589 base_shape = phi.shape 

590 self._ensure_kernel_bufs(base_shape[0]) 

591 phi_exp_3d = phi_exp.reshape(phi_exp.shape[0], base_shape[0], base_shape[1]) 

592 

593 # --- single loop: compute NLL and W_list simultaneously --- 

594 llhood = 0.0 

595 W_list = [] 

596 for i in range(len(index)): 

597 y_train_sub = self.model.y_train_normalized[i] 

598 

599 if use_fast: 

600 K = self.utils.rbf_kernel_fast(phi_exp_3d, self._kernel_plans[i], out=self._K_bufs[i]) 

601 else: 

602 K = self.utils.rbf_kernel( 

603 phi, phi_exp, self.model.n_order, n_bases, 

604 self.model.flattened_der_indices[i], 

605 self.model.powers[i], index=index[i] 

606 ) 

607 K.flat[::K.shape[0] + 1] += sigma_n_sq 

608 

609 try: 

610 L, low = cho_factor(K) 

611 alpha_v = cho_solve((L, low), y_train_sub) 

612 N = len(y_train_sub) 

613 

614 # NLL contribution 

615 data_fit = 0.5 * np.dot(y_train_sub.flatten(), alpha_v.flatten()) 

616 log_det = np.sum(np.log(np.diag(L))) 

617 const = 0.5 * N * np.log(2 * np.pi) 

618 llhood += data_fit + log_det + const 

619 

620 # W matrix for gradient (reuse same Cholesky) 

621 K_inv = cho_solve((L, low), np.eye(N)) 

622 W_list.append(K_inv - np.outer(alpha_v, alpha_v)) 

623 except np.linalg.LinAlgError: 

624 llhood += 1e6 

625 return float(llhood), np.zeros(len(x0)) 

626 

627 # --- gradient from W_list (no second kernel build / Cholesky) --- 

628 grad = np.zeros(len(x0)) 

629 n_sub = len(index) 

630 

631 # Precompute W projected into phi_exp space (sum over submodels) 

632 W_proj = None 

633 if use_fast and self.model.n_order > 0: 

634 from math import comb 

635 ndir = comb(n_bases + deriv_order, deriv_order) 

636 proj_shape = (ndir, base_shape[0], base_shape[1]) 

637 if self._W_proj_buf is None or self._W_proj_shape != proj_shape: 

638 self._W_proj_buf = np.empty(proj_shape) 

639 self._W_proj_shape = proj_shape 

640 W_proj = self._W_proj_buf 

641 W_proj[:] = 0.0 

642 tmp_proj = np.empty_like(W_proj) 

643 for i in range(n_sub): 

644 plan = self._kernel_plans[i] 

645 row_off = plan.get('row_offsets_abs', plan['row_offsets'] + base_shape[0]) 

646 col_off = plan.get('col_offsets_abs', plan['col_offsets'] + base_shape[1]) 

647 args = [ 

648 W_list[i], tmp_proj, base_shape[0], base_shape[1], 

649 plan['fd_flat_indices'], plan['df_flat_indices'], 

650 plan['dd_flat_indices'], 

651 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'], 

652 ] 

653 if self._uses_signs: 

654 args.append(plan['signs']) 

655 args.extend([plan['n_deriv_types'], row_off, col_off]) 

656 self.utils._project_W_to_phi_space(*args) 

657 W_proj += tmp_proj 

658 

659 def _gc(dphi): 

660 # Precompute dphi_exp ONCE, reshape to 3D 

661 if self.model.n_order == 0: 

662 dphi_exp = dphi.real[np.newaxis, :, :] 

663 else: 

664 dphi_exp = self._expand_derivs(dphi, n_bases, deriv_order) 

665 if W_proj is not None: 

666 dphi_3d = dphi_exp.reshape(W_proj.shape) 

667 return 0.5 * np.vdot(W_proj, dphi_3d) 

668 elif use_fast: 

669 dphi_3d = dphi_exp.reshape(dphi_exp.shape[0], base_shape[0], base_shape[1]) 

670 total = 0.0 

671 for i in range(n_sub): 

672 dK = self.utils.rbf_kernel_fast(dphi_3d, self._kernel_plans[i], out=self._dK_bufs[i]) 

673 total += np.vdot(W_list[i], dK) 

674 return 0.5 * total 

675 else: 

676 total = 0.0 

677 for i in range(n_sub): 

678 dK = self.utils.rbf_kernel( 

679 dphi, dphi_exp, 

680 self.model.n_order, n_bases, 

681 self.model.flattened_der_indices[i], 

682 self.model.powers[i], 

683 index=index[i], 

684 ) 

685 total += np.vdot(W_list[i], dK) 

686 return 0.5 * total 

687 

688 grad[-2] = _gc(oti.mul(2.0 * ln10, phi)) 

689 grad[-1] = ln10 * sigma_n_sq * sum(np.trace(W) for W in W_list) 

690 

691 if kernel == 'SE': 

692 if kernel_type == 'anisotropic': 

693 ell = 10.0 ** x0[:D] 

694 if hasattr(phi, 'fused_scale_sq_mul'): 

695 dphi_buf = oti.zeros(phi.shape) 

696 for d in range(D): 

697 dphi_buf.fused_scale_sq_mul(diffs[d], phi, -ln10 * ell[d] ** 2) 

698 grad[d] = _gc(dphi_buf) 

699 else: 

700 for d in range(D): 

701 grad[d] = _gc(oti.mul(-ln10 * ell[d] ** 2, 

702 oti.mul(oti.mul(diffs[d], diffs[d]), phi))) 

703 else: 

704 ell = 10.0 ** float(x0[0]) 

705 if hasattr(phi, 'fused_sum_sq'): 

706 sum_sq = oti.zeros(phi.shape) 

707 sum_sq.fused_sum_sq(diffs) 

708 else: 

709 sum_sq = oti.mul(diffs[0], diffs[0]) 

710 for d in range(1, D): 

711 sum_sq = oti.sum(sum_sq, oti.mul(diffs[d], diffs[d])) 

712 grad[0] = _gc(oti.mul(-ln10 * ell ** 2, oti.mul(sum_sq, phi))) 

713 

714 elif kernel == 'RQ': 

715 if kernel_type == 'anisotropic': 

716 ell = 10.0 ** x0[:D]; alpha_rq = 10.0 ** float(x0[D]); alpha_idx = D 

717 else: 

718 ell = np.full(D, 10.0 ** float(x0[0])) 

719 alpha_rq = np.exp(float(x0[1])); alpha_idx = 1 

720 if hasattr(phi, 'fused_sqdist'): 

721 r2 = oti.zeros(phi.shape) 

722 ell_sq = np.ascontiguousarray(ell ** 2, dtype=np.float64) 

723 r2.fused_sqdist(diffs, ell_sq) 

724 else: 

725 r2 = oti.mul(ell[0], diffs[0]); r2 = oti.mul(r2, r2) 

726 for d in range(1, D): 

727 td = oti.mul(ell[d], diffs[d]); r2 = oti.sum(r2, oti.mul(td, td)) 

728 base = oti.sum(1.0, oti.mul(r2, 1.0 / (2.0 * alpha_rq))) 

729 inv_base = oti.pow(base, -1) 

730 phi_over_base = oti.mul(phi, inv_base) 

731 if kernel_type == 'anisotropic': 

732 if hasattr(phi, 'fused_scale_sq_mul'): 

733 dphi_buf = oti.zeros(phi.shape) 

734 for d in range(D): 

735 dphi_buf.fused_scale_sq_mul(diffs[d], phi_over_base, -ln10 * ell[d] ** 2) 

736 grad[d] = _gc(dphi_buf) 

737 else: 

738 for d in range(D): 

739 grad[d] = _gc(oti.mul(-ln10 * ell[d] ** 2, 

740 oti.mul(oti.mul(diffs[d], diffs[d]), phi_over_base))) 

741 else: 

742 if hasattr(phi, 'fused_sum_sq'): 

743 sum_sq = oti.zeros(phi.shape) 

744 sum_sq.fused_sum_sq(diffs) 

745 else: 

746 sum_sq = oti.mul(diffs[0], diffs[0]) 

747 for d in range(1, D): 

748 sum_sq = oti.sum(sum_sq, oti.mul(diffs[d], diffs[d])) 

749 grad[0] = _gc(oti.mul(-ln10 * ell[0] ** 2, oti.mul(sum_sq, phi_over_base))) 

750 log_base = oti.log(base) 

751 term = oti.sub(oti.sub(1.0, inv_base), log_base) 

752 alpha_factor = ln10 * alpha_rq if kernel_type == 'anisotropic' else alpha_rq 

753 grad[alpha_idx] = _gc(oti.mul(alpha_factor, oti.mul(phi, term))) 

754 

755 elif kernel == 'SineExp': 

756 if kernel_type == 'anisotropic': 

757 ell = 10.0 ** x0[:D]; p = 10.0 ** x0[D:2*D] 

758 pip = np.pi / p; p_start = D 

759 else: 

760 ell = np.full(D, 10.0 ** float(x0[0])) 

761 pip = np.full(D, np.pi / 10.0 ** float(x0[1])); p_start = 1 

762 sin_d = [oti.sin(oti.mul(pip[d], diffs[d])) for d in range(D)] 

763 cos_d = [oti.cos(oti.mul(pip[d], diffs[d])) for d in range(D)] 

764 if kernel_type == 'anisotropic': 

765 if hasattr(phi, 'fused_scale_sq_mul'): 

766 dphi_buf = oti.zeros(phi.shape) 

767 for d in range(D): 

768 dphi_buf.fused_scale_sq_mul(sin_d[d], phi, -4.0 * ln10 * ell[d] ** 2) 

769 grad[d] = _gc(dphi_buf) 

770 else: 

771 for d in range(D): 

772 grad[d] = _gc(oti.mul(-4.0 * ln10 * ell[d] ** 2, 

773 oti.mul(oti.mul(sin_d[d], sin_d[d]), phi))) 

774 for d in range(D): 

775 sc = oti.mul(sin_d[d], oti.mul(cos_d[d], diffs[d])) 

776 grad[p_start + d] = _gc(oti.mul(4.0 * ln10 * ell[d] ** 2 * pip[d], 

777 oti.mul(sc, phi))) 

778 else: 

779 if hasattr(phi, 'fused_sum_sq'): 

780 ss = oti.zeros(phi.shape) 

781 ss.fused_sum_sq(sin_d) 

782 else: 

783 ss = oti.mul(sin_d[0], sin_d[0]) 

784 for d in range(1, D): 

785 ss = oti.sum(ss, oti.mul(sin_d[d], sin_d[d])) 

786 grad[0] = _gc(oti.mul(-4.0 * ln10 * ell[0] ** 2, oti.mul(ss, phi))) 

787 scd = oti.mul(sin_d[0], oti.mul(cos_d[0], diffs[0])) 

788 for d in range(1, D): 

789 scd = oti.sum(scd, oti.mul(sin_d[d], oti.mul(cos_d[d], diffs[d]))) 

790 grad[p_start] = _gc(oti.mul(4.0 * ln10 * ell[0] ** 2 * pip[0], 

791 oti.mul(scd, phi))) 

792 

793 elif kernel == 'Matern': 

794 kf = self.model.kernel_factory 

795 if not hasattr(kf, '_matern_grad_prebuild'): 

796 kf._matern_grad_prebuild = matern_kernel_grad_builder(getattr(kf, "nu", 1.5), oti_module=oti) 

797 ell = (10.0 ** x0[:D] if kernel_type == 'anisotropic' 

798 else np.full(D, 10.0 ** float(x0[0]))) 

799 sigma_f_sq = (10.0 ** float(x0[-2])) ** 2 

800 _eps = 1e-10 

801 if hasattr(phi, 'fused_sqdist'): 

802 r2 = oti.zeros(phi.shape) 

803 ell_sq = np.ascontiguousarray(ell ** 2, dtype=np.float64) 

804 r2.fused_sqdist(diffs, ell_sq) 

805 else: 

806 r2 = oti.mul(ell[0], diffs[0]); r2 = oti.mul(r2, r2) 

807 for d in range(1, D): 

808 td = oti.mul(ell[d], diffs[d]); r2 = oti.sum(r2, oti.mul(td, td)) 

809 r_oti = oti.sqrt(oti.sum(r2, _eps ** 2)) 

810 f_prime_r = kf._matern_grad_prebuild(r_oti) 

811 inv_r = oti.pow(r_oti, -1) 

812 base_matern = oti.mul(sigma_f_sq, oti.mul(f_prime_r, inv_r)) 

813 if kernel_type == 'anisotropic': 

814 if hasattr(phi, 'fused_scale_sq_mul'): 

815 dphi_buf = oti.zeros(phi.shape) 

816 for d in range(D): 

817 dphi_buf.fused_scale_sq_mul(diffs[d], base_matern, ln10 * ell[d] ** 2) 

818 grad[d] = _gc(dphi_buf) 

819 else: 

820 for d in range(D): 

821 d_sq = oti.mul(diffs[d], diffs[d]) 

822 dphi_d = oti.mul(ln10 * ell[d] ** 2, oti.mul(d_sq, base_matern)) 

823 grad[d] = _gc(dphi_d) 

824 else: 

825 if hasattr(phi, 'fused_sum_sq'): 

826 sum_dsq = oti.zeros(phi.shape) 

827 sum_dsq.fused_sum_sq(diffs) 

828 else: 

829 sum_dsq = oti.mul(diffs[0], diffs[0]) 

830 for d in range(1, D): 

831 sum_dsq = oti.sum(sum_dsq, oti.mul(diffs[d], diffs[d])) 

832 dphi_e = oti.mul(ln10 * ell[0] ** 2, oti.mul(sum_dsq, base_matern)) 

833 grad[0] = _gc(dphi_e) 

834 

835 return float(llhood), grad 

836 

837 def optimize_hyperparameters( 

838 self, 

839 optimizer="pso", 

840 **kwargs 

841 ): 

842 """ 

843 Optimize the DEGP model hyperparameters using the specified optimizer. 

844 

845 Parameters: 

846 ---------- 

847 optimizer : str or callable, default="pso" 

848 Name of optimizer or callable. Available: 'pso', 'lbfgs', 'jade', etc. 

849 **kwargs : dict 

850 Additional arguments passed to the optimizer. 

851 

852 Returns: 

853 ------- 

854 best_x : ndarray 

855 The optimal set of hyperparameters found. 

856 """ 

857 

858 if isinstance(optimizer, str): 

859 if optimizer not in OPTIMIZERS: 

860 raise ValueError( 

861 f"Unknown optimizer '{optimizer}'. Available: {list(OPTIMIZERS.keys())}" 

862 ) 

863 optimizer_fn = OPTIMIZERS[optimizer] 

864 else: 

865 optimizer_fn = optimizer # allow passing a callable directly 

866 

867 bounds = self.model.bounds 

868 lb = [b[0] for b in bounds] 

869 ub = [b[1] for b in bounds] 

870 

871 if optimizer in ('lbfgs', 'jade', 'pso') and 'func_and_grad' not in kwargs and 'grad_func' not in kwargs: 

872 kwargs['func_and_grad'] = self.nll_and_grad 

873 

874 best_x, best_val = optimizer_fn(self.nll_wrapper, lb, ub, **kwargs) 

875 

876 self.model.opt_x0 = best_x 

877 self.model.opt_nll = best_val 

878 

879 return best_x