Coverage for jetgp/wdegp/wdegp_utils.py: 66%

390 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-04-03 15:09 -0500

1import pyoti.core as coti 

2import numpy as np 

3from line_profiler import profile 

4import numba 

5 

6 

7# ============================================================================= 

8# Numba-accelerated helper functions for efficient matrix slicing 

9# ============================================================================= 

10 

11@numba.jit(nopython=True, cache=True) 

12def extract_rows(content_full, row_indices, n_cols): 

13 """ 

14 Extract rows from content_full at specified indices. 

15  

16 Parameters 

17 ---------- 

18 content_full : ndarray of shape (n_rows_full, n_cols) 

19 Source matrix. 

20 row_indices : ndarray of int64 

21 Row indices to extract. 

22 n_cols : int 

23 Number of columns. 

24  

25 Returns 

26 ------- 

27 result : ndarray of shape (len(row_indices), n_cols) 

28 Extracted rows. 

29 """ 

30 n_rows = len(row_indices) 

31 result = np.empty((n_rows, n_cols)) 

32 for i in range(n_rows): 

33 ri = row_indices[i] 

34 for j in range(n_cols): 

35 result[i, j] = content_full[ri, j] 

36 return result 

37 

38 

39@numba.jit(nopython=True, cache=True) 

40def extract_cols(content_full, col_indices, n_rows): 

41 """ 

42 Extract columns from content_full at specified indices. 

43  

44 Parameters 

45 ---------- 

46 content_full : ndarray of shape (n_rows, n_cols_full) 

47 Source matrix. 

48 col_indices : ndarray of int64 

49 Column indices to extract. 

50 n_rows : int 

51 Number of rows. 

52  

53 Returns 

54 ------- 

55 result : ndarray of shape (n_rows, len(col_indices)) 

56 Extracted columns. 

57 """ 

58 n_cols = len(col_indices) 

59 result = np.empty((n_rows, n_cols)) 

60 for i in range(n_rows): 

61 for j in range(n_cols): 

62 result[i, j] = content_full[i, col_indices[j]] 

63 return result 

64 

65 

66@numba.jit(nopython=True, cache=True) 

67def extract_submatrix(content_full, row_indices, col_indices): 

68 """ 

69 Extract submatrix from content_full at specified row and column indices. 

70 Replaces the expensive np.ix_ operation. 

71  

72 Parameters 

73 ---------- 

74 content_full : ndarray of shape (n_rows_full, n_cols_full) 

75 Source matrix. 

76 row_indices : ndarray of int64 

77 Row indices to extract. 

78 col_indices : ndarray of int64 

79 Column indices to extract. 

80  

81 Returns 

82 ------- 

83 result : ndarray of shape (len(row_indices), len(col_indices)) 

84 Extracted submatrix. 

85 """ 

86 n_rows = len(row_indices) 

87 n_cols = len(col_indices) 

88 result = np.empty((n_rows, n_cols)) 

89 for i in range(n_rows): 

90 ri = row_indices[i] 

91 for j in range(n_cols): 

92 result[i, j] = content_full[ri, col_indices[j]] 

93 return result 

94 

95 

96@numba.jit(nopython=True, cache=True, parallel=False) 

97def extract_and_assign(content_full, row_indices, col_indices, K, 

98 row_start, col_start, sign): 

99 """ 

100 Extract submatrix and assign directly to K with sign multiplication. 

101 Combines extraction and assignment in one pass for better performance. 

102  

103 Parameters 

104 ---------- 

105 content_full : ndarray of shape (n_rows_full, n_cols_full) 

106 Source matrix. 

107 row_indices : ndarray of int64 

108 Row indices to extract. 

109 col_indices : ndarray of int64 

110 Column indices to extract. 

111 K : ndarray 

112 Target matrix to fill. 

113 row_start : int 

114 Starting row index in K. 

115 col_start : int 

116 Starting column index in K. 

117 sign : float 

118 Sign multiplier (+1.0 or -1.0). 

119 """ 

120 n_rows = len(row_indices) 

121 n_cols = len(col_indices) 

122 for i in range(n_rows): 

123 ri = row_indices[i] 

124 for j in range(n_cols): 

125 K[row_start + i, col_start + j] = content_full[ri, col_indices[j]] * sign 

126 

127 

128@numba.jit(nopython=True, cache=True) 

