Coverage for jetgp/full_ddegp/ddegp_utils.py: 68%

376 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-04-03 15:09 -0500

1import numpy as np 

2import numba 

3import pyoti.core as coti 

4from line_profiler import profile 

5 

6 

7# ============================================================================= 

8# Numba-accelerated helper functions for efficient matrix slicing 

9# ============================================================================= 

10 

11@numba.jit(nopython=True, cache=True) 

12def extract_rows(content_full, row_indices, n_cols): 

13 """ 

14 Extract rows from content_full at specified indices. 

15  

16 Parameters 

17 ---------- 

18 content_full : ndarray of shape (n_rows_full, n_cols) 

19 Source matrix. 

20 row_indices : ndarray of int64 

21 Row indices to extract. 

22 n_cols : int 

23 Number of columns. 

24  

25 Returns 

26 ------- 

27 result : ndarray of shape (len(row_indices), n_cols) 

28 Extracted rows. 

29 """ 

30 n_rows = len(row_indices) 

31 result = np.empty((n_rows, n_cols)) 

32 for i in range(n_rows): 

33 ri = row_indices[i] 

34 for j in range(n_cols): 

35 result[i, j] = content_full[ri, j] 

36 return result 

37 

38 

39@numba.jit(nopython=True, cache=True) 

40def extract_cols(content_full, col_indices, n_rows): 

41 """ 

42 Extract columns from content_full at specified indices. 

43  

44 Parameters 

45 ---------- 

46 content_full : ndarray of shape (n_rows, n_cols_full) 

47 Source matrix. 

48 col_indices : ndarray of int64 

49 Column indices to extract. 

50 n_rows : int 

51 Number of rows. 

52  

53 Returns 

54 ------- 

55 result : ndarray of shape (n_rows, len(col_indices)) 

56 Extracted columns. 

57 """ 

58 n_cols = len(col_indices) 

59 result = np.empty((n_rows, n_cols)) 

60 for i in range(n_rows): 

61 for j in range(n_cols): 

62 result[i, j] = content_full[i, col_indices[j]] 

63 return result 

64 

65 

66@numba.jit(nopython=True, cache=True) 

67def extract_submatrix(content_full, row_indices, col_indices): 

68 """ 

69 Extract submatrix from content_full at specified row and column indices. 

70 Replaces the expensive np.ix_ operation. 

71  

72 Parameters 

73 ---------- 

74 content_full : ndarray of shape (n_rows_full, n_cols_full) 

75 Source matrix. 

76 row_indices : ndarray of int64 

77 Row indices to extract. 

78 col_indices : ndarray of int64 

79 Column indices to extract. 

80  

81 Returns 

82 ------- 

83 result : ndarray of shape (len(row_indices), len(col_indices)) 

84 Extracted submatrix. 

85 """ 

86 n_rows = len(row_indices) 

87 n_cols = len(col_indices) 

88 result = np.empty((n_rows, n_cols)) 

89 for i in range(n_rows): 

90 ri = row_indices[i] 

91 for j in range(n_cols): 

92 result[i, j] = content_full[ri, col_indices[j]] 

93 return result 

94 

95 

96@numba.jit(nopython=True, cache=True, parallel=False) 

97def extract_and_assign(content_full, row_indices, col_indices, K, 

98 row_start, col_start, sign): 

99 """ 

100 Extract submatrix and assign directly to K with sign multiplication. 

101 Combines extraction and assignment in one pass for better performance. 

102  

103 Parameters 

104 ---------- 

105 content_full : ndarray of shape (n_rows_full, n_cols_full) 

106 Source matrix. 

107 row_indices : ndarray of int64 

108 Row indices to extract. 

109 col_indices : ndarray of int64 

110 Column indices to extract. 

111 K : ndarray 

112 Target matrix to fill. 

113 row_start : int 

114 Starting row index in K. 

115 col_start : int 

116 Starting column index in K. 

117 sign : float 

118 Sign multiplier (+1.0 or -1.0). 

119 """ 

120 n_rows = len(row_indices) 

121 n_cols = len(col_indices) 

122 for i in range(n_rows): 

123 ri = row_indices[i] 

124 for j in range(n_cols): 

125 K[row_start + i, col_start + j] = content_full[ri, col_indices[j]] * sign 

126 

127 

128@numba.jit(nopython=True, cache=True) 

129def extract_rows_and_assign(content_full, row_indices, K, 

130 row_start, col_start, n_cols, sign): 

