Coverage for jetgp/full_ddegp/wddegp_utils.py: 69%

401 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-04-03 15:09 -0500

1import numpy as np 

2import pyoti.sparse as oti 

3import pyoti.core as coti 

4from line_profiler import profile 

5import numba 

6 

7 

8# ============================================================================= 

9# Numba-accelerated helper functions for efficient matrix slicing 

10# ============================================================================= 

11 

12@numba.jit(nopython=True, cache=True) 

13def extract_rows(content_full, row_indices, n_cols): 

14 """ 

15 Extract rows from content_full at specified indices. 

16  

17 Parameters 

18 ---------- 

19 content_full : ndarray of shape (n_rows_full, n_cols) 

20 Source matrix. 

21 row_indices : ndarray of int64 

22 Row indices to extract. 

23 n_cols : int 

24 Number of columns. 

25  

26 Returns 

27 ------- 

28 result : ndarray of shape (len(row_indices), n_cols) 

29 Extracted rows. 

30 """ 

31 n_rows = len(row_indices) 

32 result = np.empty((n_rows, n_cols)) 

33 for i in range(n_rows): 

34 ri = row_indices[i] 

35 for j in range(n_cols): 

36 result[i, j] = content_full[ri, j] 

37 return result 

38 

39 

40@numba.jit(nopython=True, cache=True) 

41def extract_cols(content_full, col_indices, n_rows): 

42 """ 

43 Extract columns from content_full at specified indices. 

44  

45 Parameters 

46 ---------- 

47 content_full : ndarray of shape (n_rows, n_cols_full) 

48 Source matrix. 

49 col_indices : ndarray of int64 

50 Column indices to extract. 

51 n_rows : int 

52 Number of rows. 

53  

54 Returns 

55 ------- 

56 result : ndarray of shape (n_rows, len(col_indices)) 

57 Extracted columns. 

58 """ 

59 n_cols = len(col_indices) 

60 result = np.empty((n_rows, n_cols)) 

61 for i in range(n_rows): 

62 for j in range(n_cols): 

63 result[i, j] = content_full[i, col_indices[j]] 

64 return result 

65 

66 

67@numba.jit(nopython=True, cache=True) 

68def extract_submatrix(content_full, row_indices, col_indices): 

69 """ 

70 Extract submatrix from content_full at specified row and column indices. 

71 Replaces the expensive np.ix_ operation. 

72  

73 Parameters 

74 ---------- 

75 content_full : ndarray of shape (n_rows_full, n_cols_full) 

76 Source matrix. 

77 row_indices : ndarray of int64 

78 Row indices to extract. 

79 col_indices : ndarray of int64 

80 Column indices to extract. 

81  

82 Returns 

83 ------- 

84 result : ndarray of shape (len(row_indices), len(col_indices)) 

85 Extracted submatrix. 

86 """ 

87 n_rows = len(row_indices) 

88 n_cols = len(col_indices) 

89 result = np.empty((n_rows, n_cols)) 

90 for i in range(n_rows): 

91 ri = row_indices[i] 

92 for j in range(n_cols): 

93 result[i, j] = content_full[ri, col_indices[j]] 

94 return result 

95 

96 

97@numba.jit(nopython=True, cache=True, parallel=False) 

98def extract_and_assign(content_full, row_indices, col_indices, K, 

99 row_start, col_start, sign): 

100 """ 

101 Extract submatrix and assign directly to K with sign multiplication. 

102 Combines extraction and assignment in one pass for better performance. 

103  

104 Parameters 

105 ---------- 

106 content_full : ndarray of shape (n_rows_full, n_cols_full) 

107 Source matrix. 

108 row_indices : ndarray of int64 

109 Row indices to extract. 

110 col_indices : ndarray of int64 

111 Column indices to extract. 

112 K : ndarray 

113 Target matrix to fill. 

114 row_start : int 

115 Starting row index in K. 

116 col_start : int 

117 Starting column index in K. 

118 sign : float 

119 Sign multiplier (+1.0 or -1.0). 

120 """ 

121 n_rows = len(row_indices) 

122 n_cols = len(col_indices) 

123 for i in range(n_rows): 

124 ri = row_indices[i] 

125 for j in range(n_cols): 

126 K[row_start + i, col_start + j] = content_full[ri, col_indices[j]] * sign 

127 

128 

