Coverage for jetgp/full_gddegp/gddegp_utils.py: 65%

433 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-04-03 15:09 -0500

1import numpy as np 

2import pyoti.core as coti 

3from line_profiler import profile 

4import numba 

5 

6 

7# ============================================================================= 

8# Numba-accelerated helper functions for efficient matrix slicing 

9# ============================================================================= 

10 

11@numba.jit(nopython=True, cache=True) 

12def extract_rows(content_full, row_indices, n_cols): 

13 """ 

14 Extract rows from content_full at specified indices. 

15  

16 Parameters 

17 ---------- 

18 content_full : ndarray of shape (n_rows_full, n_cols) 

19 Source matrix. 

20 row_indices : ndarray of int64 

21 Row indices to extract. 

22 n_cols : int 

23 Number of columns. 

24  

25 Returns 

26 ------- 

27 result : ndarray of shape (len(row_indices), n_cols) 

28 Extracted rows. 

29 """ 

30 n_rows = len(row_indices) 

31 result = np.empty((n_rows, n_cols)) 

32 for i in range(n_rows): 

33 ri = row_indices[i] 

34 for j in range(n_cols): 

35 result[i, j] = content_full[ri, j] 

36 return result 

37 

38 

39@numba.jit(nopython=True, cache=True) 

40def extract_cols(content_full, col_indices, n_rows): 

41 """ 

42 Extract columns from content_full at specified indices. 

43  

44 Parameters 

45 ---------- 

46 content_full : ndarray of shape (n_rows, n_cols_full) 

47 Source matrix. 

48 col_indices : ndarray of int64 

49 Column indices to extract. 

50 n_rows : int 

51 Number of rows. 

52  

53 Returns 

54 ------- 

55 result : ndarray of shape (n_rows, len(col_indices)) 

56 Extracted columns. 

57 """ 

58 n_cols = len(col_indices) 

59 result = np.empty((n_rows, n_cols)) 

60 for i in range(n_rows): 

61 for j in range(n_cols): 

62 result[i, j] = content_full[i, col_indices[j]] 

63 return result 

64 

65 

66@numba.jit(nopython=True, cache=True) 

67def extract_submatrix(content_full, row_indices, col_indices): 

68 """ 

69 Extract submatrix from content_full at specified row and column indices. 

70 Replaces the expensive np.ix_ operation. 

71  

72 Parameters 

73 ---------- 

74 content_full : ndarray of shape (n_rows_full, n_cols_full) 

75 Source matrix. 

76 row_indices : ndarray of int64 

77 Row indices to extract. 

78 col_indices : ndarray of int64 

79 Column indices to extract. 

80  

81 Returns 

82 ------- 

83 result : ndarray of shape (len(row_indices), len(col_indices)) 

84 Extracted submatrix. 

85 """ 

86 n_rows = len(row_indices) 

87 n_cols = len(col_indices) 

88 result = np.empty((n_rows, n_cols)) 

89 for i in range(n_rows): 

90 ri = row_indices[i] 

91 for j in range(n_cols): 

92 result[i, j] = content_full[ri, col_indices[j]] 

93 return result 

94 

95 

96@numba.jit(nopython=True, cache=True) 

97def extract_submatrix_transposed(content_full, row_indices, col_indices): 

98 """ 

99 Extract submatrix and return its transpose. 

100 Replaces content_full[np.ix_(row_indices, col_indices)].T 

101  

102 Parameters 

103 ---------- 

104 content_full : ndarray of shape (n_rows_full, n_cols_full) 

105 Source matrix. 

106 row_indices : ndarray of int64 

107 Row indices to extract. 

108 col_indices : ndarray of int64 

109 Column indices to extract. 

110  

111 Returns 

112 ------- 

113 result : ndarray of shape (len(col_indices), len(row_indices)) 

114 Transposed extracted submatrix. 

115 """ 

116 n_rows = len(row_indices) 

117 n_cols = len(col_indices) 

118 result = np.empty((n_cols, n_rows)) 

119 for i in range(n_rows): 

120 ri = row_indices[i] 

121 for j in range(n_cols): 

122 result[j, i] = content_full[ri, col_indices[j]] 

123 return result 

124 

125 

126@numba.jit(nopython=True, cache=True) 

127def extract_rows_transposed(content_full, row_indices, n_cols): 

128 """ 

129 Extract rows and return transposed result. 

130 Replaces content_full[row_indices, :].T 

131  

132 Parameters 

133 ---------- 

134 content_full : ndarray of shape (n_rows_full, n_cols) 

135 Source matrix. 

136 row_indices : ndarray of int64 

137 Row indices to extract. 

138 n_cols : int 

139 Number of columns. 

140  

141 Returns 

142 ------- 

143 result : ndarray of shape (n_cols, len(row_indices)) 

144 Transposed extracted rows. 

145 """ 

146 n_rows = len(row_indices) 

147 result = np.empty((n_cols, n_rows)) 

148 for i in range(n_rows): 

149 ri = row_indices[i] 

150 for j in range(n_cols): 

151 result[j, i] = content_full[ri, j] 

152 return result 

153 

154 

155@numba.jit(nopython=True, cache=True) 