131 """ 

132 Extract rows and assign directly to K with sign multiplication. 

133  

134 Parameters 

135 ---------- 

136 content_full : ndarray of shape (n_rows_full, n_cols) 

137 Source matrix. 

138 row_indices : ndarray of int64 

139 Row indices to extract. 

140 K : ndarray 

141 Target matrix to fill. 

142 row_start : int 

143 Starting row index in K. 

144 col_start : int 

145 Starting column index in K. 

146 n_cols : int 

147 Number of columns to copy. 

148 sign : float 

149 Sign multiplier (+1.0 or -1.0). 

150 """ 

151 n_rows = len(row_indices) 

152 for i in range(n_rows): 

153 ri = row_indices[i] 

154 for j in range(n_cols): 

155 K[row_start + i, col_start + j] = content_full[ri, j] * sign 

156 

157 

158@numba.jit(nopython=True, cache=True) 

159def extract_cols_and_assign(content_full, col_indices, K, 

160 row_start, col_start, n_rows, sign): 

161 """ 

162 Extract columns and assign directly to K with sign multiplication. 

163  

164 Parameters 

165 ---------- 

166 content_full : ndarray of shape (n_rows, n_cols_full) 

167 Source matrix. 

168 col_indices : ndarray of int64 

169 Column indices to extract. 

170 K : ndarray 

171 Target matrix to fill. 

172 row_start : int 

173 Starting row index in K. 

174 col_start : int 

175 Starting column index in K. 

176 n_rows : int 

177 Number of rows to copy. 

178 sign : float 

179 Sign multiplier (+1.0 or -1.0). 

180 """ 

181 n_cols = len(col_indices) 

182 for i in range(n_rows): 

183 for j in range(n_cols): 

184 K[row_start + i, col_start + j] = content_full[i, col_indices[j]] * sign 

185 

186 

187# ============================================================================= 

188# Difference computation functions 

189# ============================================================================= 

190 

191def differences_by_dim_func(X1, X2, rays, n_order, oti_module, return_deriv=True, index=-1): 

192 """ 

193 Compute dimension-wise pairwise differences between X1 and X2, 

194 including hypercomplex perturbations in the directions specified by `rays`. 

195  

196 This optimized version pre-calculates the perturbation and uses a single 

197 efficient loop for subtraction, avoiding broadcasting issues with OTI arrays. 

198  

199 Parameters 

200 ---------- 

201 X1 : ndarray of shape (n1, d) 

202 First set of input points with n1 samples in d dimensions. 

203 X2 : ndarray of shape (n2, d) 

204 Second set of input points with n2 samples in d dimensions. 

205 rays : ndarray of shape (d, n_rays) 

206 Directional vectors for derivative computation. 

207 n_order : int 

208 The base order used to construct hypercomplex units. 

209 When return_deriv=True, uses order 2*n_order. 

210 When return_deriv=False, uses order n_order. 

211 oti_module : module 

212 The PyOTI static module (e.g., pyoti.static.onumm4n2). 

213 return_deriv : bool, optional (default=True) 

214 If True, use order 2*n_order for hypercomplex units (needed for  

215 derivative-derivative blocks in training kernel). 

216 If False, use order n_order (sufficient for prediction without  

217 derivative outputs). 

218 index : int, optional 

219 Currently unused. Reserved for future enhancements. 

220  

221 Returns 

222 ------- 

223 differences_by_dim : list of length d 

224 A list where each element is an array of shape (n1, n2), containing  

225 the differences between corresponding dimensions of X1 and X2,  

226 augmented with directional hypercomplex perturbations. 

227  

228 Notes 

229 ----- 

230 - The function leverages hypercomplex arithmetic from the pyOTI library. 

231 - The directional perturbation is computed as: perts = rays @ e_bases 

232 where e_bases are the hypercomplex units for each ray direction. 

233 - This routine is typically used in the construction of directional  

234 derivative kernels for Gaussian processes. 

235  

236 Example 

237 ------- 

238 >>> X1 = np.array([[1.0, 2.0], [3.0, 4.0]]) 

239 >>> X2 = np.array([[1.5, 2.5], [3.5, 4.5]]) 

240 >>> rays = np.eye(2) # Standard basis directions 

241 >>> n_order = 1 

242 >>> oti_module = get_oti_module(2, 1) # dim=2, n_order=1 

243 >>> diffs = differences_by_dim_func(X1, X2, rays, n_order, oti_module) 

244 >>> len(diffs) 

245 2 

246 >>> diffs[0].shape 

247 (2, 2) 

248 """ 

249 X1 = oti_module.array(X1) 

250 X2 = oti_module.array(X2) 

251 n1, d = X1.shape 