129@numba.jit(nopython=True, cache=True) 

130def extract_rows_and_assign(content_full, row_indices, K, 

131 row_start, col_start, n_cols, sign): 

132 """ 

133 Extract rows and assign directly to K with sign multiplication. 

134  

135 Parameters 

136 ---------- 

137 content_full : ndarray of shape (n_rows_full, n_cols) 

138 Source matrix. 

139 row_indices : ndarray of int64 

140 Row indices to extract. 

141 K : ndarray 

142 Target matrix to fill. 

143 row_start : int 

144 Starting row index in K. 

145 col_start : int 

146 Starting column index in K. 

147 n_cols : int 

148 Number of columns to copy. 

149 sign : float 

150 Sign multiplier (+1.0 or -1.0). 

151 """ 

152 n_rows = len(row_indices) 

153 for i in range(n_rows): 

154 ri = row_indices[i] 

155 for j in range(n_cols): 

156 K[row_start + i, col_start + j] = content_full[ri, j] * sign 

157 

158 

159@numba.jit(nopython=True, cache=True) 

160def extract_cols_and_assign(content_full, col_indices, K, 

161 row_start, col_start, n_rows, sign): 

162 """ 

163 Extract columns and assign directly to K with sign multiplication. 

164  

165 Parameters 

166 ---------- 

167 content_full : ndarray of shape (n_rows, n_cols_full) 

168 Source matrix. 

169 col_indices : ndarray of int64 

170 Column indices to extract. 

171 K : ndarray 

172 Target matrix to fill. 

173 row_start : int 

174 Starting row index in K. 

175 col_start : int 

176 Starting column index in K. 

177 n_rows : int 

178 Number of rows to copy. 

179 sign : float 

180 Sign multiplier (+1.0 or -1.0). 

181 """ 

182 n_cols = len(col_indices) 

183 for i in range(n_rows): 

184 for j in range(n_cols): 

185 K[row_start + i, col_start + j] = content_full[i, col_indices[j]] * sign 

186 

187 

188# ============================================================================= 

189# Difference computation functions 

190# ============================================================================= 

191 

192def differences_by_dim_func(X1, X2, rays, n_order,oti_module, return_deriv=True): 

193 """ 

194 Compute dimension-wise pairwise differences between X1 and X2, 

195 including hypercomplex perturbations in the directions specified by `rays`. 

196  

197 This optimized version pre-calculates the perturbation and uses a single 

198 efficient loop for subtraction, avoiding broadcasting issues with OTI arrays. 

199  

200 Parameters 

201 ---------- 

202 X1 : ndarray of shape (n1, d) 

203 First set of input points with n1 samples in d dimensions. 

204 X2 : ndarray of shape (n2, d) 

205 Second set of input points with n2 samples in d dimensions. 

206 rays : ndarray of shape (d, n_rays) 

207 Directional vectors for derivative computation. 

208 n_order : int 

209 The base order used to construct hypercomplex units. 

210 When return_deriv=True, uses order 2*n_order. 

211 When return_deriv=False, uses order n_order. 

212 return_deriv : bool, optional (default=True) 

213 If True, use order 2*n_order for hypercomplex units (needed for  

214 derivative-derivative blocks in training kernel). 

215 If False, use order n_order (sufficient for prediction without  

216 derivative outputs). 

217  

218 Returns 

219 ------- 

220 differences_by_dim : list of length d 

221 A list where each element is an array of shape (n1, n2), containing  

222 the differences between corresponding dimensions of X1 and X2,  

223 augmented with directional hypercomplex perturbations. 

224 """ 

225 X1 = oti_module.array(X1) 

226 X2 = oti_module.array(X2) 

227 n1, d = X1.shape 

228 n2, _ = X2.shape 

229 n_rays = rays.shape[1] 

230 

231 differences_by_dim = [] 

232 

233 # Case 1: n_order == 0 (no hypercomplex perturbation) 

234 if n_order == 0: 

235 for k in range(d): 

236 diffs_k = oti_module.zeros((n1, n2)) 

237 for i in range(n1): 

238 diffs_k[i, :] = X1[i, k] - X2[:, k].T 

239 differences_by_dim.append(diffs_k) 

240 return differences_by_dim 

241 

242 # Determine the order for hypercomplex units based on return_deriv 

243 if return_deriv: 

244 hc_order = 2 * n_order 