129def extract_rows_and_assign(content_full, row_indices, K, 

130 row_start, col_start, n_cols, sign): 

131 """ 

132 Extract rows and assign directly to K with sign multiplication. 

133  

134 Parameters 

135 ---------- 

136 content_full : ndarray of shape (n_rows_full, n_cols) 

137 Source matrix. 

138 row_indices : ndarray of int64 

139 Row indices to extract. 

140 K : ndarray 

141 Target matrix to fill. 

142 row_start : int 

143 Starting row index in K. 

144 col_start : int 

145 Starting column index in K. 

146 n_cols : int 

147 Number of columns to copy. 

148 sign : float 

149 Sign multiplier (+1.0 or -1.0). 

150 """ 

151 n_rows = len(row_indices) 

152 for i in range(n_rows): 

153 ri = row_indices[i] 

154 for j in range(n_cols): 

155 K[row_start + i, col_start + j] = content_full[ri, j] * sign 

156 

157 

158@numba.jit(nopython=True, cache=True) 

159def extract_cols_and_assign(content_full, col_indices, K, 

160 row_start, col_start, n_rows, sign): 

161 """ 

162 Extract columns and assign directly to K with sign multiplication. 

163  

164 Parameters 

165 ---------- 

166 content_full : ndarray of shape (n_rows, n_cols_full) 

167 Source matrix. 

168 col_indices : ndarray of int64 

169 Column indices to extract. 

170 K : ndarray 

171 Target matrix to fill. 

172 row_start : int 

173 Starting row index in K. 

174 col_start : int 

175 Starting column index in K. 

176 n_rows : int 

177 Number of rows to copy. 

178 sign : float 

179 Sign multiplier (+1.0 or -1.0). 

180 """ 

181 n_cols = len(col_indices) 

182 for i in range(n_rows): 

183 for j in range(n_cols): 

184 K[row_start + i, col_start + j] = content_full[i, col_indices[j]] * sign 

185 

186def differences_by_dim_func(X1, X2, n_order, oti_module, return_deriv=True): 

187 """ 

188 Compute pairwise differences between two input arrays X1 and X2 for each dimension, 

189 embedding hypercomplex units along each dimension for automatic differentiation. 

190 

191 For each dimension k, this function computes: 

192 diff_k[i, j] = X1[i, k] + e_{k+1} - X2[j, k] 

193 where e_{k+1} is a hypercomplex unit for the (k+1)-th dimension with order 2 * n_order. 

194 

195 Parameters 

196 ---------- 

197 X1 : array_like of shape (n1, d) 

198 First set of input points with n1 samples in d dimensions. 

199 X2 : array_like of shape (n2, d) 

200 Second set of input points with n2 samples in d dimensions. 

201 n_order : int 

202 The base order used to construct hypercomplex units (e_{k+1}) with order 2 * n_order. 

203 oti_module : module 

204 The PyOTI static module (e.g., pyoti.static.onumm4n2). 

205 return_deriv : bool, optional 

206 If True, use 2*n_order for derivative predictions. 

207 

208 Returns 

209 ------- 

210 differences_by_dim : list of length d 

211 A list where each element is an array of shape (n1, n2), containing the differences 

212 between corresponding dimensions of X1 and X2, augmented with hypercomplex units. 

213 """ 

214 X1 = oti_module.array(X1) 

215 X2 = oti_module.array(X2) 

216 n1, d = X1.shape 

217 n2, d = X2.shape 

218 

219 differences_by_dim = [] 

220 

221 if n_order == 0: 

222 for k in range(d): 

223 diffs_k = oti_module.zeros((n1, n2)) 

224 for i in range(n1): 

225 diffs_k[i, :] = X1[i, k] - (oti_module.transpose(X2[:, k])) 

226 differences_by_dim.append(diffs_k) 

227 elif not return_deriv: 

228 for k in range(d): 

229 diffs_k = oti_module.zeros((n1, n2)) 

230 for i in range(n1): 

231 diffs_k[i, :] = ( 

232 X1[i, k] 

233 + oti_module.e(k + 1, order=n_order) 

234 - (X2[:, k].T) 

235 ) 

236 differences_by_dim.append(diffs_k) 

237 else: 

238 for k in range(d): 

239 diffs_k = oti_module.zeros((n1, n2)) 

240 for i in range(n1): 

241 diffs_k[i, :] = X1[i, k] - (X2[:, k].T) 