156def extract_cols_transposed(content_full, col_indices, n_rows): 

157 """ 

158 Extract columns and return transposed result. 

159 Replaces content_full[:, col_indices].T 

160  

161 Parameters 

162 ---------- 

163 content_full : ndarray of shape (n_rows, n_cols_full) 

164 Source matrix. 

165 col_indices : ndarray of int64 

166 Column indices to extract. 

167 n_rows : int 

168 Number of rows. 

169  

170 Returns 

171 ------- 

172 result : ndarray of shape (len(col_indices), n_rows) 

173 Transposed extracted columns. 

174 """ 

175 n_cols = len(col_indices) 

176 result = np.empty((n_cols, n_rows)) 

177 for i in range(n_rows): 

178 for j in range(n_cols): 

179 result[j, i] = content_full[i, col_indices[j]] 

180 return result 

181 

182 

183@numba.jit(nopython=True, cache=True, parallel=False) 

184def extract_and_assign(content_full, row_indices, col_indices, K, 

185 row_start, col_start): 

186 """ 

187 Extract submatrix and assign directly to K. 

188  

189 Parameters 

190 ---------- 

191 content_full : ndarray of shape (n_rows_full, n_cols_full) 

192 Source matrix. 

193 row_indices : ndarray of int64 

194 Row indices to extract. 

195 col_indices : ndarray of int64 

196 Column indices to extract. 

197 K : ndarray 

198 Target matrix to fill. 

199 row_start : int 

200 Starting row index in K. 

201 col_start : int 

202 Starting column index in K. 

203 """ 

204 n_rows = len(row_indices) 

205 n_cols = len(col_indices) 

206 for i in range(n_rows): 

207 ri = row_indices[i] 

208 for j in range(n_cols): 

209 K[row_start + i, col_start + j] = content_full[ri, col_indices[j]] 

210 

211 

212@numba.jit(nopython=True, cache=True, parallel=False) 

213def extract_and_assign_transposed(content_full, row_indices, col_indices, K, 

214 row_start, col_start): 

215 """ 

216 Extract submatrix and assign its transpose directly to K. 

217 Replaces K[...] = content_full[np.ix_(row_indices, col_indices)].T 

218  

219 Parameters 

220 ---------- 

221 content_full : ndarray of shape (n_rows_full, n_cols_full) 

222 Source matrix. 

223 row_indices : ndarray of int64 

224 Row indices to extract from content_full. 

225 col_indices : ndarray of int64 

226 Column indices to extract from content_full. 

227 K : ndarray 

228 Target matrix to fill. 

229 row_start : int 

230 Starting row index in K. 

231 col_start : int 

232 Starting column index in K. 

233 """ 

234 n_rows = len(row_indices) 

235 n_cols = len(col_indices) 

236 for i in range(n_rows): 

237 ri = row_indices[i] 

238 for j in range(n_cols): 

239 # Transposed assignment: K[col_idx, row_idx] = content[row_idx, col_idx] 

240 K[row_start + j, col_start + i] = content_full[ri, col_indices[j]] 

241 

242 

243@numba.jit(nopython=True, cache=True) 

244def extract_rows_and_assign(content_full, row_indices, K, 

245 row_start, col_start, n_cols): 

246 """ 

247 Extract rows and assign directly to K. 

248  

249 Parameters 

250 ---------- 

251 content_full : ndarray of shape (n_rows_full, n_cols) 

252 Source matrix. 

253 row_indices : ndarray of int64 

254 Row indices to extract. 

255 K : ndarray 

256 Target matrix to fill. 

257 row_start : int 

258 Starting row index in K. 

259 col_start : int 

260 Starting column index in K. 

261 n_cols : int 

262 Number of columns to copy. 

263 """ 

264 n_rows = len(row_indices) 

265 for i in range(n_rows): 

266 ri = row_indices[i] 

267 for j in range(n_cols): 

268 K[row_start + i, col_start + j] = content_full[ri, j] 

269 

270 

271@numba.jit(nopython=True, cache=True) 

272def extract_cols_and_assign(content_full, col_indices, K, 

273 row_start, col_start, n_rows): 

274 """ 

275 Extract columns and assign directly to K. 

276  

277 Parameters 

278 ---------- 

279 content_full : ndarray of shape (n_rows, n_cols_full) 

280 Source matrix. 

281 col_indices : ndarray of int64 

282 Column indices to extract. 

283 K : ndarray 

284 Target matrix to fill. 

285 row_start : int 

286 Starting row index in K. 

287 col_start : int 

288 Starting column index in K. 

289 n_rows : int 

290 Number of rows to copy. 

291 """ 

292 n_cols = len(col_indices) 

293 for i in range(n_rows): 

294 for j in range(n_cols): 

295 K[row_start + i, col_start + j] = content_full[i, col_indices[j]] 

296 

297 

298@numba.jit(nopython=True, cache=True) 

299def extract_rows_and_assign_transposed(content_full, row_indices, K, 

300 row_start, col_start, n_cols): 