245 else: 

246 hc_order = n_order 

247 

248 # Pre-calculate the perturbation vector using directional rays 

249 e_bases = [oti_module.e(i + 1, order=hc_order) for i in range(n_rays)] 

250 perts = np.dot(rays, e_bases) 

251 

252 # Case 2: return_deriv=False (prediction without derivative outputs) 

253 if not return_deriv: 

254 for k in range(d): 

255 # Add the pre-calculated perturbation for the current dimension to all points in X1 

256 X1_k_tagged = X1[:, k] + perts[k] 

257 X2_k = X2[:, k] 

258 

259 # Pre-allocate the result matrix for this dimension 

260 diffs_k = oti_module.zeros((n1, n2)) 

261 

262 # Use an efficient single loop for subtraction 

263 for i in range(n1): 

264 diffs_k[i, :] = X1_k_tagged[i, 0] - X2_k[:, 0].T 

265 

266 differences_by_dim.append(diffs_k) 

267 

268 # Case 3: return_deriv=True (training kernel with derivative-derivative blocks) 

269 else: 

270 for k in range(d): 

271 X2_k = X2[:, k] 

272 

273 # Pre-allocate the result matrix for this dimension 

274 diffs_k = oti_module.zeros((n1, n2)) 

275 

276 # Compute differences without perturbation first 

277 for i in range(n1): 

278 diffs_k[i, :] = X1[i, k] - X2_k[:, 0].T 

279 

280 # Add perturbation to the entire matrix (more efficient) 

281 differences_by_dim.append(diffs_k + perts[k]) 

282 

283 return differences_by_dim 

284 

285 

286# ============================================================================= 

287# Derivative mapping utilities 

288# ============================================================================= 

289 

290def deriv_map(nbases, order): 

291 """ 

292 Creates a mapping from (order, index_within_order) to a single 

293 flattened index for all derivative components. 

294 """ 

295 k = 0 

296 map_deriv = [] 

297 for ordi in range(order + 1): 

298 ndir = coti.ndir_order(nbases, ordi) 

299 map_deriv_i = [0] * ndir 

300 for idx in range(ndir): 

301 map_deriv_i[idx] = k 

302 k += 1 

303 map_deriv.append(map_deriv_i) 

304 return map_deriv 

305 

306 

307def transform_der_indices(der_indices, der_map): 

308 """ 

309 Transforms a list of user-facing derivative specifications into the 

310 internal (order, index) format and the final flattened index. 

311 """ 

312 deriv_ind_transf = [] 

313 deriv_ind_order = [] 

314 for deriv in der_indices: 

315 imdir = coti.imdir(deriv) 

316 idx, order = imdir 

317 deriv_ind_transf.append(der_map[order][idx]) 

318 deriv_ind_order.append(imdir) 

319 return deriv_ind_transf, deriv_ind_order 

320 

321 

322# ============================================================================= 

323# RBF Kernel Assembly Functions (Optimized with Numba) 

324# ============================================================================= 

325 

326def rbf_kernel( 

327 phi, 

328 phi_exp, 

329 n_order, 

330 n_bases, 

331 der_indices, 

332 powers, 

333 index=-1 

334): 

335 """ 

336 Assembles the full DD-GP covariance matrix using an efficient, pre-computed 

337 derivative array and block-wise matrix filling. 

338  

339 This version uses Numba-accelerated functions for efficient matrix slicing, 

340 replacing expensive np.ix_ operations. 

341  

342 Parameters 

343 ---------- 

344 phi : OTI array 

345 Base kernel matrix from kernel_func(differences, length_scales). 

346 phi_exp : ndarray 

347 Expanded derivative array from phi.get_all_derivs(). 

348 n_order : int 

349 Maximum derivative order considered. 

350 n_bases : int 

351 Total number of bases (function value + derivative terms). 

352 der_indices : list of lists 

353 Multi-index derivative structures for each derivative component. 

354 powers : list of int 

355 Powers of (-1) applied to each term (for symmetry or sign conventions). 

356 index : list of lists 

357 Specifies which training point indices have each derivative type. 

358  

359 Returns 

360 ------- 

361 K : ndarray 

362 Full kernel matrix with function values and derivative blocks. 

363 """ 

364 dh = coti.get_dHelp() 

365 

366 # Create maps to translate derivative specifications to flat indices 