242 differences_by_dim.append(diffs_k + oti_module.e(k + 1, order=2 * n_order)) 

243 

244 return differences_by_dim 

245 

246 

247# ============================================================================= 

248# Derivative mapping utilities 

249# ============================================================================= 

250 

251def deriv_map(nbases, order): 

252 """ 

253 Creates a mapping from (order, index_within_order) to a single 

254 flattened index for all derivative components. 

255 """ 

256 k = 0 

257 map_deriv = [] 

258 for ordi in range(order + 1): 

259 ndir = coti.ndir_order(nbases, ordi) 

260 map_deriv_i = [0] * ndir 

261 for idx in range(ndir): 

262 map_deriv_i[idx] = k 

263 k += 1 

264 map_deriv.append(map_deriv_i) 

265 return map_deriv 

266 

267 

268def transform_der_indices(der_indices, der_map): 

269 """ 

270 Transforms a list of user-facing derivative specifications into the 

271 internal (order, index) format and the final flattened index. 

272 """ 

273 deriv_ind_transf = [] 

274 deriv_ind_order = [] 

275 for deriv in der_indices: 

276 imdir = coti.imdir(deriv) 

277 idx, order = imdir 

278 deriv_ind_transf.append(der_map[order][idx]) 

279 deriv_ind_order.append(imdir) 

280 return deriv_ind_transf, deriv_ind_order 

281 

282 

283# ============================================================================= 

284# RBF Kernel Assembly Functions (Optimized with Numba) 

285# ============================================================================= 

286 

287@profile 

288def rbf_kernel( 

289 phi, 

290 phi_exp, 

291 n_order, 

292 n_bases, 

293 der_indices, 

294 powers, 

295 index=-1, 

296): 

297 """ 

298 Constructs the RBF kernel matrix with derivative entries using an 

299 efficient pre-allocation strategy combined with a single call to 

300 extract all derivative components. 

301  

302 This version uses Numba-accelerated functions for efficient matrix slicing, 

303 replacing expensive np.ix_ operations. 

304 

305 Parameters 

306 ---------- 

307 phi : OTI array 

308 Base kernel matrix from kernel_func(differences, length_scales). 

309 phi_exp : ndarray 

310 Expanded derivative array from phi.get_all_derivs(). 

311 n_order : int 

312 Maximum derivative order. 

313 n_bases : int 

314 Number of OTI bases. 

315 der_indices : list 

316 Derivative specifications. 

317 powers : list of int 

318 Sign powers for each derivative type. 

319 index : list of lists 

320 Training point indices for each derivative type. 

321 

322 Returns 

323 ------- 

324 K : ndarray 

325 Full RBF kernel matrix with mixed function and derivative entries. 

326 """ 

327 dh = coti.get_dHelp() 

328 

329 # Create maps to translate derivative specifications to flat indices 

330 der_map = deriv_map(n_bases, 2 * n_order) 

331 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map) 

332 

333 # Determine Block Sizes and Pre-allocate Matrix 

334 n_rows_func, n_cols_func = phi.shape 

335 n_deriv_types = len(der_indices) 

336 n_pts_with_derivs_cols = sum(len(order_indices) for order_indices in index) 

337 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index) 

338 total_rows = n_rows_func + n_pts_with_derivs_rows 

339 total_cols = n_cols_func + n_pts_with_derivs_cols 

340 

341 K = np.zeros((total_rows, total_cols)) 

342 base_shape = (n_rows_func, n_cols_func) 

343 

344 # Pre-compute signs (avoid repeated exponentiation) 

345 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64) 

346 

347 # Convert index lists to numpy arrays for numba 

348 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index] 

349 

350 # Block (0,0): Function-Function (K_ff) 

351 content_full = phi_exp[0].reshape(base_shape) 

352 K[:n_rows_func, :n_cols_func] = content_full * signs[0] 

353 

354 # First Block-Row: Function-Derivative (K_fd) 

355 col_offset = n_cols_func 

356 for j in range(n_deriv_types): 

357 flat_idx = der_indices_tr[j] 

358 content_full = phi_exp[flat_idx].reshape(base_shape) 

359 current_indices = index_arrays[j] 

360 n_pts_this_order = len(current_indices) 

361 

362 # Use numba for efficient column extraction and assignment 