252 n2, _ = X2.shape 

253 n_rays = rays.shape[1] 

254 

255 differences_by_dim = [] 

256 

257 # Case 1: n_order == 0 (no hypercomplex perturbation) 

258 if n_order == 0: 

259 for k in range(d): 

260 diffs_k = oti_module.zeros((n1, n2)) 

261 for i in range(n1): 

262 diffs_k[i, :] = X1[i, k] - oti_module.transpose(X2[:, k]) 

263 differences_by_dim.append(diffs_k) 

264 return differences_by_dim 

265 

266 # Determine the order for hypercomplex units based on return_deriv 

267 if return_deriv: 

268 hc_order = 2 * n_order 

269 else: 

270 hc_order = n_order 

271 

272 # Pre-calculate the perturbation vector using directional rays 

273 e_bases = [oti_module.e(i + 1, order=hc_order) for i in range(n_rays)] 

274 perts = np.dot(rays, e_bases) 

275 

276 # Case 2: return_deriv=False (prediction without derivative outputs) 

277 if not return_deriv: 

278 for k in range(d): 

279 # Add the pre-calculated perturbation for the current dimension to all points in X1 

280 X1_k_tagged = X1[:, k] + perts[k] 

281 X2_k = X2[:, k] 

282 

283 # Pre-allocate the result matrix for this dimension 

284 diffs_k = oti_module.zeros((n1, n2)) 

285 

286 # Use an efficient single loop for subtraction 

287 for i in range(n1): 

288 diffs_k[i, :] = X1_k_tagged[i, 0] - X2_k[:, 0].T 

289 

290 differences_by_dim.append(diffs_k) 

291 

292 # Case 3: return_deriv=True (training kernel with derivative-derivative blocks) 

293 else: 

294 for k in range(d): 

295 X2_k = X2[:, k] 

296 

297 # Pre-allocate the result matrix for this dimension 

298 diffs_k = oti_module.zeros((n1, n2)) 

299 

300 # Compute differences without perturbation first 

301 for i in range(n1): 

302 diffs_k[i, :] = X1[i, k] - X2_k[:, 0].T 

303 

304 # Add perturbation to the entire matrix (more efficient) 

305 differences_by_dim.append(diffs_k + perts[k]) 

306 

307 return differences_by_dim 

308 

309# ============================================================================= 

310# Derivative mapping utilities 

311# ============================================================================= 

312 

313def deriv_map(nbases, order): 

314 """ 

315 Creates a mapping from (order, index_within_order) to a single 

316 flattened index for all derivative components. 

317  

318 Parameters 

319 ---------- 

320 nbases : int 

321 Number of base dimensions. 

322 order : int 

323 Maximum derivative order. 

324  

325 Returns 

326 ------- 

327 map_deriv : list of lists 

328 Mapping where map_deriv[order][idx] gives the flattened index. 

329 """ 

330 k = 0 

331 map_deriv = [] 

332 for ordi in range(order + 1): 

333 ndir = coti.ndir_order(nbases, ordi) 

334 map_deriv_i = [0] * ndir 

335 for idx in range(ndir): 

336 map_deriv_i[idx] = k 

337 k += 1 

338 map_deriv.append(map_deriv_i) 

339 return map_deriv 

340 

341 

342def transform_der_indices(der_indices, der_map): 

343 """ 

344 Transforms a list of user-facing derivative specifications into the 

345 internal (order, index) format and the final flattened index. 

346  

347 Parameters 

348 ---------- 

349 der_indices : list 

350 User-facing derivative specifications. 

351 der_map : list of lists 

352 Derivative mapping from deriv_map(). 

353  

354 Returns 

355 ------- 

356 deriv_ind_transf : list 

357 Flattened indices for each derivative. 

358 deriv_ind_order : list 

359 (index, order) tuples for each derivative. 

360 """ 

361 deriv_ind_transf = [] 

362 deriv_ind_order = [] 

363 for deriv in der_indices: 

364 imdir = coti.imdir(deriv) 

365 idx, order = imdir 

366 deriv_ind_transf.append(der_map[order][idx]) 

367 deriv_ind_order.append(imdir) 

368 return deriv_ind_transf, deriv_ind_order 

369 

370 

371# ============================================================================= 

372# RBF Kernel Assembly Functions (Optimized with Numba) 

373# ============================================================================= 

374 

375@profile 

376def rbf_kernel( 

377 phi, 

378 phi_exp, 

379 n_order, 

380 n_bases, 

381 der_indices, 

382 powers, 

383 index=-1 

384): 