301 """ 

302 Extract rows and assign transposed result directly to K. 

303 Replaces K[...] = content_full[row_indices, :].T 

304  

305 Parameters 

306 ---------- 

307 content_full : ndarray of shape (n_rows_full, n_cols) 

308 Source matrix. 

309 row_indices : ndarray of int64 

310 Row indices to extract. 

311 K : ndarray 

312 Target matrix to fill. 

313 row_start : int 

314 Starting row index in K. 

315 col_start : int 

316 Starting column index in K. 

317 n_cols : int 

318 Number of columns in content_full. 

319 """ 

320 n_rows = len(row_indices) 

321 for i in range(n_rows): 

322 ri = row_indices[i] 

323 for j in range(n_cols): 

324 K[row_start + j, col_start + i] = content_full[ri, j] 

325 

326 

327@numba.jit(nopython=True, cache=True) 

328def extract_cols_and_assign_transposed(content_full, col_indices, K, 

329 row_start, col_start, n_rows): 

330 """ 

331 Extract columns and assign transposed result directly to K. 

332 Replaces K[...] = content_full[:, col_indices].T 

333  

334 Parameters 

335 ---------- 

336 content_full : ndarray of shape (n_rows, n_cols_full) 

337 Source matrix. 

338 col_indices : ndarray of int64 

339 Column indices to extract. 

340 K : ndarray 

341 Target matrix to fill. 

342 row_start : int 

343 Starting row index in K. 

344 col_start : int 

345 Starting column index in K. 

346 n_rows : int 

347 Number of rows in content_full. 

348 """ 

349 n_cols = len(col_indices) 

350 for i in range(n_rows): 

351 for j in range(n_cols): 

352 K[row_start + j, col_start + i] = content_full[i, col_indices[j]] 

353 

354 

355# ============================================================================= 

356# Derivative index transformation utilities 

357# ============================================================================= 

358 

359def make_first_odd(der_indices): 

360 """Transform derivative indices to use odd bases (1, 3, 5, ...).""" 

361 result = [] 

362 for group in der_indices: 

363 new_group = [] 

364 for pair in group: 

365 first = pair[0] 

366 new_group.append([2 * first - 1, pair[1]]) 

367 result.append(new_group) 

368 return result 

369 

370 

371def make_first_even(der_indices): 

372 """Transform derivative indices to use even bases (2, 4, 6, ...).""" 

373 result = [] 

374 for group in der_indices: 

375 new_group = [] 

376 for pair in group: 

377 first = pair[0] 

378 new_group.append([2 * first, pair[1]]) 

379 result.append(new_group) 

380 return result 

381 

382 

383# ============================================================================= 

384# Difference computation functions 

385# ============================================================================= 

386def compute_dimension_differences(k, X1, X2, n1, n2, rays_X1, rays_X2, 

387 derivative_locations_X1, derivative_locations_X2, 

388 e_tags_1, e_tags_2, oti_module): 

389 """ 

390 Compute differences for a single dimension k. 

391 Only perturbs points at specified derivative_locations with their corresponding rays. 

392 

393 Parameters 

394 ---------- 

395 k : int 

396 Dimension index. 

397 X1, X2 : oti.array 

398 Input point arrays of shape (n1, d) and (n2, d). 

399 n1, n2 : int 

400 Number of points in X1, X2. 

401 rays_X1 : list of ndarray or None 

402 rays_X1[i] has shape (d, len(derivative_locations_X1[i])). 

403 rays_X2 : list of ndarray or None 

404 rays_X2[i] has shape (d, len(derivative_locations_X2[i])). 

405 derivative_locations_X1 : list of list 

406 derivative_locations_X1[i] contains indices of X1 points with direction i. 

407 derivative_locations_X2 : list of list 

408 derivative_locations_X2[i] contains indices of X2 points with direction i. 

409 e_tags_1, e_tags_2 : list 

410 OTI basis elements for each direction. 

411 oti_module : module 

412 The PyOTI static module. 

413 

414 Returns 

415 ------- 

416 diffs_k : oti.array 

417 Differences for dimension k with shape (n1, n2). 

418 """ 

419 # Build perturbation vector for X1 

420 perturb_X1_values = [0.0] * n1 

421 if rays_X1 is not None: 

422 for dir_idx in range(len(rays_X1)): 

423 locs = derivative_locations_X1[dir_idx] 

424 rays = rays_X1[dir_idx] 

425 for j, pt_idx in enumerate(locs): 

426 perturb_X1_values[pt_idx] = perturb_X1_values[pt_idx] + e_tags_1[dir_idx] * rays[k, j] 

427 

428 # Build perturbation vector for X2 

429 perturb_X2_values = [0.0] * n2 

430 if rays_X2 is not None: 

431 for dir_idx in range(len(rays_X2)): 

432 locs = derivative_locations_X2[dir_idx] 

433 rays = rays_X2[dir_idx] 

434 for j, pt_idx in enumerate(locs): 

435 perturb_X2_values[pt_idx] = perturb_X2_values[pt_idx] + e_tags_2[dir_idx] * rays[k, j] 

436 

437 # Convert to OTI arrays 

438 perturb_X1 = oti_module.array(perturb_X1_values) 

439 perturb_X2 = oti_module.array(perturb_X2_values) 

440 

441 # Tag coordinates 

442 X1_k_tagged = X1[:, k] + perturb_X1 