363 extract_cols_and_assign(content_full, current_indices, K, 

364 0, col_offset, n_rows_func, signs[j + 1]) 

365 col_offset += n_pts_this_order 

366 

367 # First Block-Column: Derivative-Function (K_df) 

368 row_offset = n_rows_func 

369 for i in range(n_deriv_types): 

370 flat_idx = der_indices_tr[i] 

371 content_full = phi_exp[flat_idx].reshape(base_shape) 

372 current_indices = index_arrays[i] 

373 n_pts_this_order = len(current_indices) 

374 

375 # Use numba for efficient row extraction and assignment 

376 extract_rows_and_assign(content_full, current_indices, K, 

377 row_offset, 0, n_cols_func, signs[0]) 

378 row_offset += n_pts_this_order 

379 

380 # Inner Blocks: Derivative-Derivative (K_dd) 

381 row_offset = n_rows_func 

382 for i in range(n_deriv_types): 

383 col_offset = n_cols_func 

384 row_indices = index_arrays[i] 

385 n_pts_row = len(row_indices) 

386 

387 for j in range(n_deriv_types): 

388 col_indices = index_arrays[j] 

389 n_pts_col = len(col_indices) 

390 

391 # Multiply the derivative indices to find the correct flat index 

392 imdir1 = der_ind_order[j] 

393 imdir2 = der_ind_order[i] 

394 new_idx, new_ord = dh.mult_dir( 

395 imdir1[0], imdir1[1], imdir2[0], imdir2[1]) 

396 flat_idx = der_map[new_ord][new_idx] 

397 content_full = phi_exp[flat_idx].reshape(base_shape) 

398 

399 # Use numba for efficient submatrix extraction and assignment (replaces np.ix_) 

400 extract_and_assign(content_full, row_indices, col_indices, K, 

401 row_offset, col_offset, signs[j + 1]) 

402 

403 col_offset += n_pts_col 

404 row_offset += n_pts_row 

405 

406 return K 

407 

408 

409@numba.jit(nopython=True, cache=True) 

410def _assemble_kernel_numba(phi_exp_3d, K, n_rows_func, n_cols_func, 

411 fd_flat_indices, df_flat_indices, dd_flat_indices, 

412 idx_flat, idx_offsets, idx_sizes, 

413 signs, n_deriv_types, row_offsets, col_offsets): 

414 """ 

415 Fused numba kernel that assembles the entire K matrix in a single call. 

416 Handles ff, fd, df, and dd blocks without Python-level loop overhead. 

417 """ 

418 # Block (0,0): Function-Function 

419 s0 = signs[0] 

420 for r in range(n_rows_func): 

421 for c in range(n_cols_func): 

422 K[r, c] = phi_exp_3d[0, r, c] * s0 

423 

424 # First Block-Row: Function-Derivative (fd) 

425 for j in range(n_deriv_types): 

426 fi = fd_flat_indices[j] 

427 sj = signs[j + 1] 

428 co = col_offsets[j] 

429 off_j = idx_offsets[j] 

430 sz_j = idx_sizes[j] 

431 for r in range(n_rows_func): 

432 for k in range(sz_j): 

433 ci = idx_flat[off_j + k] 

434 K[r, co + k] = phi_exp_3d[fi, r, ci] * sj 

435 

436 # First Block-Column: Derivative-Function (df) 

437 for i in range(n_deriv_types): 

438 fi = df_flat_indices[i] 

439 ro = row_offsets[i] 

440 off_i = idx_offsets[i] 

441 sz_i = idx_sizes[i] 

442 for k in range(sz_i): 

443 ri = idx_flat[off_i + k] 

444 for c in range(n_cols_func): 

445 K[ro + k, c] = phi_exp_3d[fi, ri, c] * s0 

446 

447 # Inner Blocks: Derivative-Derivative (dd) 

448 for i in range(n_deriv_types): 

449 ro = row_offsets[i] 

450 off_i = idx_offsets[i] 

451 sz_i = idx_sizes[i] 

452 for j in range(n_deriv_types): 

453 fi = dd_flat_indices[i, j] 

454 sj = signs[j + 1] 

455 co = col_offsets[j] 

456 off_j = idx_offsets[j] 

457 sz_j = idx_sizes[j] 

458 for ki in range(sz_i): 

459 ri = idx_flat[off_i + ki] 

460 for kj in range(sz_j): 