385 """ 

386 Assembles the full DD-GP covariance matrix using an efficient, pre-computed 

387 derivative array and block-wise matrix filling. 

388  

389 Supports both uniform blocks (all derivatives at all points) and non-contiguous 

390 indices (different derivatives at different subsets of points). 

391  

392 This version uses Numba-accelerated functions for efficient matrix slicing, 

393 replacing expensive np.ix_ operations. 

394  

395 Parameters 

396 ---------- 

397 phi : OTI array 

398 Base kernel matrix from kernel_func(differences, length_scales). 

399 phi_exp : ndarray 

400 Expanded derivative array from phi.get_all_derivs(). 

401 n_order : int 

402 Maximum derivative order considered. 

403 n_bases : int 

404 Number of input dimensions (rays). 

405 der_indices : list of lists 

406 Multi-index derivative structures for each derivative component. 

407 powers : list of int 

408 Powers of (-1) applied to each term (for symmetry or sign conventions). 

409 index : list of lists or int, optional (default=-1) 

410 If empty list, assumes all derivative types apply to all training points. 

411 If provided, specifies which training point indices have each derivative type, 

412 allowing non-contiguous index support and variable block sizes. 

413  

414 Returns 

415 ------- 

416 K : ndarray 

417 Full kernel matrix with function values and derivative blocks. 

418 """ 

419 # --- 1. Initial Setup and Efficient Derivative Extraction --- 

420 dh = coti.get_dHelp() 

421 

422 # Create maps to translate derivative specifications to flat indices 

423 der_map = deriv_map(n_bases, 2 * n_order) 

424 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map) 

425 

426 # --- 2. Determine Block Sizes and Pre-allocate Matrix --- 

427 n_rows_func, n_cols_func = phi.shape 

428 n_deriv_types = len(der_indices) 

429 

430 # Pre-compute signs (avoid repeated exponentiation) 

431 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64) 

432 

433 # Convert index lists to numpy arrays for numba (if provided) 

434 if isinstance(index, list) and len(index) > 0: 

435 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index] 

436 else: 

437 index_arrays = [] 

438 

439 n_pts_with_derivs_cols = sum(len(idx) for idx in index_arrays) if index_arrays else 0 

440 n_pts_with_derivs_rows = n_pts_with_derivs_cols 

441 total_rows = n_rows_func + n_pts_with_derivs_rows 

442 total_cols = n_cols_func + n_pts_with_derivs_cols 

443 

444 K = np.zeros((total_rows, total_cols)) 

445 base_shape = (n_rows_func, n_cols_func) 

446 

447 # --- 3. Fill the Matrix Block by Block --- 

448 

449 # Block (0,0): Function-Function (K_ff) 

450 content_full = phi_exp[0].reshape(base_shape) 

451 K[:n_rows_func, :n_cols_func] = content_full * signs[0] 

452 

453 if not index_arrays: 

454 # No derivative indices provided, return early 

455 return K 

456 

457 # First Block-Column: Derivative-Function (K_df) 

458 row_offset = n_rows_func 

459 for i in range(n_deriv_types): 

460 flat_idx = der_indices_tr[i] 

461 content_full = phi_exp[flat_idx].reshape(base_shape) 

462 

463 row_indices = index_arrays[i] 

464 n_pts_this_order = len(row_indices) 

465 

466 # Use numba for efficient row extraction and assignment 

467 extract_rows_and_assign(content_full, row_indices, K, 

468 row_offset, 0, n_cols_func, signs[0]) 

469 row_offset += n_pts_this_order 

470 

471 # First Block-Row: Function-Derivative (K_fd) 

472 col_offset = n_cols_func 

473 for j in range(n_deriv_types): 

474 flat_idx = der_indices_tr[j] 

475 content_full = phi_exp[flat_idx].reshape(base_shape) 

476 

477 col_indices = index_arrays[j] 

478 n_pts_this_order = len(col_indices) 

479 

480 # Use numba for efficient column extraction and assignment 

481 extract_cols_and_assign(content_full, col_indices, K, 

482 0, col_offset, n_rows_func, signs[j + 1]) 

483 col_offset += n_pts_this_order 

484 

485 # Inner Blocks: Derivative-Derivative (K_dd) 

486 row_offset = n_rows_func 

487 for i in range(n_deriv_types): 

488 col_offset = n_cols_func 

489 

490 row_indices = index_arrays[i] 

491 n_pts_row = len(row_indices) 

492 

493 for j in range(n_deriv_types): 

494 col_indices = index_arrays[j] 

495 n_pts_col = len(col_indices) 

496 