367 der_map = deriv_map(n_bases, 2 * n_order) 

368 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map) 

369 

370 # Determine Block Sizes and Pre-allocate Matrix 

371 n_rows_func, n_cols_func = phi.shape 

372 n_deriv_types = len(der_indices) 

373 n_pts_with_derivs_cols = sum(len(order_indices) for order_indices in index) 

374 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index) 

375 total_rows = n_rows_func + n_pts_with_derivs_rows 

376 total_cols = n_cols_func + n_pts_with_derivs_cols 

377 

378 K = np.zeros((total_rows, total_cols)) 

379 base_shape = (n_rows_func, n_cols_func) 

380 

381 # Pre-compute signs (avoid repeated exponentiation) 

382 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64) 

383 

384 # Convert index lists to numpy arrays for numba 

385 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index] 

386 

387 # Block (0,0): Function-Function (K_ff) 

388 content_full = phi_exp[0].reshape(base_shape) 

389 K[:n_rows_func, :n_cols_func] = content_full * signs[0] 

390 

391 # First Block-Column: Derivative-Function (K_df) 

392 row_offset = n_rows_func 

393 for i in range(n_deriv_types): 

394 flat_idx = der_indices_tr[i] 

395 content_full = phi_exp[flat_idx].reshape(base_shape) 

396 

397 current_indices = index_arrays[i] 

398 n_pts_this_order = len(current_indices) 

399 

400 # Use numba for efficient row extraction and assignment 

401 extract_rows_and_assign(content_full, current_indices, K, 

402 row_offset, 0, n_cols_func, signs[0]) 

403 row_offset += n_pts_this_order 

404 

405 # First Block-Row: Function-Derivative (K_fd) 

406 col_offset = n_cols_func 

407 for j in range(n_deriv_types): 

408 flat_idx = der_indices_tr[j] 

409 content_full = phi_exp[flat_idx].reshape(base_shape) 

410 

411 current_indices = index_arrays[j] 

412 n_pts_this_order = len(current_indices) 

413 

414 # Use numba for efficient column extraction and assignment 

415 extract_cols_and_assign(content_full, current_indices, K, 

416 0, col_offset, n_rows_func, signs[j + 1]) 

417 col_offset += n_pts_this_order 

418 

419 # Inner Blocks: Derivative-Derivative (K_dd) 

420 row_offset = n_rows_func 

421 for i in range(n_deriv_types): 

422 col_offset = n_cols_func 

423 

424 row_indices = index_arrays[i] 

425 n_pts_row = len(row_indices) 

426 

427 for j in range(n_deriv_types): 

428 col_indices = index_arrays[j] 

429 n_pts_col = len(col_indices) 

430 

431 # Multiply derivative indices to find correct flat index 

432 imdir1 = der_ind_order[j] 

433 imdir2 = der_ind_order[i] 

434 new_idx, new_ord = dh.mult_dir(imdir1[0], imdir1[1], imdir2[0], imdir2[1]) 

435 flat_idx = der_map[new_ord][new_idx] 

436 content_full = phi_exp[flat_idx].reshape(base_shape) 

437 

438 # Use numba for efficient submatrix extraction and assignment (replaces np.ix_) 

439 extract_and_assign(content_full, row_indices, col_indices, K, 

440 row_offset, col_offset, signs[j + 1]) 

441 

442 col_offset += n_pts_col 

443 

444 row_offset += n_pts_row 

445 

446 return K 

447 

448 

449@numba.jit(nopython=True, cache=True) 

450def _assemble_kernel_numba(phi_exp_3d, K, n_rows_func, n_cols_func, 

451 fd_flat_indices, df_flat_indices, dd_flat_indices, 

452 idx_flat, idx_offsets, idx_sizes, 

453 signs, n_deriv_types, row_offsets, col_offsets): 

454 """Fused numba kernel for entire K matrix assembly.""" 

455 s0 = signs[0] 

456 for r in range(n_rows_func): 

457 for c in range(n_cols_func): 

458 K[r, c] = phi_exp_3d[0, r, c] * s0 

459 for j in range(n_deriv_types): 

460 fi = fd_flat_indices[j] 

461 sj = signs[j + 1] 

462 co = col_offsets[j] 

463 off_j = idx_offsets[j] 

464 sz_j = idx_sizes[j] 

465 for r in range(n_rows_func): 