461 ci = idx_flat[off_j + kj] 

462 K[ro + ki, co + kj] = phi_exp_3d[fi, ri, ci] * sj 

463 

464 

465@numba.jit(nopython=True, cache=True) 

466def _project_W_to_phi_space(W, W_proj, n_rows_func, n_cols_func, 

467 fd_flat_indices, df_flat_indices, dd_flat_indices, 

468 idx_flat, idx_offsets, idx_sizes, 

469 signs, n_deriv_types, row_offsets, col_offsets): 

470 """ 

471 Reverse of _assemble_kernel_numba: project W from K-space back into 

472 phi_exp-space so that vdot(W, assemble(dphi_exp)) == vdot(W_proj, dphi_exp). 

473 """ 

474 for d in range(W_proj.shape[0]): 

475 for r in range(W_proj.shape[1]): 

476 for c in range(W_proj.shape[2]): 

477 W_proj[d, r, c] = 0.0 

478 s0 = signs[0] 

479 for r in range(n_rows_func): 

480 for c in range(n_cols_func): 

481 W_proj[0, r, c] += s0 * W[r, c] 

482 for j in range(n_deriv_types): 

483 fi = fd_flat_indices[j] 

484 sj = signs[j + 1] 

485 co = col_offsets[j] 

486 off_j = idx_offsets[j] 

487 sz_j = idx_sizes[j] 

488 for r in range(n_rows_func): 

489 for k in range(sz_j): 

490 ci = idx_flat[off_j + k] 

491 W_proj[fi, r, ci] += sj * W[r, co + k] 

492 for i in range(n_deriv_types): 

493 fi = df_flat_indices[i] 

494 ro = row_offsets[i] 

495 off_i = idx_offsets[i] 

496 sz_i = idx_sizes[i] 

497 for k in range(sz_i): 

498 ri = idx_flat[off_i + k] 

499 for c in range(n_cols_func): 

500 W_proj[fi, ri, c] += s0 * W[ro + k, c] 

501 for i in range(n_deriv_types): 

502 ro = row_offsets[i] 

503 off_i = idx_offsets[i] 

504 sz_i = idx_sizes[i] 

505 for j in range(n_deriv_types): 

506 fi = dd_flat_indices[i, j] 

507 sj = signs[j + 1] 

508 co = col_offsets[j] 

509 off_j = idx_offsets[j] 

510 sz_j = idx_sizes[j] 

511 for ki in range(sz_i): 

512 ri = idx_flat[off_i + ki] 

513 for kj in range(sz_j): 

514 ci = idx_flat[off_j + kj] 

515 W_proj[fi, ri, ci] += sj * W[ro + ki, co + kj] 

516 

517 

518def precompute_kernel_plan(n_order, n_bases, der_indices, powers, index): 

519 """ 

520 Precompute all structural information needed by rbf_kernel so it can be 

521 reused across repeated calls with different phi_exp values. 

522 

523 Returns a dict containing flat indices, signs, index arrays, precomputed 

524 offsets/sizes, and mult_dir results for the dd block. 

525 """ 

526 dh = coti.get_dHelp() 

527 der_map = deriv_map(n_bases, 2 * n_order) 

528 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map) 

529 

530 n_deriv_types = len(der_indices) 

531 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64) 

532 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index] 

533 

534 # Precompute sizes and offsets 

535 index_sizes = np.array([len(idx) for idx in index_arrays], dtype=np.int64) 

536 n_pts_with_derivs = int(index_sizes.sum()) 

537 

538 # Pack all index arrays into a single flat array with offsets 

539 idx_flat = np.concatenate(index_arrays) if n_deriv_types > 0 else np.array([], dtype=np.int64) 

540 idx_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

541 for i in range(1, n_deriv_types): 

542 idx_offsets[i] = idx_offsets[i - 1] + index_sizes[i - 1] 

543 

544 # Precompute row/col offsets in K for each deriv type 

545 row_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

546 col_offsets = np.zeros(n_deriv_types, dtype=np.int64) 

547 # Note: n_rows_func == n_cols_func for training kernel, but we store 

548 # offsets relative to n_rows_func which is added at call time 

549 cumsum = 0 

550 for i in range(n_deriv_types): 

551 row_offsets[i] = cumsum # relative to n_rows_func 

552 col_offsets[i] = cumsum # relative to n_cols_func 

553 cumsum += index_sizes[i] 