497 # Multiply derivative indices to find correct flat index 

498 imdir1 = der_ind_order[j] 

499 imdir2 = der_ind_order[i] 

500 new_idx, new_ord = dh.mult_dir(imdir1[0], imdir1[1], imdir2[0], imdir2[1]) 

501 flat_idx = der_map[new_ord][new_idx] 

502 content_full = phi_exp[flat_idx].reshape(base_shape) 

503 

504 # Use numba for efficient submatrix extraction and assignment 

505 # This replaces the expensive np.ix_ operation 

506 extract_and_assign(content_full, row_indices, col_indices, K, 

507 row_offset, col_offset, signs[j + 1]) 

508 

509 col_offset += n_pts_col 

510 

511 row_offset += n_pts_row 

512 

513 return K 

514 

515 

516@numba.jit(nopython=True, cache=True) 

517def _assemble_kernel_numba(phi_exp_3d, K, n_rows_func, n_cols_func, 

518 fd_flat_indices, df_flat_indices, dd_flat_indices, 

519 idx_flat, idx_offsets, idx_sizes, 

520 signs, n_deriv_types, row_offsets, col_offsets): 

521 """ 

522 Fused numba kernel that assembles the entire K matrix in a single call. 

523 Handles ff, fd, df, and dd blocks without Python-level loop overhead. 

524 """ 

525 # Block (0,0): Function-Function 

526 s0 = signs[0] 

527 for r in range(n_rows_func): 

528 for c in range(n_cols_func): 

529 K[r, c] = phi_exp_3d[0, r, c] * s0 

530 

531 # First Block-Row: Function-Derivative (fd) 

532 for j in range(n_deriv_types): 

533 fi = fd_flat_indices[j] 

534 sj = signs[j + 1] 

535 co = col_offsets[j] 

536 off_j = idx_offsets[j] 

537 sz_j = idx_sizes[j] 

538 for r in range(n_rows_func): 

539 for k in range(sz_j): 

540 ci = idx_flat[off_j + k] 

541 K[r, co + k] = phi_exp_3d[fi, r, ci] * sj 

542 

543 # First Block-Column: Derivative-Function (df) 

544 for i in range(n_deriv_types): 

545 fi = df_flat_indices[i] 

546 ro = row_offsets[i] 

547 off_i = idx_offsets[i] 

548 sz_i = idx_sizes[i] 

549 for k in range(sz_i): 

550 ri = idx_flat[off_i + k] 

551 for c in range(n_cols_func): 

552 K[ro + k, c] = phi_exp_3d[fi, ri, c] * s0 

553 

554 # Inner Blocks: Derivative-Derivative (dd) 

555 for i in range(n_deriv_types): 

556 ro = row_offsets[i] 

557 off_i = idx_offsets[i] 

558 sz_i = idx_sizes[i] 

559 for j in range(n_deriv_types): 

560 fi = dd_flat_indices[i, j] 

561 sj = signs[j + 1] 

562 co = col_offsets[j] 

563 off_j = idx_offsets[j] 

564 sz_j = idx_sizes[j] 

565 for ki in range(sz_i): 

566 ri = idx_flat[off_i + ki] 

567 for kj in range(sz_j): 

568 ci = idx_flat[off_j + kj] 

569 K[ro + ki, co + kj] = phi_exp_3d[fi, ri, ci] * sj 

570 

571 

572@numba.jit(nopython=True, cache=True) 

573def _project_W_to_phi_space(W, W_proj, n_rows_func, n_cols_func, 

574 fd_flat_indices, df_flat_indices, dd_flat_indices, 

575 idx_flat, idx_offsets, idx_sizes, 

576 signs, n_deriv_types, row_offsets, col_offsets): 

577 """ 

578 Reverse of _assemble_kernel_numba: project W from K-space back into 

579 phi_exp-space so that vdot(W, assemble(dphi_exp)) == vdot(W_proj, dphi_exp). 

580 """ 

581 for d in range(W_proj.shape[0]): 

582 for r in range(W_proj.shape[1]): 

583 for c in range(W_proj.shape[2]): 

584 W_proj[d, r, c] = 0.0 

585 s0 = signs[0] 

586 for r in range(n_rows_func): 

587 for c in range(n_cols_func): 

588 W_proj[0, r, c] += s0 * W[r, c] 

589 for j in range(n_deriv_types): 

590 fi = fd_flat_indices[j] 

591 sj = signs[j + 1] 

592 co = col_offsets[j] 

593 off_j = idx_offsets[j] 

594 sz_j = idx_sizes[j] 