443 X2_k_tagged = X2[:, k] + perturb_X2 

444 

445 # Compute differences 

446 diffs_k = oti_module.zeros((n1, n2)) 

447 for i in range(n1): 

448 diffs_k[i, :] = X1_k_tagged[i, 0] - oti_module.transpose(X2_k_tagged[:, 0]) 

449 

450 return diffs_k 

451 

452 

453def differences_by_dim_func(X1, X2, rays_X1, rays_X2, derivative_locations_X1, derivative_locations_X2, 

454 n_order, oti_module, return_deriv=True): 

455 """ 

456 Compute dimension-wise differences with OTI tagging on both X1 and X2. 

457 

458 GDDEGP uses a dual-tag OTI scheme: X1 points are tagged with odd bases 

459 (e_1, e_3, e_5, ...) and X2 points with even bases (e_2, e_4, e_6, ...). 

460 This requires ``n_bases = 2 * n_direction_types``. 

461 

462 The dual-tag approach is necessary because each point can have a unique 

463 directional ray, and the kernel matrix requires derivatives with respect to 

464 *both* sets of directions simultaneously. In the difference X1 - X2, the 

465 OTI coefficient for basis e_i at position (a, b) encodes only the ray of 

466 the point that was tagged with e_i. A single-tag scheme (tagging both X1 

467 and X2 with the same basis) would conflate the two rays in the difference, 

468 making it impossible to recover the correct cross-derivative 

469 ``v_i(a)^T H v_j(b)`` needed for K_dd blocks, and producing an asymmetric 

470 K_fd block when rays vary per point. 

471 

472 Parameters 

473 ---------- 

474 X1 : ndarray of shape (n1, d) 

475 First set of input points. 

476 X2 : ndarray of shape (n2, d) 

477 Second set of input points. 

478 rays_X1 : list of ndarray or None 

479 List of ray arrays for X1. rays_X1[i] has shape (d, len(derivative_locations_X1[i])). 

480 rays_X2 : list of ndarray or None 

481 List of ray arrays for X2. rays_X2[i] has shape (d, len(derivative_locations_X2[i])). 

482 derivative_locations_X1 : list of list 

483 derivative_locations_X1[i] contains indices of X1 points with derivative direction i. 

484 derivative_locations_X2 : list of list 

485 derivative_locations_X2[i] contains indices of X2 points with derivative direction i. 

486 n_order : int 

487 Derivative order for OTI tagging. 

488 oti_module : module 

489 The PyOTI static module (e.g., pyoti.static.onumm4n2). 

490 return_deriv : bool, optional 

491 If True, use order 2*n_order for derivative-derivative blocks. 

492 

493 Returns 

494 ------- 

495 differences_by_dim : list of oti.array 

496 List of length d, each element is an (n1, n2) OTI array. 

497 """ 

498 X1 = oti_module.array(X1) 

499 X2 = oti_module.array(X2) 

500 n1, d = X1.shape 

501 n2, _ = X2.shape 

502 

503 # Determine number of derivative directions 

504 m1 = len(rays_X1) if rays_X1 is not None else 0 

505 m2 = len(rays_X2) if rays_X2 is not None else 0 

506 m = max(m1, m2) 

507 

508 # Pre-compute OTI basis elements 

509 e_tags_1 = [] 

510 e_tags_2 = [] 

511 

512 if n_order == 0: 

513 e_tags_1 = [0] * m 

514 e_tags_2 = [0] * m 

515 elif not return_deriv: 

516 for i in range(m): 

517 e_tags_1.append(oti_module.e((2 * i + 1), order=n_order)) 

518 e_tags_2.append(oti_module.e((2 * i + 2), order=n_order)) 

519 else: 

520 for i in range(m): 

521 e_tags_1.append(oti_module.e((2 * i + 1), order=2 * n_order)) 

522 e_tags_2.append(oti_module.e((2 * i + 2), order=2 * n_order)) 

523 

524 # Compute differences for each dimension 

525 differences_by_dim = [] 

526 for k in range(d): 

527 diffs_k = compute_dimension_differences( 

528 k, X1, X2, n1, n2, rays_X1, rays_X2, 

529 derivative_locations_X1, derivative_locations_X2, 

530 e_tags_1, e_tags_2, oti_module 

531 ) 

532 differences_by_dim.append(diffs_k) 

533 

534 return differences_by_dim 

535 

536 

537# ============================================================================= 

538# Derivative mapping utilities 

539# ============================================================================= 

540 

541def deriv_map(nbases, order): 

542 """Create mapping from (order, index) to flattened index.""" 

543 k = 0 

544 map_deriv = [] 

545 for ordi in range(order + 1): 

546 ndir = coti.ndir_order(nbases, ordi) 

547 map_deriv_i = [0] * ndir 

548 for idx in range(ndir): 

549 map_deriv_i[idx] = k 

550 k += 1 

551 map_deriv.append(map_deriv_i) 

552 return map_deriv 

553 

554 

555def transform_der_indices(der_indices, der_map): 

556 """Transform derivative indices to flattened format.""" 

557 deriv_ind_transf = [] 

558 deriv_ind_order = [] 

559 for deriv in der_indices: 