554 

555 # Precompute mult_dir results for dd blocks 

556 dd_flat_indices = np.empty((n_deriv_types, n_deriv_types), dtype=np.int64) 

557 for i in range(n_deriv_types): 

558 for j in range(n_deriv_types): 

559 imdir1 = der_ind_order[j] 

560 imdir2 = der_ind_order[i] 

561 new_idx, new_ord = dh.mult_dir( 

562 imdir1[0], imdir1[1], imdir2[0], imdir2[1]) 

563 dd_flat_indices[i, j] = der_map[new_ord][new_idx] 

564 

565 # fd and df flat indices as arrays 

566 fd_flat_indices = np.array(der_indices_tr, dtype=np.int64) 

567 df_flat_indices = np.array(der_indices_tr, dtype=np.int64) 

568 

569 return { 

570 'der_indices_tr': der_indices_tr, 

571 'signs': signs, 

572 'index_arrays': index_arrays, 

573 'index_sizes': index_sizes, 

574 'n_pts_with_derivs': n_pts_with_derivs, 

575 'dd_flat_indices': dd_flat_indices, 

576 'n_deriv_types': n_deriv_types, 

577 # Fused kernel data 

578 'idx_flat': idx_flat, 

579 'idx_offsets': idx_offsets, 

580 'row_offsets': row_offsets, 

581 'col_offsets': col_offsets, 

582 'fd_flat_indices': fd_flat_indices, 

583 'df_flat_indices': df_flat_indices, 

584 } 

585 

586 

587def rbf_kernel_fast(phi_exp_3d, plan, out=None): 

588 """ 

589 Fast kernel assembly using a precomputed plan and fused numba kernel. 

590 

591 Parameters 

592 ---------- 

593 phi_exp_3d : ndarray of shape (n_derivs, n_rows_func, n_cols_func) 

594 Pre-reshaped expanded derivative array. 

595 plan : dict 

596 Precomputed plan from precompute_kernel_plan(). 

597 out : ndarray, optional 

598 Pre-allocated output array. If None, a new array is allocated. 

599 

600 Returns 

601 ------- 

602 K : ndarray 

603 Full kernel matrix. 

604 """ 

605 n_rows_func = phi_exp_3d.shape[1] 

606 n_cols_func = phi_exp_3d.shape[2] 

607 total = n_rows_func + plan['n_pts_with_derivs'] 

608 if out is not None: 

609 K = out 

610 else: 

611 K = np.empty((total, total)) 

612 

613 if 'row_offsets_abs' in plan: 

614 row_off = plan['row_offsets_abs'] 

615 col_off = plan['col_offsets_abs'] 

616 else: 

617 row_off = plan['row_offsets'] + n_rows_func 

618 col_off = plan['col_offsets'] + n_cols_func 

619 

620 _assemble_kernel_numba( 

621 phi_exp_3d, K, n_rows_func, n_cols_func, 

622 plan['fd_flat_indices'], plan['df_flat_indices'], plan['dd_flat_indices'], 

623 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'], 

624 plan['signs'], plan['n_deriv_types'], row_off, col_off, 

625 ) 

626 

627 return K 

628 

629 

630@profile 

631def rbf_kernel_predictions( 

632 phi, 

633 phi_exp, 

634 n_order, 

635 n_bases, 

636 der_indices, 

637 powers, 

638 return_deriv, 

639 index=-1, 

640 common_derivs=None, 

641 calc_cov=False, 

642 powers_predict=None 

643): 

644 """ 

645 Constructs the RBF kernel matrix for predictions with derivative entries. 

646  

647 This version uses Numba-accelerated functions for efficient matrix slicing. 

648 

649 Parameters 

650 ---------- 

651 phi : OTI array 

652 Base kernel matrix between test and training points. 

653 phi_exp : ndarray 

654 Expanded derivative array from phi.get_all_derivs(). 

655 n_order : int 

656 Maximum derivative order. 

657 n_bases : int 

658 Number of OTI bases. 

659 der_indices : list 

660 Derivative specifications for training data. 

661 powers : list of int 

662 Sign powers for each derivative type. 

663 return_deriv : bool 

664 If True, predict derivatives at test points. 

665 index : list of lists 

666 Training point indices for each derivative type. 

667 common_derivs : list 

668 Common derivative indices to predict. 

669 calc_cov : bool 

670 If True, computing covariance. 

671 powers_predict : list of int, optional 

672 Sign powers for prediction derivatives. 

673 

674 Returns 

675 ------- 

676 K : ndarray 

677 Prediction kernel matrix. 

678 """ 