595 for r in range(n_rows_func): 

596 for k in range(sz_j): 

597 ci = idx_flat[off_j + k] 

598 W_proj[fi, r, ci] += sj * W[r, co + k] 

599 for i in range(n_deriv_types): 

600 fi = df_flat_indices[i] 

601 ro = row_offsets[i] 

602 off_i = idx_offsets[i] 

603 sz_i = idx_sizes[i] 

604 for k in range(sz_i): 

605 ri = idx_flat[off_i + k] 

606 for c in range(n_cols_func): 

607 W_proj[fi, ri, c] += s0 * W[ro + k, c] 

608 for i in range(n_deriv_types): 

609 ro = row_offsets[i] 

610 off_i = idx_offsets[i] 

611 sz_i = idx_sizes[i] 

612 for j in range(n_deriv_types): 

613 fi = dd_flat_indices[i, j] 

614 sj = signs[j + 1] 

615 co = col_offsets[j] 

616 off_j = idx_offsets[j] 

617 sz_j = idx_sizes[j] 

618 for ki in range(sz_i): 

619 ri = idx_flat[off_i + ki] 

620 for kj in range(sz_j): 

621 ci = idx_flat[off_j + kj] 

622 W_proj[fi, ri, ci] += sj * W[ro + ki, co + kj] 

623 

624 

625def precompute_kernel_plan(n_order, n_bases, der_indices, powers, index): 

626 """ 

627 Precompute all structural information needed by rbf_kernel so it can be 

628 reused across repeated calls with different phi_exp values. 

629 

630 Returns a dict containing flat indices, signs, index arrays, precomputed 

631 offsets/sizes, and mult_dir results for the dd block. 

632 """ 

633 dh = coti.get_dHelp() 

634 der_map = deriv_map(n_bases, 2 * n_order) 

635 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map) 

636 

637 n_deriv_types = len(der_indices) 

638 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64) 

639 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index] 

640 

641 # Precompute sizes and offsets 

642 index_sizes = np.array([len(idx) for idx in index_arrays], dtype=np.int64) 

643 n_pts_with_derivs = int(index_sizes.sum()) 

644 

645 # Pack all index arrays into a single flat array with offsets 

646 idx_flat = np.concatenate(index_arrays) if n_deriv_types > 0 else np.array([], dtype=np.int64) 

647 idx_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

648 for i in range(1, n_deriv_types): 

649 idx_offsets[i] = idx_offsets[i - 1] + index_sizes[i - 1] 

650 

651 # Precompute row/col offsets in K for each deriv type 

652 row_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

653 col_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

654 # Note: n_rows_func == n_cols_func for training kernel, but we store 

655 # offsets relative to n_rows_func which is added at call time 

656 cumsum = 0 

657 for i in range(n_deriv_types): 

658 row_offsets[i] = cumsum # relative to n_rows_func 

659 col_offsets[i] = cumsum # relative to n_cols_func 

660 cumsum += index_sizes[i] 

661 

662 # Precompute mult_dir results for dd blocks 

663 dd_flat_indices = np.empty((n_deriv_types, n_deriv_types), dtype=np.int64) 

664 for i in range(n_deriv_types): 

665 for j in range(n_deriv_types): 

666 imdir1 = der_ind_order[j] 

667 imdir2 = der_ind_order[i] 

668 new_idx, new_ord = dh.mult_dir( 

669 imdir1[0], imdir1[1], imdir2[0], imdir2[1]) 

670 dd_flat_indices[i, j] = der_map[new_ord][new_idx] 

671 

672 # fd and df flat indices as arrays 

673 fd_flat_indices = np.array(der_indices_tr, dtype=np.int64) 

674 df_flat_indices = np.array(der_indices_tr, dtype=np.int64) 

675 

676 return { 

677 'der_indices_tr': der_indices_tr, 

678 'signs': signs, 

679 'index_arrays': index_arrays, 

680 'index_sizes': index_sizes, 

681 'n_pts_with_derivs': n_pts_with_derivs, 

682 'dd_flat_indices': dd_flat_indices, 

683 'n_deriv_types': n_deriv_types, 

684 # Fused kernel data 

685 'idx_flat': idx_flat, 

686 'idx_offsets': idx_offsets, 

687 'row_offsets': row_offsets, 

688 'col_offsets': col_offsets, 

689 'fd_flat_indices': fd_flat_indices, 

690 'df_flat_indices': df_flat_indices, 

691 } 

692 

693 

694def rbf_kernel_fast(phi_exp_3d, plan, out=None): 

