Coverage for jetgp/full_gddegp/gddegp.py: 88%

168 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-04-02 21:01 -0500

1import numpy as np 

2import jetgp.utils as utils 

3from jetgp.kernel_funcs.kernel_funcs import KernelFactory, get_oti_module 

4from jetgp.full_gddegp.optimizer import Optimizer 

5from jetgp.full_gddegp import gddegp_utils 

6from scipy.linalg import cho_solve, cho_factor, solve_triangular 

7import warnings 

8 

9class gddegp: 

10 """ 

11 Global Directional Derivative-Enhanced Gaussian Process (GDDEGP) model. 

12 

13 Supports point-wise directional derivatives with unique rays per point, 

14 hypercomplex representation, and automatic normalization. Includes methods 

15 for training, prediction, and uncertainty quantification using kernel methods. 

16 

17 Parameters 

18 ---------- 

19 x_train : ndarray 

20 Training input data of shape (n_samples, n_features). 

21 y_train : list or ndarray 

22 Training targets or list of directional derivatives. 

23 n_order : int 

24 Maximum derivative order. 

25 rays_list : list of ndarray 

26 List of ray arrays. rays_list[i] has shape (d, len(derivative_locations[i])). 

27 der_indices : list of lists 

28 Derivative multi-indices corresponding to each derivative term. 

29 derivative_locations : list of lists 

30 Which training points have which derivatives. 

31 n_bases : int, optional 

32 Override the OTI space size. By default ``2 * n_direction_types`` (inferred 

33 from ``der_indices``). Pass explicitly when training on function values only 

34 (``der_indices=[]``) and you still want to predict directional derivatives: 

35 set ``n_bases = 2 * n_prediction_direction_types``. 

36 normalize : bool, default=True 

37 Whether to normalize inputs and outputs. 

38 sigma_data : float or array-like, optional 

39 Observation noise standard deviation or diagonal noise values. 

40 kernel : str, default='SE' 

41 Kernel type ('SE', 'RQ', 'Matern', etc.). 

42 kernel_type : str, default='anisotropic' 

43 Kernel anisotropy ('anisotropic' or 'isotropic'). 

44 smoothness_parameter : float, optional 

45 Smoothness parameter for Matern kernel. 

46 """ 

47 

48 def __init__(self, x_train, y_train, n_order, rays_list, der_indices, 

49 derivative_locations=None, n_bases=None, normalize=True, 

50 sigma_data=None, kernel="SE", kernel_type="anisotropic", 

51 smoothness_parameter=None): 

52 

53 if n_order > 0 and derivative_locations is None: 

54 import warnings 

55 # Count total number of derivative components across all orders 

56 n_derivs = sum(len(order_derivs) for order_derivs in der_indices) 

57 n_train = len(x_train) 

58 derivative_locations = [[i for i in range(n_train)] for _ in range(n_derivs)] 

59 warnings.warn( 

60 f"derivative_locations not provided. Assuming all {n_derivs} derivative(s) " 

61 f"are available at all {n_train} training point(s).", 

62 UserWarning 

63 ) 

64 

65 elif der_indices is None and n_order == 0: 

66 der_indices = [] 

67 derivative_locations = [] 

68 

69 self.x_train = x_train 

70 self.y_train = y_train 

71 self.sigma_data = sigma_data 

72 self.n_order = n_order 

73 self.max_order = n_order 

74 self.rays_list = rays_list 

75 self.dim = x_train.shape[1] 

76 self.num_points = x_train.shape[0] 

77 self.kernel = kernel 

78 self.kernel_type = kernel_type 

79 self.normalize = normalize 

80 self.derivative_locations = derivative_locations 

81 self.der_indices = der_indices 

82 

83 # Flatten derivative indices first so we can size the OTI module correctly. 

84 # GDDEGP needs 2 OTI bases per direction type (one odd tag for X1, one even 

85 # tag for X2), so n_bases = 2 * n_direction_types by default. 

86 # An explicit n_bases can be passed to support function-only training 

87 # (der_indices=[]) while still reserving OTI space for derivative predictions. 

88 self.flattened_der_indices = utils.flatten_der_indices(der_indices) 

89 if n_bases is not None: 

90 self.n_bases = n_bases 

91 else: 

92 self.n_bases = 2 * len(self.flattened_der_indices) 

93 self.oti = get_oti_module(self.n_bases, n_order) 

94 

95 if normalize: 

96 self.y_train, self.mu_y, self.sigma_y, self.sigmas_x, self.mus_x, sigma_data = \ 

97 utils.normalize_y_data_directional( 

98 x_train, y_train, sigma_data, self.flattened_der_indices) 