560 imdir = coti.imdir(deriv) 

561 idx, order = imdir 

562 deriv_ind_transf.append(der_map[order][idx]) 

563 deriv_ind_order.append(imdir) 

564 return deriv_ind_transf, deriv_ind_order 

565 

566 

567# ============================================================================= 

568# RBF Kernel Assembly Functions (Optimized with Numba) 

569# ============================================================================= 

570 

571@profile 

572def rbf_kernel( 

573 phi, 

574 phi_exp, 

575 n_order, 

576 n_bases, 

577 der_indices, 

578 index=None 

579): 

580 """ 

581 Assembles the full GDDEGP covariance matrix with selective derivative coverage. 

582  

583 This version uses Numba-accelerated functions for efficient matrix slicing, 

584 replacing expensive np.ix_ operations. 

585 

586 Parameters 

587 ---------- 

588 phi : OTI array 

589 Base kernel matrix from kernel_func(differences, length_scales). 

590 phi_exp : ndarray 

591 Expanded derivative array from phi.get_all_derivs(). 

592 n_order : int 

593 Maximum derivative order. 

594 n_bases : int 

595 Number of OTI bases (must be even). 

596 der_indices : list 

597 Derivative index specifications. 

598 index : list of list 

599 index[i] contains indices of points with derivative direction i. 

600 

601 Returns 

602 ------- 

603 K : ndarray 

604 Kernel matrix with block structure based on derivative locations. 

605 """ 

606 dh = coti.get_dHelp() 

607 

608 assert n_bases % 2 == 0, "n_bases must be an even number." 

609 PHIrows, PHIcols = phi.shape 

610 total_derivs = len(der_indices) 

611 

612 # Compute output matrix dimensions 

613 n_deriv_rows = sum(len(locs) for locs in index) 

614 n_deriv_cols = sum(len(locs) for locs in index) 

615 n_output_rows = PHIrows + n_deriv_rows 

616 n_output_cols = PHIcols + n_deriv_cols 

617 

618 der_map = deriv_map(n_bases, 2 * n_order) 

619 

620 # Pre-compute derivative index transformations 

621 der_indices_even = make_first_even(der_indices) 

622 der_indices_odd = make_first_odd(der_indices) 

623 der_indices_tr_even, der_ind_order_even = transform_der_indices(der_indices_even, der_map) 

624 der_indices_tr_odd, der_ind_order_odd = transform_der_indices(der_indices_odd, der_map) 

625 

626 # Convert index lists to numpy arrays for numba 

627 index_arrays = [np.asarray(locs, dtype=np.int64) for locs in index] 

628 

629 # Compute block offsets 

630 row_offsets = [0, PHIrows] 

631 for i in range(total_derivs): 

632 row_offsets.append(row_offsets[-1] + len(index[i])) 

633 

634 col_offsets = [0, PHIcols] 

635 for i in range(total_derivs): 

636 col_offsets.append(col_offsets[-1] + len(index[i])) 

637 

638 # Allocate output matrix 

639 K = np.zeros((n_output_rows, n_output_cols)) 

640 

641 # Fill blocks 

642 for i in range(total_derivs + 1): 

643 for j in range(total_derivs + 1): 

644 

645 if i == 0 and j == 0: 

646 # K_ff: Full function-function block 

647 K[0:PHIrows, 0:PHIcols] = phi_exp[0] 

648 

649 elif i == 0 and j > 0: 

650 # K_fd: Function rows, derivative j columns 

651 idx = der_indices_tr_even[j - 1] 

652 col_locs = index_arrays[j - 1] 

653 col_start = col_offsets[j] 

654 

655 # Use numba for efficient column extraction 

656 extract_cols_and_assign(phi_exp[idx], col_locs, K, 

657 0, col_start, PHIrows) 

658 

659 elif i > 0 and j == 0: 

660 # K_df: Derivative i rows, function columns 

661 idx = der_indices_tr_odd[i - 1] 

662 row_locs = index_arrays[i - 1] 

663 row_start = row_offsets[i] 

664 

665 # Use numba for efficient row extraction 

666 extract_rows_and_assign(phi_exp[idx], row_locs, K, 

667 row_start, 0, PHIcols) 

668 

669 else: 

670 # K_dd: Derivative i rows, derivative j columns 

671 imdir1 = der_ind_order_even[j - 1] 

672 imdir2 = der_ind_order_odd[i - 1] 

673 new_idx, new_ord = dh.mult_dir( 

674 imdir1[0], imdir1[1], imdir2[0], imdir2[1]) 

675 idx = der_map[new_ord][new_idx] 

676 

677 row_locs = index_arrays[i - 1] 

678 col_locs = index_arrays[j - 1] 

679 row_start = row_offsets[i] 

680 col_start = col_offsets[j] 

681 

682 # Use numba for efficient submatrix extraction (replaces np.ix_) 

683 extract_and_assign(phi_exp[idx], row_locs, col_locs, K, 

684 row_start, col_start) 

685 

686 return K 

687 

688 

689@numba.jit(nopython=True, cache=True) 