695 """ 

696 Fast kernel assembly using a precomputed plan and fused numba kernel. 

697 

698 Parameters 

699 ---------- 

700 phi_exp_3d : ndarray of shape (n_derivs, n_rows_func, n_cols_func) 

701 Pre-reshaped expanded derivative array. 

702 plan : dict 

703 Precomputed plan from precompute_kernel_plan(). 

704 out : ndarray, optional 

705 Pre-allocated output array. If None, a new array is allocated. 

706 

707 Returns 

708 ------- 

709 K : ndarray 

710 Full kernel matrix. 

711 """ 

712 n_rows_func = phi_exp_3d.shape[1] 

713 n_cols_func = phi_exp_3d.shape[2] 

714 total = n_rows_func + plan['n_pts_with_derivs'] 

715 if out is not None: 

716 K = out 

717 else: 

718 K = np.empty((total, total)) 

719 

720 if 'row_offsets_abs' in plan: 

721 row_off = plan['row_offsets_abs'] 

722 col_off = plan['col_offsets_abs'] 

723 else: 

724 row_off = plan['row_offsets'] + n_rows_func 

725 col_off = plan['col_offsets'] + n_cols_func 

726 

727 _assemble_kernel_numba( 

728 phi_exp_3d, K, n_rows_func, n_cols_func, 

729 plan['fd_flat_indices'], plan['df_flat_indices'], plan['dd_flat_indices'], 

730 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'], 

731 plan['signs'], plan['n_deriv_types'], row_off, col_off, 

732 ) 

733 

734 return K 

735 

736 

737def rbf_kernel_predictions( 

738 phi, 

739 phi_exp, 

740 n_order, 

741 n_bases, 

742 der_indices, 

743 powers, 

744 return_deriv, 

745 index=-1, 

746 common_derivs=None, 

747 calc_cov=False, 

748 powers_predict=None 

749): 

750 """ 

751 Constructs the RBF kernel matrix for predictions with directional derivative entries. 

752  

753 This handles the asymmetric case where: 

754 - Rows: Test points (predictions) 

755 - Columns: Training points (with derivative structure from index) 

756  

757 This version uses Numba-accelerated functions for efficient matrix slicing. 

758 

759 Parameters 

760 ---------- 

761 phi : OTI array 

762 Base kernel matrix between test and training points. 

763 phi_exp : ndarray 

764 Expanded derivative array from phi.get_all_derivs(). 

765 n_order : int 

766 Maximum derivative order. 

767 n_bases : int 

768 Number of input dimensions (rays). 

769 der_indices : list 

770 Derivative specifications for training data. 

771 powers : list of int 

772 Sign powers for each derivative type. 

773 return_deriv : bool 

774 If True, predict derivatives at ALL test points. 

775 index : list of lists or int, optional (default=-1) 

776 Training point indices for each derivative type. 

777 common_derivs : list, optional 

778 Common derivative indices to predict (intersection of training and requested). 

779 calc_cov : bool, optional (default=False) 

780 If True, computing covariance (use all indices for rows). 

781 powers_predict : list of int, optional 

782 Sign powers for prediction derivatives. 

783 

784 Returns 

785 ------- 

786 K : ndarray 

787 Prediction kernel matrix. 

788 """ 

789 # --- 1. Initial Setup --- 

790 if calc_cov and not return_deriv: 

791 return phi.real 

792 

793 dh = coti.get_dHelp() 

794 

795 # Pre-compute signs 

796 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64) 

797 if powers_predict is not None: 

798 signs_predict = np.array([(-1.0) ** p for p in powers_predict], dtype=np.float64) 

799 else: 

800 signs_predict = signs 

801 

802 # --- 2. Determine Block Sizes and Pre-allocate Matrix --- 

803 n_rows_func, n_cols_func = phi.shape 

804 n_deriv_types = len(der_indices) 

805 n_deriv_types_pred = len(common_derivs) if common_derivs else 0 

806 

807 # Convert index to numpy arrays 

808 if isinstance(index, list) and len(index) > 0 and isinstance(index[0], (list, np.ndarray)): 

809 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index] 

810 else: 

811 index_arrays = [] 

812 

813 if return_deriv: 

814 der_map = deriv_map(n_bases, 2 * n_order) 

815 index_2 = np.arange(n_cols_func, dtype=np.int64) 

816 if calc_cov: 

817 index_cov = np.arange(n_cols_func, dtype=np.int64) 

818 n_deriv_types = n_deriv_types_pred 

819 n_pts_with_derivs_rows = n_deriv_types * n_cols_func 

820 else: 