679 if calc_cov and not return_deriv: 

680 return phi.real 

681 

682 dh = coti.get_dHelp() 

683 

684 n_rows_func, n_cols_func = phi.shape 

685 n_deriv_types = len(der_indices) 

686 n_deriv_types_pred = len(common_derivs) if common_derivs else 0 

687 

688 # Pre-compute signs 

689 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64) 

690 if powers_predict is not None: 

691 signs_predict = np.array([(-1.0) ** p for p in powers_predict], dtype=np.float64) 

692 else: 

693 signs_predict = signs 

694 

695 if return_deriv: 

696 der_map = deriv_map(n_bases, 2 * n_order) 

697 index_2 = np.arange(phi_exp.shape[-1], dtype=np.int64) 

698 if calc_cov: 

699 index_cov = np.arange(phi_exp.shape[-1], dtype=np.int64) 

700 n_deriv_types = n_deriv_types_pred 

701 n_pts_with_derivs_rows = n_deriv_types * len([i for i in range(n_cols_func) if i < len(index_2)]) 

702 else: 

703 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index) if isinstance(index, list) else 0 

704 else: 

705 der_map = deriv_map(n_bases, n_order) 

706 index_2 = np.array([], dtype=np.int64) 

707 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index) if isinstance(index, list) else 0 

708 

709 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map) 

710 der_indices_tr_pred, der_ind_order_pred = transform_der_indices(common_derivs, der_map) if common_derivs else ([], []) 

711 n_pts_with_derivs_cols = n_deriv_types_pred * len([i for i in range(n_cols_func) if i < len(index_2)]) 

712 

713 total_rows = n_rows_func + n_pts_with_derivs_rows 

714 total_cols = n_cols_func + n_pts_with_derivs_cols 

715 

716 K = np.zeros((total_rows, total_cols)) 

717 base_shape = (n_rows_func, n_cols_func) 

718 

719 # Convert index lists to numpy arrays for numba 

720 if isinstance(index, list) and len(index) > 0: 

721 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index] 

722 else: 

723 index_arrays = [] 

724 

725 # Block (0,0): Function-Function (K_ff) 

726 content_full = phi_exp[0].reshape(base_shape) 

727 K[:n_rows_func, :n_cols_func] = content_full * signs[0] 

728 

729 if not return_deriv: 

730 # First Block-Column: Derivative-Function (K_df) 

731 row_offset = n_rows_func 

732 for i in range(n_deriv_types): 

733 if not index_arrays: 

734 break 

735 row_indices = index_arrays[i] 

736 n_pts_row = len(row_indices) 

737 

738 flat_idx = der_indices_tr[i] 

739 content_full = phi_exp[flat_idx].reshape(base_shape) 

740 

741 # Use numba for efficient row extraction 

742 extract_rows_and_assign(content_full, row_indices, K, 

743 row_offset, 0, n_cols_func, signs[0]) 

744 row_offset += n_pts_row 

745 return K 

746 

747 # --- return_deriv=True case --- 

748 

749 # First Block-Row: Function-Derivative (K_fd) 

750 col_offset = n_cols_func 

751 for j in range(n_deriv_types_pred): 

752 col_indices = index_2 

753 n_pts_col = len(col_indices) 

754 

755 flat_idx = der_indices_tr_pred[j] 

756 content_full = phi_exp[flat_idx].reshape(base_shape) 

757 

758 # Use numba for efficient column extraction 

759 extract_cols_and_assign(content_full, col_indices, K, 

760 0, col_offset, n_rows_func, signs_predict[j + 1]) 

761 col_offset += n_pts_col 

762 

763 # First Block-Column: Derivative-Function (K_df) 

764 row_offset = n_rows_func 

765 for i in range(n_deriv_types): 

766 if calc_cov: 

767 row_indices = index_cov 

768 flat_idx = der_indices_tr_pred[i] 

769 else: 

770 if not index_arrays: 

771 break 

772 row_indices = index_arrays[i] 

773 flat_idx = der_indices_tr[i] 

774 n_pts_row = len(row_indices) 

775 

776 content_full = phi_exp[flat_idx].reshape(base_shape) 