466 for k in range(sz_j): 

467 ci = idx_flat[off_j + k] 

468 K[r, co + k] = phi_exp_3d[fi, r, ci] * sj 

469 for i in range(n_deriv_types): 

470 fi = df_flat_indices[i] 

471 ro = row_offsets[i] 

472 off_i = idx_offsets[i] 

473 sz_i = idx_sizes[i] 

474 for k in range(sz_i): 

475 ri = idx_flat[off_i + k] 

476 for c in range(n_cols_func): 

477 K[ro + k, c] = phi_exp_3d[fi, ri, c] * s0 

478 for i in range(n_deriv_types): 

479 ro = row_offsets[i] 

480 off_i = idx_offsets[i] 

481 sz_i = idx_sizes[i] 

482 for j in range(n_deriv_types): 

483 fi = dd_flat_indices[i, j] 

484 sj = signs[j + 1] 

485 co = col_offsets[j] 

486 off_j = idx_offsets[j] 

487 sz_j = idx_sizes[j] 

488 for ki in range(sz_i): 

489 ri = idx_flat[off_i + ki] 

490 for kj in range(sz_j): 

491 ci = idx_flat[off_j + kj] 

492 K[ro + ki, co + kj] = phi_exp_3d[fi, ri, ci] * sj 

493 

494 

495@numba.jit(nopython=True, cache=True) 

496def _project_W_to_phi_space(W, W_proj, n_rows_func, n_cols_func, 

497 fd_flat_indices, df_flat_indices, dd_flat_indices, 

498 idx_flat, idx_offsets, idx_sizes, 

499 signs, n_deriv_types, row_offsets, col_offsets): 

500 """ 

501 Reverse of _assemble_kernel_numba: project W from K-space back into 

502 phi_exp-space so that vdot(W, assemble(dphi_exp)) == vdot(W_proj, dphi_exp). 

503 """ 

504 for d in range(W_proj.shape[0]): 

505 for r in range(W_proj.shape[1]): 

506 for c in range(W_proj.shape[2]): 

507 W_proj[d, r, c] = 0.0 

508 s0 = signs[0] 

509 for r in range(n_rows_func): 

510 for c in range(n_cols_func): 

511 W_proj[0, r, c] += s0 * W[r, c] 

512 for j in range(n_deriv_types): 

513 fi = fd_flat_indices[j] 

514 sj = signs[j + 1] 

515 co = col_offsets[j] 

516 off_j = idx_offsets[j] 

517 sz_j = idx_sizes[j] 

518 for r in range(n_rows_func): 

519 for k in range(sz_j): 

520 ci = idx_flat[off_j + k] 

521 W_proj[fi, r, ci] += sj * W[r, co + k] 

522 for i in range(n_deriv_types): 

523 fi = df_flat_indices[i] 

524 ro = row_offsets[i] 

525 off_i = idx_offsets[i] 

526 sz_i = idx_sizes[i] 

527 for k in range(sz_i): 

528 ri = idx_flat[off_i + k] 

529 for c in range(n_cols_func): 

530 W_proj[fi, ri, c] += s0 * W[ro + k, c] 

531 for i in range(n_deriv_types): 

532 ro = row_offsets[i] 

533 off_i = idx_offsets[i] 

534 sz_i = idx_sizes[i] 

535 for j in range(n_deriv_types): 

536 fi = dd_flat_indices[i, j] 

537 sj = signs[j + 1] 

538 co = col_offsets[j] 

539 off_j = idx_offsets[j] 

540 sz_j = idx_sizes[j] 

541 for ki in range(sz_i): 

542 ri = idx_flat[off_i + ki] 

543 for kj in range(sz_j): 

544 ci = idx_flat[off_j + kj] 

545 W_proj[fi, ri, ci] += sj * W[ro + ki, co + kj] 

546 

547 

548def precompute_kernel_plan(n_order, n_bases, der_indices, powers, index): 

549 """Precompute structural info for rbf_kernel_fast.""" 

550 dh = coti.get_dHelp() 

551 der_map = deriv_map(n_bases, 2 * n_order) 

552 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map) 

553 

554 n_deriv_types = len(der_indices) 

555 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64) 

556 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index] 

557 

558 index_sizes = np.array([len(idx) for idx in index_arrays], dtype=np.int64) 