690def _assemble_kernel_numba(phi_exp_3d, K, n_rows_func, n_cols_func, 

691 fd_flat_indices, df_flat_indices, dd_flat_indices, 

692 idx_flat, idx_offsets, idx_sizes, 

693 n_deriv_types, row_offsets, col_offsets): 

694 """Fused numba kernel for GDDEGP K matrix assembly (no signs, even/odd bases).""" 

695 # ff block 

696 for r in range(n_rows_func): 

697 for c in range(n_cols_func): 

698 K[r, c] = phi_exp_3d[0, r, c] 

699 # fd block (even indices) 

700 for j in range(n_deriv_types): 

701 fi = fd_flat_indices[j] 

702 co = col_offsets[j] 

703 off_j = idx_offsets[j] 

704 sz_j = idx_sizes[j] 

705 for r in range(n_rows_func): 

706 for k in range(sz_j): 

707 ci = idx_flat[off_j + k] 

708 K[r, co + k] = phi_exp_3d[fi, r, ci] 

709 # df block (odd indices) 

710 for i in range(n_deriv_types): 

711 fi = df_flat_indices[i] 

712 ro = row_offsets[i] 

713 off_i = idx_offsets[i] 

714 sz_i = idx_sizes[i] 

715 for k in range(sz_i): 

716 ri = idx_flat[off_i + k] 

717 for c in range(n_cols_func): 

718 K[ro + k, c] = phi_exp_3d[fi, ri, c] 

719 # dd block (even × odd) 

720 for i in range(n_deriv_types): 

721 ro = row_offsets[i] 

722 off_i = idx_offsets[i] 

723 sz_i = idx_sizes[i] 

724 for j in range(n_deriv_types): 

725 fi = dd_flat_indices[i, j] 

726 co = col_offsets[j] 

727 off_j = idx_offsets[j] 

728 sz_j = idx_sizes[j] 

729 for ki in range(sz_i): 

730 ri = idx_flat[off_i + ki] 

731 for kj in range(sz_j): 

732 ci = idx_flat[off_j + kj] 

733 K[ro + ki, co + kj] = phi_exp_3d[fi, ri, ci] 

734 

735 

736@numba.jit(nopython=True, cache=True) 

737def _project_W_to_phi_space(W, W_proj, n_rows_func, n_cols_func, 

738 fd_flat_indices, df_flat_indices, dd_flat_indices, 

739 idx_flat, idx_offsets, idx_sizes, 

740 n_deriv_types, row_offsets, col_offsets): 

741 """ 

742 Reverse of _assemble_kernel_numba: project W from K-space back into 

743 phi_exp-space so that vdot(W, assemble(dphi_exp)) == vdot(W_proj, dphi_exp). 

744 No-signs variant for GDDEGP even/odd bases. 

745 """ 

746 for d in range(W_proj.shape[0]): 

747 for r in range(W_proj.shape[1]): 

748 for c in range(W_proj.shape[2]): 

749 W_proj[d, r, c] = 0.0 

750 for r in range(n_rows_func): 

751 for c in range(n_cols_func): 

752 W_proj[0, r, c] += W[r, c] 

753 for j in range(n_deriv_types): 

754 fi = fd_flat_indices[j] 

755 co = col_offsets[j] 

756 off_j = idx_offsets[j] 

757 sz_j = idx_sizes[j] 

758 for r in range(n_rows_func): 

759 for k in range(sz_j): 

760 ci = idx_flat[off_j + k] 

761 W_proj[fi, r, ci] += W[r, co + k] 

762 for i in range(n_deriv_types): 

763 fi = df_flat_indices[i] 

764 ro = row_offsets[i] 

765 off_i = idx_offsets[i] 

766 sz_i = idx_sizes[i] 

767 for k in range(sz_i): 

768 ri = idx_flat[off_i + k] 

769 for c in range(n_cols_func): 

770 W_proj[fi, ri, c] += W[ro + k, c] 

771 for i in range(n_deriv_types): 

772 ro = row_offsets[i] 

773 off_i = idx_offsets[i] 

774 sz_i = idx_sizes[i] 

775 for j in range(n_deriv_types): 

776 fi = dd_flat_indices[i, j] 

777 co = col_offsets[j] 

778 off_j = idx_offsets[j] 

779 sz_j = idx_sizes[j] 

780 for ki in range(sz_i): 

781 ri = idx_flat[off_i + ki] 

782 for kj in range(sz_j): 

783 ci = idx_flat[off_j + kj] 

784 W_proj[fi, ri, ci] += W[ro + ki, co + kj] 

785 

786 

787def precompute_kernel_plan(n_order, n_bases, der_indices, powers, index): 

788 """Precompute structural info for rbf_kernel_fast (GDDEGP even/odd variant).""" 

789 dh = coti.get_dHelp() 

790 assert n_bases % 2 == 0, "n_bases must be an even number." 

791 der_map = deriv_map(n_bases, 2 * n_order) 

792 

793 n_deriv_types = len(der_indices) 

794 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index] 

795 

796 index_sizes = np.array([len(idx) for idx in index_arrays], dtype=np.int64) 

797 n_pts_with_derivs = int(index_sizes.sum()) 

798 

799 idx_flat = np.concatenate(index_arrays) if n_deriv_types > 0 else np.array([], dtype=np.int64) 