99 self.rays_list = utils.normalize_directions_2(self.sigmas_x, self.rays_list) 

100 self.x_train = utils.normalize_x_data_train(x_train) 

101 else: 

102 self.x_train = x_train 

103 self.y_train = utils.reshape_y_train(y_train) 

104 

105 self.differences_by_dim = gddegp_utils.differences_by_dim_func( 

106 self.x_train, self.x_train, 

107 self.rays_list, self.rays_list, 

108 self.derivative_locations, self.derivative_locations, 

109 n_order, self.oti, return_deriv=True 

110 ) 

111 

112 self.sigma_data = ( 

113 np.zeros((self.y_train.shape[0], self.y_train.shape[0])) 

114 if sigma_data is None else 10 * np.diag(sigma_data) 

115 ) 

116 

117 self.kernel_factory = KernelFactory( 

118 dim=self.dim, 

119 normalize=self.normalize, 

120 n_order=self.max_order, 

121 differences_by_dim=self.differences_by_dim, 

122 smoothness_parameter=smoothness_parameter, 

123 oti_module=self.oti 

124 ) 

125 self.kernel_func = self.kernel_factory.create_kernel( 

126 kernel_name=self.kernel, 

127 kernel_type=self.kernel_type 

128 ) 

129 self.bounds = self.kernel_factory.bounds 

130 self.optimizer = Optimizer(self) 

131 

132 def optimize_hyperparameters(self, *args, **kwargs): 

133 """ 

134 Run the optimizer to find the best kernel hyperparameters. 

135 Returns optimized hyperparameter vector. 

136 """ 

137 self.params = self.optimizer.optimize_hyperparameters(*args, **kwargs) 

138 return self.params 

139 

140 def predict(self, X_test, params, rays_predict=None, calc_cov=False, 

141 return_deriv=False, derivs_to_predict=None): 

142 """ 

143 Predict posterior mean and optional variance at test points. 

144 

145 Parameters 

146 ---------- 

147 X_test : ndarray 

148 Test input points of shape (n_test, n_features). 

149 params : ndarray 

150 Log-scaled kernel hyperparameters. 

151 rays_predict : list of ndarray, optional 

152 Rays at test points for derivative predictions. 

153 calc_cov : bool, default=False 

154 Whether to compute predictive variance. 

155 return_deriv : bool, default=False 

156 Whether to return derivative predictions. 

157 derivs_to_predict : list, optional 

158 Specific derivatives to predict. Can include derivatives not present in the 

159 training set — the cross-covariance K_* is constructed from kernel derivatives 

160 and does not require the requested derivative to have been observed during 

161 training. Each entry must be a valid derivative spec within n_bases and n_order. 

162 If None, defaults to all derivatives used in training. 

163 

164 Returns 

165 ------- 

166 f_mean : ndarray 

167 Predictive mean vector. 

168 f_var : ndarray, optional 

169 Predictive variance vector (only if calc_cov=True). 

170 """ 

171 

172 n_predict = X_test.shape[0] 

173 

174 # Handle missing rays_predict when derivatives are requested 

175 if return_deriv and rays_predict is None: 

176 n_rays = len(self.flattened_der_indices) 

177 

178 warnings.warn( 

179 f"No rays_predict provided for derivative predictions. " 

180 f"Predictions will be made along coordinate axes: " 

181 f"[1,0,0,...], [0,1,0,...], etc. for {n_rays} directional derivative(s).", 

182 UserWarning 

183 ) 

184 

185 # Construct coordinate axis rays for each entry in flattened_der_indices 

186 # Each ray array has shape (n_bases, n_predict) 

187 rays_predict = [] 

188 for i in range(n_rays): 

189 # Cycle through coordinate axes if more rays than dimensions 

190 axis_idx = i % self.dim 

191 ray_array = np.zeros((self.dim, n_predict)) 

192 ray_array[axis_idx, :] = 1.0 

193 rays_predict.append(ray_array) 

194 

195 # Warn if rays provided but not needed 

196 if not return_deriv and rays_predict is not None: 

197 warnings.warn( 

198 "rays_predict was provided but return_deriv=False. " 

199 "The provided rays will be ignored.", 

200 UserWarning 

201 ) 

202 

203 # Validate rays_predict structure when predicting derivatives 

204 if return_deriv and rays_predict is not None: 

205 # Check number of rays doesn't exceed training rays 

206 if len(self.rays_list) > 0 and len(rays_predict) > len(self.rays_list): 

207 raise ValueError( 

208 f"Number of prediction rays ({len(rays_predict)}) exceeds the number of " 

209 f"training rays ({len(self.rays_list)}). rays_predict must have at most " 

210 f"{len(self.rays_list)} ray array(s)." 

211 ) 