559 n_pts_with_derivs = int(index_sizes.sum()) 

560 

561 idx_flat = np.concatenate(index_arrays) if n_deriv_types > 0 else np.array([], dtype=np.int64) 

562 idx_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

563 for i in range(1, n_deriv_types): 

564 idx_offsets[i] = idx_offsets[i - 1] + index_sizes[i - 1] 

565 

566 row_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

567 col_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

568 cumsum = 0 

569 for i in range(n_deriv_types): 

570 row_offsets[i] = cumsum 

571 col_offsets[i] = cumsum 

572 cumsum += index_sizes[i] 

573 

574 dd_flat_indices = np.empty((n_deriv_types, n_deriv_types), dtype=np.int64) 

575 for i in range(n_deriv_types): 

576 for j in range(n_deriv_types): 

577 imdir1 = der_ind_order[j] 

578 imdir2 = der_ind_order[i] 

579 new_idx, new_ord = dh.mult_dir(imdir1[0], imdir1[1], imdir2[0], imdir2[1]) 

580 dd_flat_indices[i, j] = der_map[new_ord][new_idx] 

581 

582 fd_flat_indices = np.array(der_indices_tr, dtype=np.int64) 

583 df_flat_indices = np.array(der_indices_tr, dtype=np.int64) 

584 

585 return { 

586 'der_indices_tr': der_indices_tr, 

587 'signs': signs, 

588 'index_arrays': index_arrays, 

589 'index_sizes': index_sizes, 

590 'n_pts_with_derivs': n_pts_with_derivs, 

591 'dd_flat_indices': dd_flat_indices, 

592 'n_deriv_types': n_deriv_types, 

593 'idx_flat': idx_flat, 

594 'idx_offsets': idx_offsets, 

595 'row_offsets': row_offsets, 

596 'col_offsets': col_offsets, 

597 'fd_flat_indices': fd_flat_indices, 

598 'df_flat_indices': df_flat_indices, 

599 } 

600 

601 

602def rbf_kernel_fast(phi_exp_3d, plan, out=None): 

603 """Fast kernel assembly using precomputed plan and fused numba kernel.""" 

604 n_rows_func = phi_exp_3d.shape[1] 

605 n_cols_func = phi_exp_3d.shape[2] 

606 total = n_rows_func + plan['n_pts_with_derivs'] 

607 if out is not None: 

608 K = out 

609 else: 

610 K = np.empty((total, total)) 

611 

612 if 'row_offsets_abs' in plan: 

613 row_off = plan['row_offsets_abs'] 

614 col_off = plan['col_offsets_abs'] 

615 else: 

616 row_off = plan['row_offsets'] + n_rows_func 

617 col_off = plan['col_offsets'] + n_cols_func 

618 

619 _assemble_kernel_numba( 

620 phi_exp_3d, K, n_rows_func, n_cols_func, 

621 plan['fd_flat_indices'], plan['df_flat_indices'], plan['dd_flat_indices'], 

622 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'], 

623 plan['signs'], plan['n_deriv_types'], row_off, col_off, 

624 ) 

625 return K 

626 

627 

628def rbf_kernel_predictions( 

629 phi, 

630 phi_exp, 

631 n_order, 

632 n_bases, 

633 der_indices, 

634 powers, 

635 return_deriv, 

636 index=-1, 

637 common_derivs=None, 

638 calc_cov=False, 

639 powers_predict=None 

640): 

641 """ 

642 Constructs the RBF kernel matrix for predictions with directional derivative entries. 

643  

644 This version uses Numba-accelerated functions for efficient matrix slicing. 

645 

646 Parameters 

647 ---------- 

648 phi : OTI array 

649 Base kernel matrix between test and training points. 

650 phi_exp : ndarray 

651 Expanded derivative array from phi.get_all_derivs(). 

652 n_order : int 

653 Maximum derivative order. 

654 n_bases : int 

655 Number of input dimensions. 

656 der_indices : list 

657 Derivative specifications. 

658 powers : list of int 

659 Sign powers for each derivative type. 

660 return_deriv : bool 

661 If True, predict derivatives at ALL test points. 

662 index : list of lists 

663 Training point indices for each derivative type. 

664 common_derivs : list 

665 Common derivative indices to predict. 

666 calc_cov : bool 

667 If True, computing covariance (use all indices for rows). 

668 powers_predict : list of int, optional 

669 Sign powers for prediction derivatives. 

670 

671 Returns 

672 ------- 

673 K : ndarray 

674 Prediction kernel matrix. 

675 """ 