800 idx_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

801 for i in range(1, n_deriv_types): 

802 idx_offsets[i] = idx_offsets[i - 1] + index_sizes[i - 1] 

803 

804 row_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

805 col_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

806 cumsum = 0 

807 for i in range(n_deriv_types): 

808 row_offsets[i] = cumsum 

809 col_offsets[i] = cumsum 

810 cumsum += index_sizes[i] 

811 

812 # Even/odd derivative transforms 

813 der_indices_even = make_first_even(der_indices) 

814 der_indices_odd = make_first_odd(der_indices) 

815 der_indices_tr_even, der_ind_order_even = transform_der_indices(der_indices_even, der_map) 

816 der_indices_tr_odd, der_ind_order_odd = transform_der_indices(der_indices_odd, der_map) 

817 

818 fd_flat_indices = np.array(der_indices_tr_even, dtype=np.int64) 

819 df_flat_indices = np.array(der_indices_tr_odd, dtype=np.int64) 

820 

821 dd_flat_indices = np.empty((n_deriv_types, n_deriv_types), dtype=np.int64) 

822 for i in range(n_deriv_types): 

823 for j in range(n_deriv_types): 

824 imdir1 = der_ind_order_even[j] 

825 imdir2 = der_ind_order_odd[i] 

826 new_idx, new_ord = dh.mult_dir(imdir1[0], imdir1[1], imdir2[0], imdir2[1]) 

827 dd_flat_indices[i, j] = der_map[new_ord][new_idx] 

828 

829 return { 

830 'signs': np.ones(n_deriv_types + 1, dtype=np.float64), # unused, kept for API 

831 'index_arrays': index_arrays, 

832 'index_sizes': index_sizes, 

833 'n_pts_with_derivs': n_pts_with_derivs, 

834 'dd_flat_indices': dd_flat_indices, 

835 'n_deriv_types': n_deriv_types, 

836 'idx_flat': idx_flat, 

837 'idx_offsets': idx_offsets, 

838 'row_offsets': row_offsets, 

839 'col_offsets': col_offsets, 

840 'fd_flat_indices': fd_flat_indices, 

841 'df_flat_indices': df_flat_indices, 

842 } 

843 

844 

845def rbf_kernel_fast(phi_exp_3d, plan, out=None): 

846 """Fast kernel assembly using precomputed plan and fused numba kernel.""" 

847 n_rows_func = phi_exp_3d.shape[1] 

848 n_cols_func = phi_exp_3d.shape[2] 

849 total = n_rows_func + plan['n_pts_with_derivs'] 

850 if out is not None: 

851 K = out 

852 else: 

853 K = np.empty((total, total)) 

854 

855 if 'row_offsets_abs' in plan: 

856 row_off = plan['row_offsets_abs'] 

857 col_off = plan['col_offsets_abs'] 

858 else: 

859 row_off = plan['row_offsets'] + n_rows_func 

860 col_off = plan['col_offsets'] + n_cols_func 

861 

862 _assemble_kernel_numba( 

863 phi_exp_3d, K, n_rows_func, n_cols_func, 

864 plan['fd_flat_indices'], plan['df_flat_indices'], plan['dd_flat_indices'], 

865 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'], 

866 plan['n_deriv_types'], row_off, col_off, 

867 ) 

868 return K 

869 

870 

871@profile 

872def rbf_kernel_predictions( 

873 phi, 

874 phi_exp, 

875 n_order, 

876 n_bases, 

877 der_indices, 

878 return_deriv, 

879 index=None, 

880 common_derivs=None, 

881 calc_cov=False, 

882): 

883 """ 

884 Constructs the RBF kernel matrix for predictions with selective derivative coverage. 

885  

886 This version uses Numba-accelerated functions for efficient matrix slicing. 

887 

888 Parameters 

889 ---------- 

890 phi : OTI array 

891 Base kernel matrix between test and training points. 

892 phi_exp : ndarray 

893 Expanded derivative array from phi.get_all_derivs(). 

894 n_order : int 

895 Maximum derivative order. 

896 n_bases : int 

897 Number of OTI bases. 

898 der_indices : list 

899 Derivative specifications for training data. 

900 return_deriv : bool 

901 If True, predict derivatives at test points. 

902 index : list of list 

903 Training point indices for each derivative type. 

904 common_derivs : list 

905 Common derivative indices to predict. 

906 calc_cov : bool 

907 If True, computing covariance. 

908 

909 Returns 

910 ------- 

911 K : ndarray 

912 Prediction kernel matrix. 

913 """ 

914 # Early return for covariance-only case 

915 if calc_cov and not return_deriv: 

916 return phi.real.T 

917 

918 dh = coti.get_dHelp() 

919 

920 n_train, n_test = phi.shape 

921 n_deriv_types = len(der_indices) 

922 n_deriv_types_pred = len(common_derivs) if common_derivs else 0 

923 

924 # Handle n_order = 0 case 

925 if n_order == 0: 

926 return phi.real.T 

927 

928 # Convert index lists to numpy arrays for numba 

929 index_arrays = [np.asarray(locs, dtype=np.int64) for locs in index] 

930 

931 # Determine derivative map 

932 if return_deriv: 