212 

213 # Check shape of each ray array 

214 for i, ray_array in enumerate(rays_predict): 

215 if not isinstance(ray_array, np.ndarray): 

216 raise TypeError( 

217 f"Ray array {i} must be a numpy ndarray, got {type(ray_array).__name__}." 

218 ) 

219 

220 if ray_array.ndim != 2: 

221 raise ValueError( 

222 f"Ray array {i} must be 2-dimensional, got {ray_array.ndim} dimensions." 

223 ) 

224 

225 if ray_array.shape[0] != self.dim: 

226 raise ValueError( 

227 f"Ray array {i} has {ray_array.shape[0]} rows, expected {self.n_bases} " 

228 f"(one per spatial dimension)." 

229 ) 

230 

231 if ray_array.shape[1] != n_predict: 

232 raise ValueError( 

233 f"Ray array {i} has {ray_array.shape[1]} columns, expected {n_predict} " 

234 f"(one per test point)." 

235 ) 

236 

237 length_scales = params[:-1] 

238 sigma_n = params[-1] 

239 

240 # Set up derivative prediction configuration 

241 if return_deriv: 

242 if derivs_to_predict is not None: 

243 common_derivs = derivs_to_predict 

244 else: 

245 if self.n_order == 0: 

246 raise ValueError( 

247 "derivs_to_predict must be provided when predicting derivatives " 

248 "from a model trained with n_order=0 (no derivative training data)." 

249 ) 

250 common_derivs = self.flattened_der_indices 

251 print( 

252 f"Note: derivs_to_predict is None. Predictions will include all derivatives " 

253 f"used in training: {self.flattened_der_indices}" 

254 ) 

255 

256 # Determine prediction order from requested derivatives 

257 required_order = max( 

258 sum(pair[1] for pair in deriv_spec) 

259 for deriv_spec in common_derivs 

260 ) 

261 predict_order = max(required_order, self.n_order) 

262 

263 if predict_order > self.n_order: 

264 predict_oti = get_oti_module(self.n_bases, predict_order) 

265 smoothness_param = getattr(self.kernel_factory, 'alpha', None) 

266 predict_kernel_factory = KernelFactory( 

267 dim=self.dim, 

268 normalize=self.normalize, 

269 differences_by_dim=self.differences_by_dim, 

270 n_order=predict_order, 

271 smoothness_parameter=smoothness_param, 

272 oti_module=predict_oti 

273 ) 

274 predict_kernel_func = predict_kernel_factory.create_kernel( 

275 kernel_name=self.kernel, kernel_type=self.kernel_type 

276 ) 

277 else: 

278 predict_oti = self.oti 

279 predict_kernel_func = self.kernel_func 

280 else: 

281 common_derivs = [] 

282 predict_order = self.n_order 

283 predict_oti = self.oti 

284 predict_kernel_func = self.kernel_func 

285 

286 # Build training kernel matrix 

287 phi_train = self.kernel_func(self.differences_by_dim, length_scales) 

288 if self.n_order == 0: 

289 self.n_bases = 0 

290 phi_exp_train = phi_train.real 

291 phi_exp_train = phi_exp_train[np.newaxis,:,:] 

292 else: 

293 # Take the max so that an explicitly-set n_bases (e.g. for function-only 

294 # training or for predicting more directions than were observed) is never 

295 # silently reduced to the number of bases active in the training kernel. 

296 active = phi_train.get_active_bases() 

297 self.n_bases = max(self.n_bases, active[-1] if active else 0) 

298 phi_exp_train = phi_train.get_all_derivs(self.n_bases, 2 * self.n_order) 

299 

300 # Placeholder for powers (GDDEGP doesn't use sign powers like DEGP/DDEGP) 

301 powers = [0] * (len(self.flattened_der_indices) + 1) 

302 

303 K = gddegp_utils.rbf_kernel( 

304 phi_train, phi_exp_train, self.n_order, self.n_bases, 

305 self.flattened_der_indices, 

306 index=self.derivative_locations 

307 ) 

308 K += (10 ** sigma_n) ** 2 * np.eye(K.shape[0]) 

309 K += self.sigma_data ** 2 

310 self.K_train = K 

311 # Solve linear system 

312 try: 

313 cho_solve_failed = False 

314 L, low = cho_factor(K, lower=True) 

315 alpha = cho_solve((L, low), self.y_train) 

316 except: 

317 cho_solve_failed = True 

318 alpha = np.linalg.solve(K, self.y_train) 

319 print('Warning: Cholesky decomposition failed via scipy, using standard np solve instead.') 