676 if calc_cov and not return_deriv: 

677 return phi.real 

678 

679 dh = coti.get_dHelp() 

680 

681 n_rows_func, n_cols_func = phi.shape 

682 n_deriv_types = len(der_indices) 

683 n_deriv_types_pred = len(common_derivs) if common_derivs else 0 

684 

685 # Pre-compute signs 

686 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64) 

687 if powers_predict is not None: 

688 signs_predict = np.array([(-1.0) ** p for p in powers_predict], dtype=np.float64) 

689 else: 

690 signs_predict = signs 

691 

692 if return_deriv: 

693 der_map = deriv_map(n_bases, 2 * n_order) 

694 index_2 = np.arange(phi_exp.shape[-1], dtype=np.int64) 

695 if calc_cov: 

696 index_cov = np.arange(phi_exp.shape[-1], dtype=np.int64) 

697 n_deriv_types = n_deriv_types_pred 

698 n_pts_with_derivs_rows = n_deriv_types * len([i for i in range(n_cols_func) if i < len(index_2)]) 

699 else: 

700 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index) 

701 else: 

702 der_map = deriv_map(n_bases, n_order) 

703 index_2 = np.array([], dtype=np.int64) 

704 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index) 

705 

706 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map) 

707 der_indices_tr_pred, der_ind_order_pred = transform_der_indices(common_derivs, der_map) if common_derivs else ([], []) 

708 n_pts_with_derivs_cols = n_deriv_types_pred * len([i for i in range(n_cols_func) if i < len(index_2)]) 

709 

710 total_rows = n_rows_func + n_pts_with_derivs_rows 

711 total_cols = n_cols_func + n_pts_with_derivs_cols 

712 

713 K = np.zeros((total_rows, total_cols)) 

714 base_shape = (n_rows_func, n_cols_func) 

715 

716 # Convert index lists to numpy arrays for numba 

717 if index != -1 and isinstance(index, list) and len(index) > 0: 

718 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index] 

719 else: 

720 index_arrays = [] 

721 

722 # Block (0,0): Function-Function (K_ff) 

723 content_full = phi_exp[0].reshape(base_shape) 

724 K[:n_rows_func, :n_cols_func] = content_full * signs[0] 

725 

726 if not return_deriv: 

727 # First Block-Column: Derivative-Function (K_df) 

728 row_offset = n_rows_func 

729 for i in range(n_deriv_types): 

730 if calc_cov: 

731 row_indices = index_cov 

732 else: 

733 if not index_arrays: 

734 break 

735 row_indices = index_arrays[i] 

736 n_pts_row = len(row_indices) 

737 

738 flat_idx = der_indices_tr[i] 

739 content_full = phi_exp[flat_idx].reshape(base_shape) 

740 

741 # Use numba for efficient row extraction 

742 extract_rows_and_assign(content_full, row_indices, K, 

743 row_offset, 0, n_cols_func, signs[0]) 

744 row_offset += n_pts_row 

745 return K 

746 

747 # --- return_deriv=True case --- 

748 

749 # First Block-Row: Function-Derivative (K_fd) 

750 col_offset = n_cols_func 

751 for j in range(n_deriv_types_pred): 

752 col_indices = index_2 

753 n_pts_col = len(col_indices) 

754 

755 flat_idx = der_indices_tr_pred[j] 

756 content_full = phi_exp[flat_idx].reshape(base_shape) 

757 

758 # Use numba for efficient column extraction 

759 extract_cols_and_assign(content_full, col_indices, K, 

760 0, col_offset, n_rows_func, signs_predict[j + 1]) 

761 col_offset += n_pts_col 

762 

763 # First Block-Column: Derivative-Function (K_df) 

764 row_offset = n_rows_func 

765 for i in range(n_deriv_types): 

766 if calc_cov: 

767 row_indices = index_cov 

768 flat_idx = der_indices_tr_pred[i] 

769 else: 

770 if not index_arrays: 

771 break 

772 row_indices = index_arrays[i] 

773 flat_idx = der_indices_tr[i] 

774 n_pts_row = len(row_indices) 

775 

776 content_full = phi_exp[flat_idx].reshape(base_shape) 