821 n_pts_with_derivs_rows = sum(len(idx) for idx in index_arrays) if index_arrays else 0 

822 else: 

823 der_map = deriv_map(n_bases, n_order) 

824 index_2 = np.array([], dtype=np.int64) 

825 n_pts_with_derivs_rows = sum(len(idx) for idx in index_arrays) if index_arrays else 0 

826 

827 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map) 

828 

829 if common_derivs: 

830 der_indices_tr_pred, der_ind_order_pred = transform_der_indices(common_derivs, der_map) 

831 else: 

832 der_indices_tr_pred, der_ind_order_pred = [], [] 

833 

834 n_pts_with_derivs_cols = n_deriv_types_pred * len(index_2) 

835 

836 total_rows = n_rows_func + n_pts_with_derivs_rows 

837 total_cols = n_cols_func + n_pts_with_derivs_cols 

838 

839 K = np.zeros((total_rows, total_cols)) 

840 base_shape = (n_rows_func, n_cols_func) 

841 

842 # --- 3. Fill the Matrix Block by Block --- 

843 

844 # Block (0,0): Function-Function (K_ff) 

845 content_full = phi_exp[0].reshape(base_shape) 

846 K[:n_rows_func, :n_cols_func] = content_full * signs[0] 

847 

848 if not return_deriv: 

849 # First Block-Column: Derivative-Function (K_df) 

850 row_offset = n_rows_func 

851 for i in range(n_deriv_types): 

852 if not index_arrays: 

853 break 

854 

855 row_indices = index_arrays[i] 

856 n_pts_row = len(row_indices) 

857 

858 flat_idx = der_indices_tr[i] 

859 content_full = phi_exp[flat_idx].reshape(base_shape) 

860 

861 # Use numba for efficient row extraction 

862 extract_rows_and_assign(content_full, row_indices, K, 

863 row_offset, 0, n_cols_func, signs[0]) 

864 row_offset += n_pts_row 

865 return K 

866 

867 # --- return_deriv=True case --- 

868 

869 # First Block-Row: Function-Derivative (K_fd) 

870 col_offset = n_cols_func 

871 for j in range(n_deriv_types_pred): 

872 n_pts_col = len(index_2) 

873 

874 flat_idx = der_indices_tr_pred[j] 

875 content_full = phi_exp[flat_idx].reshape(base_shape) 

876 

877 # Use numba for efficient column extraction 

878 extract_cols_and_assign(content_full, index_2, K, 

879 0, col_offset, n_rows_func, signs_predict[j + 1]) 

880 col_offset += n_pts_col 

881 

882 # First Block-Column: Derivative-Function (K_df) 

883 row_offset = n_rows_func 

884 for i in range(n_deriv_types): 

885 if calc_cov: 

886 row_indices = index_cov 

887 flat_idx = der_indices_tr_pred[i] 

888 else: 

889 if not index_arrays: 

890 break 

891 row_indices = index_arrays[i] 

892 flat_idx = der_indices_tr[i] 

893 n_pts_row = len(row_indices) 

894 

895 content_full = phi_exp[flat_idx].reshape(base_shape) 

896 

897 # Use numba for efficient row extraction 

898 extract_rows_and_assign(content_full, row_indices, K, 

899 row_offset, 0, n_cols_func, signs[0]) 

900 row_offset += n_pts_row 

901 

902 # Inner Blocks: Derivative-Derivative (K_dd) 

903 row_offset = n_rows_func 

904 for i in range(n_deriv_types): 

905 if calc_cov: 

906 row_indices = index_cov 

907 else: 

908 if not index_arrays: 

909 break 

910 row_indices = index_arrays[i] 

911 n_pts_row = len(row_indices) 

912 

913 col_offset = n_cols_func 

914 for j in range(n_deriv_types_pred): 

915 n_pts_col = len(index_2) 

916 

917 # Multiply derivative indices to find correct flat index 

918 imdir1 = der_ind_order_pred[j] 

919 imdir2 = der_ind_order_pred[i] if calc_cov else der_ind_order[i] 

920 new_idx, new_ord = dh.mult_dir(imdir1[0], imdir1[1], imdir2[0], imdir2[1]) 

921 flat_idx = der_map[new_ord][new_idx] 

922 

923 content_full = phi_exp[flat_idx].reshape(base_shape) 

924 

925 # Use numba for efficient submatrix extraction and assignment 

926 extract_and_assign(content_full, row_indices, index_2, K, 

927 row_offset, col_offset, signs_predict[j + 1]) 

928 col_offset += n_pts_col 

929 row_offset += n_pts_row 

930 

931 return K