320 

321 # Normalize test inputs and rays 

322 rays_test = rays_predict 

323 

324 if self.normalize: 

325 X_test = utils.normalize_x_data_test(X_test, self.sigmas_x, self.mus_x) 

326 

327 if not return_deriv: 

328 rays_test = None 

329 derivative_locations_test = None 

330 else: 

331 derivative_locations_test = [ 

332 list(range(X_test.shape[0])) for _ in range(len(common_derivs))] 

333 if self.normalize: 

334 rays_test = utils.normalize_directions_2(self.sigmas_x, rays_test) 

335 

336 # Compute train-test differences 

337 diff_x_train_x_test = gddegp_utils.differences_by_dim_func( 

338 self.x_train, X_test, 

339 self.rays_list, rays_test, 

340 self.derivative_locations, derivative_locations_test, 

341 predict_order, predict_oti, return_deriv=return_deriv 

342 ) 

343 

344 # Compute train-test kernel 

345 phi_train_test = predict_kernel_func(diff_x_train_x_test, length_scales) 

346 if predict_order > 0: 

347 if return_deriv: 

348 phi_exp_train_test = phi_train_test.get_all_derivs(self.n_bases, 2 * predict_order) 

349 else: 

350 phi_exp_train_test = phi_train_test.get_all_derivs(self.n_bases, predict_order) 

351 else: 

352 phi_exp_train_test = phi_train_test.real 

353 phi_exp_train_test = phi_exp_train_test[np.newaxis, :, :] 

354 K_s = gddegp_utils.rbf_kernel_predictions( 

355 phi_train_test, phi_exp_train_test, predict_order, self.n_bases, 

356 self.flattened_der_indices, 

357 return_deriv=return_deriv, 

358 index=self.derivative_locations, 

359 common_derivs=common_derivs 

360 ) 

361 

362 # Compute posterior mean 

363 f_mean = K_s @ alpha 

364 

365 # Denormalize predictions 

366 if self.normalize: 

367 if return_deriv: 

368 f_mean = utils.transform_predictions_directional( 

369 f_mean, self.mu_y, self.sigma_y, self.sigmas_x, 

370 common_derivs, X_test) 

371 else: 

372 f_mean = self.mu_y + f_mean * self.sigma_y 

373 

374 # Reshape predictions 

375 f_mean = f_mean.reshape(-1, 1) 

376 n = X_test.shape[0] 

377 m = f_mean.shape[0] 

378 num_derivs = m // n 

379 reshaped_mean = f_mean.reshape(num_derivs, n) 

380 

381 if not calc_cov: 

382 return reshaped_mean 

383 

384 # Compute test-test differences 

385 diff_x_test_x_test = gddegp_utils.differences_by_dim_func( 

386 X_test, X_test, 

387 rays_test, rays_test, 

388 derivative_locations_test, derivative_locations_test, 

389 predict_order, predict_oti, return_deriv=return_deriv 

390 ) 

391 

392 # Compute test-test kernel 

393 phi_test_test = predict_kernel_func(diff_x_test_x_test, length_scales) 

394 if predict_order > 0: 

395 phi_exp_test_test = phi_test_test.get_all_derivs(self.n_bases, 2 * predict_order) 

396 else: 

397 phi_exp_test_test = phi_test_test.real 

398 phi_exp_test_test = phi_exp_test_test[np.newaxis, :, :] 

399 

400 K_ss = gddegp_utils.rbf_kernel_predictions( 

401 phi_test_test, phi_exp_test_test, predict_order, self.n_bases, 

402 self.flattened_der_indices, 

403 return_deriv=return_deriv, 

404 index=derivative_locations_test, 

405 common_derivs=common_derivs, 

406 calc_cov=True, 

407 ) 

408 

409 # Compute predictive covariance 

410 if cho_solve_failed: 

411 v_fallback = np.linalg.solve(K, K_s.T) 

412 f_cov = K_ss - K_s @ v_fallback 

413 else: 

414 v = solve_triangular(L, K_s.T, lower=low) 

415 f_cov = K_ss - v.T @ v 

416 

417 # Transform covariance 

418 if self.normalize: 

419 if return_deriv: 

420 f_var = utils.transform_cov_directional( 

421 f_cov, self.sigma_y, self.sigmas_x, 

422 common_derivs, X_test) 

423 else: 

424 f_var = self.sigma_y ** 2 * np.diag(np.abs(f_cov)) 

425 else: 

426 f_var = np.diag(np.abs(f_cov)) 

427 

428 reshaped_var = f_var.reshape(num_derivs, n) 

429 return reshaped_mean, reshaped_var