933 der_map = deriv_map(n_bases, 2 * n_order) 

934 derivative_locations_test = [np.arange(n_test, dtype=np.int64)] * n_deriv_types_pred 

935 else: 

936 der_map = deriv_map(n_bases, n_order) 

937 

938 # Create derivative index transformations 

939 der_indices_even = make_first_even(der_indices) 

940 der_indices_odd = make_first_odd(der_indices) 

941 der_indices_tr_odd, der_ind_order_odd = transform_der_indices(der_indices_odd, der_map) 

942 der_indices_odd_pred = make_first_odd(common_derivs) if common_derivs else [] 

943 der_indices_tr_odd_pred, der_ind_order_odd_pred = transform_der_indices(der_indices_odd_pred, der_map) if common_derivs else ([], []) 

944 

945 # Compute matrix dimensions 

946 n_rows_func = n_test 

947 if return_deriv: 

948 n_rows_derivs = sum(len(locs) for locs in derivative_locations_test) 

949 else: 

950 n_rows_derivs = 0 

951 total_rows = n_rows_func + n_rows_derivs 

952 

953 if return_deriv and calc_cov: 

954 n_deriv_types = n_deriv_types_pred 

955 n_cols_derivs = sum(len(locs) for locs in derivative_locations_test) 

956 total_cols = n_train + n_cols_derivs 

957 else: 

958 n_cols_derivs = sum(len(locs) for locs in index) 

959 total_cols = n_train + n_cols_derivs 

960 

961 # Compute block offsets 

962 row_offsets = [0, n_test] 

963 if return_deriv: 

964 for i in range(n_deriv_types_pred): 

965 row_offsets.append(row_offsets[-1] + len(derivative_locations_test[i])) 

966 

967 col_offsets = [0, n_train] 

968 if return_deriv and calc_cov: 

969 for i in range(n_deriv_types): 

970 col_offsets.append(col_offsets[-1] + len(derivative_locations_test[i])) 

971 else: 

972 for i in range(n_deriv_types): 

973 col_offsets.append(col_offsets[-1] + len(index[i])) 

974 

975 # Allocate output matrix 

976 K = np.zeros((total_rows, total_cols)) 

977 base_shape = (n_train, n_test) 

978 

979 # Block (0,0): Function-Function (K_ff) 

980 content_full = phi_exp[0].reshape(base_shape) 

981 K[:n_test, :n_train] = content_full.T 

982 

983 # First Block-Row: Function-Derivative (K_fd) 

984 for j in range(n_deriv_types): 

985 col_locs = derivative_locations_test[j] if (return_deriv and calc_cov) else index_arrays[j] 

986 col_start = col_offsets[j + 1] 

987 

988 flat_idx = der_indices_tr_odd_pred[j] if calc_cov else der_indices_tr_odd[j] 

989 content_full = phi_exp[flat_idx].reshape(base_shape) 

990 

991 # Use numba for efficient row extraction with transpose 

992 extract_rows_and_assign_transposed(content_full, col_locs, K, 

993 0, col_start, n_test) 

994 

995 if not return_deriv: 

996 return K 

997 

998 # First Block-Column: Derivative-Function (K_df) 

999 der_indices_tr_even, der_ind_order_even = transform_der_indices(der_indices_even, der_map) 

1000 der_indices_even_pred = make_first_even(common_derivs) 

1001 der_indices_tr_even_pred, der_ind_order_even_pred = transform_der_indices(der_indices_even_pred, der_map) 

1002 

1003 for i in range(n_deriv_types_pred): 

1004 test_locs = derivative_locations_test[i] 

1005 row_start = row_offsets[i + 1] 

1006 

1007 flat_idx = der_indices_tr_even_pred[i] 

1008 content_full = phi_exp[flat_idx].reshape(base_shape) 

1009 

1010 # Use numba for efficient column extraction with transpose 

1011 extract_cols_and_assign_transposed(content_full, test_locs, K, 

1012 row_start, 0, n_train) 

1013 

1014 # Inner Blocks: Derivative-Derivative (K_dd) 

1015 for i in range(n_deriv_types_pred): 

1016 test_locs = derivative_locations_test[i] 

1017 row_start = row_offsets[i + 1] 

1018 

1019 for j in range(n_deriv_types): 

1020 col_locs = derivative_locations_test[j] if (return_deriv and calc_cov) else index_arrays[j] 

1021 col_start = col_offsets[j + 1] 

1022 

1023 imdir_train = der_ind_order_odd_pred[j] if calc_cov else der_ind_order_odd[j] 

1024 imdir_test = der_ind_order_even_pred[i] 

1025 new_idx, new_ord = dh.mult_dir( 

1026 imdir_train[0], imdir_train[1], 

1027 imdir_test[0], imdir_test[1] 

1028 ) 

1029 flat_idx = der_map[new_ord][new_idx] 

1030 

1031 content_full = phi_exp[flat_idx].reshape(base_shape) 

1032 

1033 # Use numba for efficient submatrix extraction with transpose (replaces np.ix_ + .T) 

1034 extract_and_assign_transposed(content_full, col_locs, test_locs, K, 

1035 row_start, col_start) 

1036 

1037 return K