777 

778 # Use numba for efficient row extraction 

779 extract_rows_and_assign(content_full, row_indices, K, 

780 row_offset, 0, n_cols_func, signs[0]) 

781 row_offset += n_pts_row 

782 

783 # Inner Blocks: Derivative-Derivative (K_dd) 

784 row_offset = n_rows_func 

785 for i in range(n_deriv_types): 

786 if calc_cov: 

787 row_indices = index_cov 

788 else: 

789 if not index_arrays: 

790 break 

791 row_indices = index_arrays[i] 

792 n_pts_row = len(row_indices) 

793 

794 col_offset = n_cols_func 

795 for j in range(n_deriv_types_pred): 

796 col_indices = index_2 

797 n_pts_col = len(col_indices) 

798 

799 # Multiply the derivative indices to find the correct flat index 

800 imdir1 = der_ind_order_pred[j] 

801 imdir2 = der_ind_order_pred[i] if calc_cov else der_ind_order[i] 

802 new_idx, new_ord = dh.mult_dir( 

803 imdir1[0], imdir1[1], imdir2[0], imdir2[1]) 

804 flat_idx = der_map[new_ord][new_idx] 

805 

806 content_full = phi_exp[flat_idx].reshape(base_shape) 

807 

808 # Use numba for efficient submatrix extraction and assignment (replaces np.ix_) 

809 extract_and_assign(content_full, row_indices, col_indices, K, 

810 row_offset, col_offset, signs_predict[j + 1]) 

811 col_offset += n_pts_col 

812 row_offset += n_pts_row 

813 

814 return K 

815 

816 

817# ============================================================================= 

818# Utility functions 

819# ============================================================================= 

820 

821def determine_weights(diffs_by_dim, diffs_test, length_scales, kernel_func, sigma_n): 

822 """ 

823 Vectorized version: compute interpolation weights for multiple test points at once. 

824  

825 Parameters 

826 ---------- 

827 diffs_by_dim : list of ndarray 

828 Pairwise differences between training points (by dimension). 

829 diffs_test : list of ndarray 

830 Pairwise differences between test points and training points (by dimension). 

831 Shape: each array is (n_test, n_train) or similar batch dimension. 

832 length_scales : array-like 

833 Kernel hyperparameters. 

834 kernel_func : callable 

835 Kernel function. 

836 sigma_n : float 

837 Noise parameter (if needed). 

838  

839 Returns 

840 ------- 

841 weights_matrix : ndarray of shape (n_test, n_train) 

842 Interpolation weights for each test point. 

843 """ 

844 # Compute K matrix (training covariance) - same for all test points 

845 K = kernel_func(diffs_by_dim, length_scales).real 

846 n_train = K.shape[0] 

847 

848 # Compute r vectors (test-train covariances) for all test points at once 

849 r_all = kernel_func(diffs_test, length_scales).real 

850 n_test = r_all.shape[0] 

851 

852 # Build augmented system matrix M (same for all test points) 

853 M = np.zeros((n_train + 1, n_train + 1)) 

854 M[:n_train, :n_train] = K 

855 M[:n_train, n_train] = 1 

856 M[n_train, :n_train] = 1 

857 M[n_train, n_train] = 0 

858 

859 # Build augmented RHS for all test points 

860 r_augmented = np.zeros((n_test, n_train + 1)) 

861 r_augmented[:, :n_train] = r_all 

862 r_augmented[:, n_train] = 1 

863 

864 # Solve for all test points at once 

865 solution = np.linalg.solve(M, r_augmented.T) 

866 

867 # Extract weights (exclude Lagrange multiplier) 

868 weights_matrix = solution[:n_train, :].T 

869 

870 return weights_matrix 

871 

872 

873def to_list(x): 

874 """Convert tuple to list recursively.""" 

875 if isinstance(x, tuple): 

876 return [to_list(i) for i in x] 

877 return x 

878 

879 

880def to_tuple(item): 

881 """Convert list to tuple recursively.""" 

882 if isinstance(item, list): 

883 return tuple(to_tuple(x) for x in item) 

884 return item 

885 

886 

887def find_common_derivatives(all_indices): 

888 """Find derivative indices common to all submodels.""" 

889 sets = [set(to_tuple(elem) for elem in idx_list) for idx_list in all_indices] 

890 return sets[0].intersection(*sets[1:])