777 

778 # Use numba for efficient row extraction 

779 extract_rows_and_assign(content_full, row_indices, K, 

780 row_offset, 0, n_cols_func, signs[0]) 

781 row_offset += n_pts_row 

782 

783 # Inner Blocks: Derivative-Derivative (K_dd) 

784 row_offset = n_rows_func 

785 for i in range(n_deriv_types): 

786 if calc_cov: 

787 row_indices = index_cov 

788 else: 

789 if not index_arrays: 

790 break 

791 row_indices = index_arrays[i] 

792 n_pts_row = len(row_indices) 

793 

794 col_offset = n_cols_func 

795 for j in range(n_deriv_types_pred): 

796 col_indices = index_2 

797 n_pts_col = len(col_indices) 

798 

799 # Multiply the derivative indices to find the correct flat index 

800 imdir1 = der_ind_order_pred[j] 

801 imdir2 = der_ind_order_pred[i] if calc_cov else der_ind_order[i] 

802 new_idx, new_ord = dh.mult_dir( 

803 imdir1[0], imdir1[1], imdir2[0], imdir2[1]) 

804 flat_idx = der_map[new_ord][new_idx] 

805 

806 content_full = phi_exp[flat_idx].reshape(base_shape) 

807 

808 # Use numba for efficient submatrix extraction and assignment (replaces np.ix_) 

809 extract_and_assign(content_full, row_indices, col_indices, K, 

810 row_offset, col_offset, signs_predict[j + 1]) 

811 col_offset += n_pts_col 

812 row_offset += n_pts_row 

813 

814 return K 

815 

816 

817# ============================================================================= 

818# Utility functions 

819# ============================================================================= 

820 

821def determine_weights(diffs_by_dim, diffs_test, length_scales, kernel_func, sigma_n): 

822 """ 

823 Vectorized version: compute interpolation weights for multiple test points at once. 

824  

825 Parameters 

826 ---------- 

827 diffs_by_dim : list of ndarray 

828 Pairwise differences between training points (by dimension). 

829 diffs_test : list of ndarray 

830 Pairwise differences between test points and training points (by dimension). 

831 Shape: each array is (n_test, n_train) or similar batch dimension. 

832 length_scales : array-like 

833 Kernel hyperparameters. 

834 kernel_func : callable 

835 Kernel function. 

836 sigma_n : float 

837 Noise parameter (if needed). 

838  

839 Returns 

840 ------- 

841 weights_matrix : ndarray of shape (n_test, n_train) 

842 Interpolation weights for each test point. 

843 """ 

844 # Compute K matrix (training covariance) - same for all test points 

845 K = kernel_func(diffs_by_dim, length_scales).real 

846 n_train = K.shape[0] 

847 

848 # Compute r vectors (test-train covariances) for all test points at once 

849 r_all = kernel_func(diffs_test, length_scales).real 

850 n_test = r_all.shape[0] 

851 

852 # Build augmented system matrix M (same for all test points) 

853 M = np.zeros((n_train + 1, n_train + 1)) 

854 M[:n_train, :n_train] = K 

855 M[:n_train, n_train] = 1 

856 M[n_train, :n_train] = 1 

857 M[n_train, n_train] = 0 

858 

859 # Build augmented RHS for all test points 

860 r_augmented = np.zeros((n_test, n_train + 1)) 

861 r_augmented[:, :n_train] = r_all 

862 r_augmented[:, n_train] = 1 

863 

864 # Solve for all test points at once 

865 solution = np.linalg.solve(M, r_augmented.T) 

866 

867 # Extract weights (exclude Lagrange multiplier) 

868 weights_matrix = solution[:n_train, :].T 

869 

870 return weights_matrix 

871 

872 

873def to_tuple(item): 

874 """Convert list to tuple recursively.""" 

875 if isinstance(item, list): 

876 return tuple(to_tuple(x) for x in item) 

877 return item 

878 

879 

880def to_list(x): 

881 """Convert tuple to list recursively.""" 

882 if isinstance(x, tuple): 

883 return [to_list(i) for i in x] 

884 return x 

885 

886 

887def find_common_derivatives(all_indices): 

888 """Find derivative indices common to all submodels.""" 

889 sets = [set(to_tuple(elem) for elem in idx_list) for idx_list in all_indices] 

890 return sets[0].intersection(*sets[1:])