Coverage for jetgp/full_degp/degp_utils.py: 67%
370 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
1import numpy as np
2import pyoti.core as coti
3from line_profiler import profile
4import numba
7# =============================================================================
8# Numba-accelerated helper functions for efficient matrix slicing
9# =============================================================================
11@numba.jit(nopython=True, cache=True)
12def extract_rows(content_full, row_indices, n_cols):
13 """
14 Extract rows from content_full at specified indices.
16 Parameters
17 ----------
18 content_full : ndarray of shape (n_rows_full, n_cols)
19 Source matrix.
20 row_indices : ndarray of int64
21 Row indices to extract.
22 n_cols : int
23 Number of columns.
25 Returns
26 -------
27 result : ndarray of shape (len(row_indices), n_cols)
28 Extracted rows.
29 """
30 n_rows = len(row_indices)
31 result = np.empty((n_rows, n_cols))
32 for i in range(n_rows):
33 ri = row_indices[i]
34 for j in range(n_cols):
35 result[i, j] = content_full[ri, j]
36 return result
39@numba.jit(nopython=True, cache=True)
40def extract_cols(content_full, col_indices, n_rows):
41 """
42 Extract columns from content_full at specified indices.
44 Parameters
45 ----------
46 content_full : ndarray of shape (n_rows, n_cols_full)
47 Source matrix.
48 col_indices : ndarray of int64
49 Column indices to extract.
50 n_rows : int
51 Number of rows.
53 Returns
54 -------
55 result : ndarray of shape (n_rows, len(col_indices))
56 Extracted columns.
57 """
58 n_cols = len(col_indices)
59 result = np.empty((n_rows, n_cols))
60 for i in range(n_rows):
61 for j in range(n_cols):
62 result[i, j] = content_full[i, col_indices[j]]
63 return result
66@numba.jit(nopython=True, cache=True)
67def extract_submatrix(content_full, row_indices, col_indices):
68 """
69 Extract submatrix from content_full at specified row and column indices.
70 Replaces the expensive np.ix_ operation.
72 Parameters
73 ----------
74 content_full : ndarray of shape (n_rows_full, n_cols_full)
75 Source matrix.
76 row_indices : ndarray of int64
77 Row indices to extract.
78 col_indices : ndarray of int64
79 Column indices to extract.
81 Returns
82 -------
83 result : ndarray of shape (len(row_indices), len(col_indices))
84 Extracted submatrix.
85 """
86 n_rows = len(row_indices)
87 n_cols = len(col_indices)
88 result = np.empty((n_rows, n_cols))
89 for i in range(n_rows):
90 ri = row_indices[i]
91 for j in range(n_cols):
92 result[i, j] = content_full[ri, col_indices[j]]
93 return result
96@numba.jit(nopython=True, cache=True, parallel=False)
97def extract_and_assign(content_full, row_indices, col_indices, K,
98 row_start, col_start, sign):
99 """
100 Extract submatrix and assign directly to K with sign multiplication.
101 Combines extraction and assignment in one pass for better performance.
103 Parameters
104 ----------
105 content_full : ndarray of shape (n_rows_full, n_cols_full)
106 Source matrix.
107 row_indices : ndarray of int64
108 Row indices to extract.
109 col_indices : ndarray of int64
110 Column indices to extract.
111 K : ndarray
112 Target matrix to fill.
113 row_start : int
114 Starting row index in K.
115 col_start : int
116 Starting column index in K.
117 sign : float
118 Sign multiplier (+1.0 or -1.0).
119 """
120 n_rows = len(row_indices)
121 n_cols = len(col_indices)
122 for i in range(n_rows):
123 ri = row_indices[i]
124 for j in range(n_cols):
125 K[row_start + i, col_start + j] = content_full[ri, col_indices[j]] * sign
128@numba.jit(nopython=True, cache=True)
129def extract_rows_and_assign(content_full, row_indices, K,
130 row_start, col_start, n_cols, sign):
131 """
132 Extract rows and assign directly to K with sign multiplication.
134 Parameters
135 ----------
136 content_full : ndarray of shape (n_rows_full, n_cols)
137 Source matrix.
138 row_indices : ndarray of int64
139 Row indices to extract.
140 K : ndarray
141 Target matrix to fill.
142 row_start : int
143 Starting row index in K.
144 col_start : int
145 Starting column index in K.
146 n_cols : int
147 Number of columns to copy.
148 sign : float
149 Sign multiplier (+1.0 or -1.0).
150 """
151 n_rows = len(row_indices)
152 for i in range(n_rows):
153 ri = row_indices[i]
154 for j in range(n_cols):
155 K[row_start + i, col_start + j] = content_full[ri, j] * sign
158@numba.jit(nopython=True, cache=True)
159def extract_cols_and_assign(content_full, col_indices, K,
160 row_start, col_start, n_rows, sign):
161 """
162 Extract columns and assign directly to K with sign multiplication.
164 Parameters
165 ----------
166 content_full : ndarray of shape (n_rows, n_cols_full)
167 Source matrix.
168 col_indices : ndarray of int64
169 Column indices to extract.
170 K : ndarray
171 Target matrix to fill.
172 row_start : int
173 Starting row index in K.
174 col_start : int
175 Starting column index in K.
176 n_rows : int
177 Number of rows to copy.
178 sign : float
179 Sign multiplier (+1.0 or -1.0).
180 """
181 n_cols = len(col_indices)
182 for i in range(n_rows):
183 for j in range(n_cols):
184 K[row_start + i, col_start + j] = content_full[i, col_indices[j]] * sign
187# =============================================================================
188# Difference computation functions
189# =============================================================================
192def differences_by_dim_func(X1, X2, n_order, oti_module, return_deriv=True):
193 """
194 Compute pairwise differences between two input arrays X1 and X2 for each dimension,
195 embedding hypercomplex units along each dimension for automatic differentiation.
197 For each dimension k, this function computes:
198 diff_k[i, j] = X1[i, k] + e_{k+1} - X2[j, k]
199 where e_{k+1} is a hypercomplex unit for the (k+1)-th dimension with order 2 * n_order.
201 Parameters
202 ----------
203 X1 : array_like of shape (n1, d)
204 First set of input points with n1 samples in d dimensions.
205 X2 : array_like of shape (n2, d)
206 Second set of input points with n2 samples in d dimensions.
207 n_order : int
208 The base order used to construct hypercomplex units (e_{k+1}) with order 2 * n_order.
209 oti_module : module
210 The PyOTI static module (e.g., pyoti.static.onumm4n2).
211 return_deriv : bool, optional
212 If True, use 2*n_order for derivative predictions.
214 Returns
215 -------
216 differences_by_dim : list of length d
217 A list where each element is an array of shape (n1, n2), containing the differences
218 between corresponding dimensions of X1 and X2, augmented with hypercomplex units.
219 """
220 X1 = oti_module.array(X1)
221 X2 = oti_module.array(X2)
222 n1, d = X1.shape
223 n2, d = X2.shape
225 differences_by_dim = []
227 if n_order == 0:
228 for k in range(d):
229 diffs_k = oti_module.zeros((n1, n2))
230 for i in range(n1):
231 diffs_k[i, :] = X1[i, k] - (oti_module.transpose(X2[:, k]))
232 differences_by_dim.append(diffs_k)
233 elif not return_deriv:
234 for k in range(d):
235 diffs_k = oti_module.zeros((n1, n2))
236 for i in range(n1):
237 diffs_k[i, :] = (
238 X1[i, k]
239 + oti_module.e(k + 1, order=n_order)
240 - (X2[:, k].T)
241 )
242 differences_by_dim.append(diffs_k)
243 else:
244 for k in range(d):
245 diffs_k = oti_module.zeros((n1, n2))
246 for i in range(n1):
247 diffs_k[i, :] = X1[i, k] - (X2[:, k].T)
248 differences_by_dim.append(
249 diffs_k + oti_module.e(k + 1, order=2 * n_order))
251 return differences_by_dim
254# =============================================================================
255# Derivative mapping utilities
256# =============================================================================
258def deriv_map(nbases, order):
259 """
260 Create mapping from (order, index) to flattened index.
262 Parameters
263 ----------
264 nbases : int
265 Number of base dimensions.
266 order : int
267 Maximum derivative order.
269 Returns
270 -------
271 map_deriv : list of lists
272 Mapping where map_deriv[order][idx] gives the flattened index.
273 """
274 k = 0
275 map_deriv = []
276 for ordi in range(order + 1):
277 ndir = coti.ndir_order(nbases, ordi)
278 map_deriv_i = [0] * ndir
279 for idx in range(ndir):
280 map_deriv_i[idx] = k
281 k += 1
282 map_deriv.append(map_deriv_i)
283 return map_deriv
286def transform_der_indices(der_indices, der_map):
287 """
288 Transform derivative indices to flattened format.
290 Parameters
291 ----------
292 der_indices : list
293 User-facing derivative specifications.
294 der_map : list of lists
295 Derivative mapping from deriv_map().
297 Returns
298 -------
299 deriv_ind_transf : list
300 Flattened indices for each derivative.
301 deriv_ind_order : list
302 (index, order) tuples for each derivative.
303 """
304 deriv_ind_transf = []
305 deriv_ind_order = []
306 for deriv in der_indices:
307 imdir = coti.imdir(deriv)
308 idx, order = imdir
309 deriv_ind_transf.append(der_map[order][idx])
310 deriv_ind_order.append(imdir)
311 return deriv_ind_transf, deriv_ind_order
314# =============================================================================
315# RBF Kernel Assembly Functions (Optimized with Numba)
316# =============================================================================
318@profile
319def rbf_kernel(
320 phi,
321 phi_exp,
322 n_order,
323 n_bases,
324 der_indices,
325 powers,
326 index=None
327):
328 """
329 Compute the derivative-enhanced RBF kernel matrix (optimized version).
331 This version uses Numba-accelerated functions for efficient matrix slicing,
332 replacing expensive np.ix_ operations.
334 Parameters
335 ----------
336 phi : OTI array
337 Base kernel matrix from kernel_func(differences, length_scales).
338 phi_exp : ndarray
339 Expanded derivative array from phi.get_all_derivs().
340 n_order : int
341 Maximum derivative order considered.
342 n_bases : int
343 Number of input dimensions.
344 der_indices : list of lists
345 Multi-index derivative structures for each derivative component.
346 powers : list of int
347 Powers of (-1) applied to each term.
348 index : list of lists or None, optional
349 If empty list, assumes uniform blocks.
350 If provided, specifies which training point indices have each derivative type.
352 Returns
353 -------
354 K : ndarray
355 Kernel matrix including function values and derivative terms.
356 """
357 dh = coti.get_dHelp()
359 n_rows_func, n_cols_func = phi.shape
360 n_deriv_types = len(der_indices)
362 der_map = deriv_map(n_bases, 2 * n_order)
363 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
365 # Pre-compute signs (avoid repeated exponentiation)
366 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
368 # =========================================================================
369 # CASE 1: Uniform blocks (original behavior) - index is None or empty
370 # =========================================================================
371 if index is None or len(index) == 0:
372 K = np.zeros((n_rows_func * (n_deriv_types + 1),
373 n_cols_func * (n_deriv_types + 1)))
374 outer_loop_index = n_deriv_types + 1
376 for j in range(outer_loop_index):
377 signj = signs[j]
378 for i in range(n_deriv_types + 1):
379 Klocal = K[i * n_rows_func: (i + 1) * n_rows_func,
380 j * n_cols_func: (j + 1) * n_cols_func]
381 if j == 0 and i == 0:
382 Klocal[:, :] = phi_exp[0] * signj
384 return K
386 # =========================================================================
387 # CASE 2: Non-contiguous indices - index is provided
388 # =========================================================================
389 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index)
390 total_rows = n_rows_func + n_pts_with_derivs_rows
391 n_pts_with_derivs_cols = sum(len(order_indices) for order_indices in index)
392 total_cols = n_cols_func + n_pts_with_derivs_cols
394 K = np.zeros((total_rows, total_cols))
395 base_shape = (n_rows_func, n_cols_func)
397 # Convert index lists to numpy arrays for numba
398 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
400 # Block (0,0): Function-Function (K_ff)
401 content_full = phi_exp[0].reshape(base_shape)
402 K[:n_rows_func, :n_cols_func] = content_full * signs[0]
404 # First Block-Column: Derivative-Function (K_df)
405 row_offset = n_rows_func
406 for i in range(n_deriv_types):
407 flat_idx = der_indices_tr[i]
408 content_full = phi_exp[flat_idx].reshape(base_shape)
409 row_indices = index_arrays[i]
410 n_pts_this_order = len(row_indices)
412 # Use numba for efficient row extraction and assignment
413 extract_rows_and_assign(content_full, row_indices, K,
414 row_offset, 0, n_cols_func, signs[0])
415 row_offset += n_pts_this_order
417 # First Block-Row: Function-Derivative (K_fd)
418 col_offset = n_cols_func
419 for j in range(n_deriv_types):
420 flat_idx = der_indices_tr[j]
421 content_full = phi_exp[flat_idx].reshape(base_shape)
422 col_indices = index_arrays[j]
423 n_pts_this_order = len(col_indices)
425 # Use numba for efficient column extraction and assignment
426 extract_cols_and_assign(content_full, col_indices, K,
427 0, col_offset, n_rows_func, signs[j + 1])
428 col_offset += n_pts_this_order
430 # Inner Blocks: Derivative-Derivative (K_dd)
431 row_offset = n_rows_func
432 for i in range(n_deriv_types):
433 col_offset = n_cols_func
434 row_indices = index_arrays[i]
435 n_pts_row = len(row_indices)
437 for j in range(n_deriv_types):
438 col_indices = index_arrays[j]
439 n_pts_col = len(col_indices)
441 imdir1 = der_ind_order[j]
442 imdir2 = der_ind_order[i]
443 new_idx, new_ord = dh.mult_dir(
444 imdir1[0], imdir1[1], imdir2[0], imdir2[1])
445 flat_idx = der_map[new_ord][new_idx]
446 content_full = phi_exp[flat_idx].reshape(base_shape)
448 # Use numba for direct extraction and assignment (replaces np.ix_)
449 extract_and_assign(content_full, row_indices, col_indices, K,
450 row_offset, col_offset, signs[j + 1])
452 col_offset += n_pts_col
453 row_offset += n_pts_row
455 return K
458@numba.jit(nopython=True, cache=True)
459def _assemble_kernel_numba(phi_exp_3d, K, n_rows_func, n_cols_func,
460 fd_flat_indices, df_flat_indices, dd_flat_indices,
461 idx_flat, idx_offsets, idx_sizes,
462 signs, n_deriv_types, row_offsets, col_offsets):
463 """
464 Fused numba kernel that assembles the entire K matrix in a single call.
465 Handles ff, fd, df, and dd blocks without Python-level loop overhead.
466 """
467 # Block (0,0): Function-Function
468 s0 = signs[0]
469 for r in range(n_rows_func):
470 for c in range(n_cols_func):
471 K[r, c] = phi_exp_3d[0, r, c] * s0
473 # First Block-Row: Function-Derivative (fd)
474 for j in range(n_deriv_types):
475 fi = fd_flat_indices[j]
476 sj = signs[j + 1]
477 co = col_offsets[j]
478 off_j = idx_offsets[j]
479 sz_j = idx_sizes[j]
480 for r in range(n_rows_func):
481 for k in range(sz_j):
482 ci = idx_flat[off_j + k]
483 K[r, co + k] = phi_exp_3d[fi, r, ci] * sj
485 # First Block-Column: Derivative-Function (df)
486 for i in range(n_deriv_types):
487 fi = df_flat_indices[i]
488 ro = row_offsets[i]
489 off_i = idx_offsets[i]
490 sz_i = idx_sizes[i]
491 for k in range(sz_i):
492 ri = idx_flat[off_i + k]
493 for c in range(n_cols_func):
494 K[ro + k, c] = phi_exp_3d[fi, ri, c] * s0
496 # Inner Blocks: Derivative-Derivative (dd)
497 for i in range(n_deriv_types):
498 ro = row_offsets[i]
499 off_i = idx_offsets[i]
500 sz_i = idx_sizes[i]
501 for j in range(n_deriv_types):
502 fi = dd_flat_indices[i, j]
503 sj = signs[j + 1]
504 co = col_offsets[j]
505 off_j = idx_offsets[j]
506 sz_j = idx_sizes[j]
507 for ki in range(sz_i):
508 ri = idx_flat[off_i + ki]
509 for kj in range(sz_j):
510 ci = idx_flat[off_j + kj]
511 K[ro + ki, co + kj] = phi_exp_3d[fi, ri, ci] * sj
514@numba.jit(nopython=True, cache=True)
515def _project_W_to_phi_space(W, W_proj, n_rows_func, n_cols_func,
516 fd_flat_indices, df_flat_indices, dd_flat_indices,
517 idx_flat, idx_offsets, idx_sizes,
518 signs, n_deriv_types, row_offsets, col_offsets):
519 """
520 Reverse of _assemble_kernel_numba: project W from K-space back into
521 phi_exp-space so that vdot(W, assemble(dphi_exp)) == vdot(W_proj, dphi_exp).
523 This allows computing gradient contributions without materialising the
524 full dK matrix for each hyperparameter dimension.
525 """
526 # Zero out W_proj
527 for d in range(W_proj.shape[0]):
528 for r in range(W_proj.shape[1]):
529 for c in range(W_proj.shape[2]):
530 W_proj[d, r, c] = 0.0
532 s0 = signs[0]
534 # ff block: K[r, c] = phi_exp[0, r, c] * s0
535 for r in range(n_rows_func):
536 for c in range(n_cols_func):
537 W_proj[0, r, c] += s0 * W[r, c]
539 # fd blocks: K[r, co+k] = phi_exp[fi, r, idx[k]] * sj
540 for j in range(n_deriv_types):
541 fi = fd_flat_indices[j]
542 sj = signs[j + 1]
543 co = col_offsets[j]
544 off_j = idx_offsets[j]
545 sz_j = idx_sizes[j]
546 for r in range(n_rows_func):
547 for k in range(sz_j):
548 ci = idx_flat[off_j + k]
549 W_proj[fi, r, ci] += sj * W[r, co + k]
551 # df blocks: K[ro+k, c] = phi_exp[fi, idx[k], c] * s0
552 for i in range(n_deriv_types):
553 fi = df_flat_indices[i]
554 ro = row_offsets[i]
555 off_i = idx_offsets[i]
556 sz_i = idx_sizes[i]
557 for k in range(sz_i):
558 ri = idx_flat[off_i + k]
559 for c in range(n_cols_func):
560 W_proj[fi, ri, c] += s0 * W[ro + k, c]
562 # dd blocks: K[ro+ki, co+kj] = phi_exp[dd_fi[i,j], idx_i[ki], idx_j[kj]] * sj
563 for i in range(n_deriv_types):
564 ro = row_offsets[i]
565 off_i = idx_offsets[i]
566 sz_i = idx_sizes[i]
567 for j in range(n_deriv_types):
568 fi = dd_flat_indices[i, j]
569 sj = signs[j + 1]
570 co = col_offsets[j]
571 off_j = idx_offsets[j]
572 sz_j = idx_sizes[j]
573 for ki in range(sz_i):
574 ri = idx_flat[off_i + ki]
575 for kj in range(sz_j):
576 ci = idx_flat[off_j + kj]
577 W_proj[fi, ri, ci] += sj * W[ro + ki, co + kj]
580def precompute_kernel_plan(n_order, n_bases, der_indices, powers, index):
581 """
582 Precompute all structural information needed by rbf_kernel so it can be
583 reused across repeated calls with different phi_exp values.
585 Returns a dict containing flat indices, signs, index arrays, precomputed
586 offsets/sizes, and mult_dir results for the dd block.
587 """
588 dh = coti.get_dHelp()
589 der_map = deriv_map(n_bases, 2 * n_order)
590 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
592 n_deriv_types = len(der_indices)
593 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
594 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
596 # Precompute sizes and offsets
597 index_sizes = np.array([len(idx) for idx in index_arrays], dtype=np.int64)
598 n_pts_with_derivs = int(index_sizes.sum())
600 # Pack all index arrays into a single flat array with offsets
601 idx_flat = np.concatenate(index_arrays) if n_deriv_types > 0 else np.array([], dtype=np.int64)
602 idx_offsets = np.zeros(n_deriv_types, dtype=np.int64)
603 for i in range(1, n_deriv_types):
604 idx_offsets[i] = idx_offsets[i - 1] + index_sizes[i - 1]
606 # Precompute row/col offsets in K for each deriv type
607 row_offsets = np.zeros(n_deriv_types, dtype=np.int64)
608 col_offsets = np.zeros(n_deriv_types, dtype=np.int64)
609 # Note: n_rows_func == n_cols_func for training kernel, but we store
610 # offsets relative to n_rows_func which is added at call time
611 cumsum = 0
612 for i in range(n_deriv_types):
613 row_offsets[i] = cumsum # relative to n_rows_func
614 col_offsets[i] = cumsum # relative to n_cols_func
615 cumsum += index_sizes[i]
617 # Precompute mult_dir results for dd blocks
618 dd_flat_indices = np.empty((n_deriv_types, n_deriv_types), dtype=np.int64)
619 for i in range(n_deriv_types):
620 for j in range(n_deriv_types):
621 imdir1 = der_ind_order[j]
622 imdir2 = der_ind_order[i]
623 new_idx, new_ord = dh.mult_dir(
624 imdir1[0], imdir1[1], imdir2[0], imdir2[1])
625 dd_flat_indices[i, j] = der_map[new_ord][new_idx]
627 # fd and df flat indices as arrays
628 fd_flat_indices = np.array(der_indices_tr, dtype=np.int64)
629 df_flat_indices = np.array(der_indices_tr, dtype=np.int64)
631 return {
632 'der_indices_tr': der_indices_tr,
633 'signs': signs,
634 'index_arrays': index_arrays,
635 'index_sizes': index_sizes,
636 'n_pts_with_derivs': n_pts_with_derivs,
637 'dd_flat_indices': dd_flat_indices,
638 'n_deriv_types': n_deriv_types,
639 # Fused kernel data
640 'idx_flat': idx_flat,
641 'idx_offsets': idx_offsets,
642 'row_offsets': row_offsets,
643 'col_offsets': col_offsets,
644 'fd_flat_indices': fd_flat_indices,
645 'df_flat_indices': df_flat_indices,
646 }
649def rbf_kernel_fast(phi_exp_3d, plan, out=None):
650 """
651 Fast kernel assembly using a precomputed plan and fused numba kernel.
653 Parameters
654 ----------
655 phi_exp_3d : ndarray of shape (n_derivs, n_rows_func, n_cols_func)
656 Pre-reshaped expanded derivative array.
657 plan : dict
658 Precomputed plan from precompute_kernel_plan().
659 out : ndarray, optional
660 Pre-allocated output array of shape (total, total). If None, a new
661 array is allocated. Reusing a buffer avoids repeated allocation of
662 large matrices during optimization loops.
664 Returns
665 -------
666 K : ndarray
667 Full kernel matrix.
668 """
669 n_rows_func = phi_exp_3d.shape[1]
670 n_cols_func = phi_exp_3d.shape[2]
671 total = n_rows_func + plan['n_pts_with_derivs']
672 if out is not None:
673 K = out
674 else:
675 K = np.empty((total, total))
677 # Use cached offsets if available, otherwise compute them
678 if 'row_offsets_abs' in plan:
679 row_off = plan['row_offsets_abs']
680 col_off = plan['col_offsets_abs']
681 else:
682 row_off = plan['row_offsets'] + n_rows_func
683 col_off = plan['col_offsets'] + n_cols_func
685 _assemble_kernel_numba(
686 phi_exp_3d, K, n_rows_func, n_cols_func,
687 plan['fd_flat_indices'], plan['df_flat_indices'], plan['dd_flat_indices'],
688 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'],
689 plan['signs'], plan['n_deriv_types'], row_off, col_off,
690 )
692 return K
695def rbf_kernel_predictions(
696 phi,
697 phi_exp,
698 n_order,
699 n_bases,
700 der_indices,
701 powers,
702 return_deriv,
703 index=None,
704 common_derivs=None,
705 calc_cov=False,
706 powers_predict=None
707):
708 """
709 Constructs the RBF kernel matrix for predictions with derivative entries.
711 This version uses Numba-accelerated functions for efficient matrix slicing.
713 Parameters
714 ----------
715 phi : OTI array
716 Base kernel matrix between test and training points.
717 phi_exp : ndarray
718 Expanded derivative array from phi.get_all_derivs().
719 n_order : int
720 Maximum derivative order.
721 n_bases : int
722 Number of input dimensions.
723 der_indices : list
724 Derivative specifications for training data.
725 powers : list of int
726 Sign powers for each derivative type.
727 return_deriv : bool
728 If True, predict derivatives at ALL test points.
729 index : list of lists or None
730 Training point indices for each derivative type.
731 common_derivs : list
732 Common derivative indices to predict.
733 calc_cov : bool
734 If True, computing covariance (use all indices for rows).
735 powers_predict : list of int, optional
736 Sign powers for prediction derivatives.
738 Returns
739 -------
740 K : ndarray
741 Prediction kernel matrix.
742 """
743 # Early return for covariance-only case
744 if calc_cov and not return_deriv:
745 return phi.real
747 dh = coti.get_dHelp()
749 n_rows_func, n_cols_func = phi.shape
750 n_deriv_types = len(der_indices)
751 n_deriv_types_pred = len(common_derivs) if common_derivs else 0
753 # Pre-compute signs
754 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
755 if powers_predict is not None:
756 signs_predict = np.array(
757 [(-1.0) ** p for p in powers_predict], dtype=np.float64)
758 else:
759 signs_predict = signs
761 # Determine derivative map and index structures
762 if return_deriv:
763 der_map = deriv_map(n_bases, 2 * n_order)
764 index_2 = np.arange(n_cols_func, dtype=np.int64)
765 if calc_cov:
766 index_cov = np.arange(n_cols_func, dtype=np.int64)
767 n_deriv_types = n_deriv_types_pred
768 n_pts_with_derivs_rows = n_deriv_types * n_cols_func
769 else:
770 n_pts_with_derivs_rows = sum(len(order_indices)
771 for order_indices in index) if index else 0
772 else:
773 der_map = deriv_map(n_bases, n_order)
774 index_2 = np.array([], dtype=np.int64)
775 n_pts_with_derivs_rows = sum(len(order_indices)
776 for order_indices in index) if index else 0
778 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
779 der_indices_tr_pred, der_ind_order_pred = transform_der_indices(
780 common_derivs, der_map) if common_derivs else ([], [])
781 n_pts_with_derivs_cols = n_deriv_types_pred * len(index_2)
783 total_rows = n_rows_func + n_pts_with_derivs_rows
784 total_cols = n_cols_func + n_pts_with_derivs_cols
786 K = np.zeros((total_rows, total_cols))
787 base_shape = (n_rows_func, n_cols_func)
789 # Convert index lists to numpy arrays for numba
790 if index is not None and len(index) > 0 and isinstance(index[0], (list, np.ndarray)):
791 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
792 else:
793 index_arrays = []
795 # Block (0,0): Function-Function (K_ff)
796 content_full = phi_exp[0].reshape(base_shape)
797 K[:n_rows_func, :n_cols_func] = content_full * signs[0]
799 if not return_deriv:
800 # First Block-Column: Derivative-Function (K_df)
801 row_offset = n_rows_func
802 for i in range(n_deriv_types):
803 if not index_arrays:
804 break
806 row_indices = index_arrays[i]
807 n_pts_row = len(row_indices)
809 flat_idx = der_indices_tr[i]
810 content_full = phi_exp[flat_idx].reshape(base_shape)
812 # Use numba for efficient row extraction
813 extract_rows_and_assign(content_full, row_indices, K,
814 row_offset, 0, n_cols_func, signs[0])
815 row_offset += n_pts_row
816 return K
818 # --- return_deriv=True case ---
820 # First Block-Row: Function-Derivative (K_fd)
821 col_offset = n_cols_func
822 for j in range(n_deriv_types_pred):
823 n_pts_col = len(index_2)
825 flat_idx = der_indices_tr_pred[j]
826 content_full = phi_exp[flat_idx].reshape(base_shape)
828 # Use numba for efficient column extraction
829 extract_cols_and_assign(content_full, index_2, K,
830 0, col_offset, n_rows_func, signs_predict[j + 1])
831 col_offset += n_pts_col
833 # First Block-Column: Derivative-Function (K_df)
834 row_offset = n_rows_func
835 for i in range(n_deriv_types):
836 if calc_cov:
837 row_indices = index_cov
838 flat_idx = der_indices_tr_pred[i]
839 else:
840 if not index_arrays:
841 break
842 row_indices = index_arrays[i]
843 flat_idx = der_indices_tr[i]
844 n_pts_row = len(row_indices)
846 content_full = phi_exp[flat_idx].reshape(base_shape)
848 # Use numba for efficient row extraction
849 extract_rows_and_assign(content_full, row_indices, K,
850 row_offset, 0, n_cols_func, signs[0])
851 row_offset += n_pts_row
853 # Inner Blocks: Derivative-Derivative (K_dd)
854 row_offset = n_rows_func
855 for i in range(n_deriv_types):
856 if calc_cov:
857 row_indices = index_cov
858 else:
859 if not index_arrays:
860 break
861 row_indices = index_arrays[i]
862 n_pts_row = len(row_indices)
864 col_offset = n_cols_func
865 for j in range(n_deriv_types_pred):
866 n_pts_col = len(index_2)
868 imdir1 = der_ind_order_pred[j]
869 imdir2 = der_ind_order_pred[i] if calc_cov else der_ind_order[i]
870 new_idx, new_ord = dh.mult_dir(
871 imdir1[0], imdir1[1], imdir2[0], imdir2[1])
872 flat_idx = der_map[new_ord][new_idx]
874 content_full = phi_exp[flat_idx].reshape(base_shape)
876 # Use numba for efficient submatrix extraction and assignment (replaces np.ix_)
877 extract_and_assign(content_full, row_indices, index_2, K,
878 row_offset, col_offset, signs_predict[j + 1])
879 col_offset += n_pts_col
880 row_offset += n_pts_row
882 return K