Coverage for jetgp/full_ddegp/ddegp_utils.py: 68%
376 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
1import numpy as np
2import numba
3import pyoti.core as coti
4from line_profiler import profile
7# =============================================================================
8# Numba-accelerated helper functions for efficient matrix slicing
9# =============================================================================
11@numba.jit(nopython=True, cache=True)
12def extract_rows(content_full, row_indices, n_cols):
13 """
14 Extract rows from content_full at specified indices.
16 Parameters
17 ----------
18 content_full : ndarray of shape (n_rows_full, n_cols)
19 Source matrix.
20 row_indices : ndarray of int64
21 Row indices to extract.
22 n_cols : int
23 Number of columns.
25 Returns
26 -------
27 result : ndarray of shape (len(row_indices), n_cols)
28 Extracted rows.
29 """
30 n_rows = len(row_indices)
31 result = np.empty((n_rows, n_cols))
32 for i in range(n_rows):
33 ri = row_indices[i]
34 for j in range(n_cols):
35 result[i, j] = content_full[ri, j]
36 return result
39@numba.jit(nopython=True, cache=True)
40def extract_cols(content_full, col_indices, n_rows):
41 """
42 Extract columns from content_full at specified indices.
44 Parameters
45 ----------
46 content_full : ndarray of shape (n_rows, n_cols_full)
47 Source matrix.
48 col_indices : ndarray of int64
49 Column indices to extract.
50 n_rows : int
51 Number of rows.
53 Returns
54 -------
55 result : ndarray of shape (n_rows, len(col_indices))
56 Extracted columns.
57 """
58 n_cols = len(col_indices)
59 result = np.empty((n_rows, n_cols))
60 for i in range(n_rows):
61 for j in range(n_cols):
62 result[i, j] = content_full[i, col_indices[j]]
63 return result
66@numba.jit(nopython=True, cache=True)
67def extract_submatrix(content_full, row_indices, col_indices):
68 """
69 Extract submatrix from content_full at specified row and column indices.
70 Replaces the expensive np.ix_ operation.
72 Parameters
73 ----------
74 content_full : ndarray of shape (n_rows_full, n_cols_full)
75 Source matrix.
76 row_indices : ndarray of int64
77 Row indices to extract.
78 col_indices : ndarray of int64
79 Column indices to extract.
81 Returns
82 -------
83 result : ndarray of shape (len(row_indices), len(col_indices))
84 Extracted submatrix.
85 """
86 n_rows = len(row_indices)
87 n_cols = len(col_indices)
88 result = np.empty((n_rows, n_cols))
89 for i in range(n_rows):
90 ri = row_indices[i]
91 for j in range(n_cols):
92 result[i, j] = content_full[ri, col_indices[j]]
93 return result
96@numba.jit(nopython=True, cache=True, parallel=False)
97def extract_and_assign(content_full, row_indices, col_indices, K,
98 row_start, col_start, sign):
99 """
100 Extract submatrix and assign directly to K with sign multiplication.
101 Combines extraction and assignment in one pass for better performance.
103 Parameters
104 ----------
105 content_full : ndarray of shape (n_rows_full, n_cols_full)
106 Source matrix.
107 row_indices : ndarray of int64
108 Row indices to extract.
109 col_indices : ndarray of int64
110 Column indices to extract.
111 K : ndarray
112 Target matrix to fill.
113 row_start : int
114 Starting row index in K.
115 col_start : int
116 Starting column index in K.
117 sign : float
118 Sign multiplier (+1.0 or -1.0).
119 """
120 n_rows = len(row_indices)
121 n_cols = len(col_indices)
122 for i in range(n_rows):
123 ri = row_indices[i]
124 for j in range(n_cols):
125 K[row_start + i, col_start + j] = content_full[ri, col_indices[j]] * sign
128@numba.jit(nopython=True, cache=True)
129def extract_rows_and_assign(content_full, row_indices, K,
130 row_start, col_start, n_cols, sign):
131 """
132 Extract rows and assign directly to K with sign multiplication.
134 Parameters
135 ----------
136 content_full : ndarray of shape (n_rows_full, n_cols)
137 Source matrix.
138 row_indices : ndarray of int64
139 Row indices to extract.
140 K : ndarray
141 Target matrix to fill.
142 row_start : int
143 Starting row index in K.
144 col_start : int
145 Starting column index in K.
146 n_cols : int
147 Number of columns to copy.
148 sign : float
149 Sign multiplier (+1.0 or -1.0).
150 """
151 n_rows = len(row_indices)
152 for i in range(n_rows):
153 ri = row_indices[i]
154 for j in range(n_cols):
155 K[row_start + i, col_start + j] = content_full[ri, j] * sign
158@numba.jit(nopython=True, cache=True)
159def extract_cols_and_assign(content_full, col_indices, K,
160 row_start, col_start, n_rows, sign):
161 """
162 Extract columns and assign directly to K with sign multiplication.
164 Parameters
165 ----------
166 content_full : ndarray of shape (n_rows, n_cols_full)
167 Source matrix.
168 col_indices : ndarray of int64
169 Column indices to extract.
170 K : ndarray
171 Target matrix to fill.
172 row_start : int
173 Starting row index in K.
174 col_start : int
175 Starting column index in K.
176 n_rows : int
177 Number of rows to copy.
178 sign : float
179 Sign multiplier (+1.0 or -1.0).
180 """
181 n_cols = len(col_indices)
182 for i in range(n_rows):
183 for j in range(n_cols):
184 K[row_start + i, col_start + j] = content_full[i, col_indices[j]] * sign
187# =============================================================================
188# Difference computation functions
189# =============================================================================
191def differences_by_dim_func(X1, X2, rays, n_order, oti_module, return_deriv=True, index=-1):
192 """
193 Compute dimension-wise pairwise differences between X1 and X2,
194 including hypercomplex perturbations in the directions specified by `rays`.
196 This optimized version pre-calculates the perturbation and uses a single
197 efficient loop for subtraction, avoiding broadcasting issues with OTI arrays.
199 Parameters
200 ----------
201 X1 : ndarray of shape (n1, d)
202 First set of input points with n1 samples in d dimensions.
203 X2 : ndarray of shape (n2, d)
204 Second set of input points with n2 samples in d dimensions.
205 rays : ndarray of shape (d, n_rays)
206 Directional vectors for derivative computation.
207 n_order : int
208 The base order used to construct hypercomplex units.
209 When return_deriv=True, uses order 2*n_order.
210 When return_deriv=False, uses order n_order.
211 oti_module : module
212 The PyOTI static module (e.g., pyoti.static.onumm4n2).
213 return_deriv : bool, optional (default=True)
214 If True, use order 2*n_order for hypercomplex units (needed for
215 derivative-derivative blocks in training kernel).
216 If False, use order n_order (sufficient for prediction without
217 derivative outputs).
218 index : int, optional
219 Currently unused. Reserved for future enhancements.
221 Returns
222 -------
223 differences_by_dim : list of length d
224 A list where each element is an array of shape (n1, n2), containing
225 the differences between corresponding dimensions of X1 and X2,
226 augmented with directional hypercomplex perturbations.
228 Notes
229 -----
230 - The function leverages hypercomplex arithmetic from the pyOTI library.
231 - The directional perturbation is computed as: perts = rays @ e_bases
232 where e_bases are the hypercomplex units for each ray direction.
233 - This routine is typically used in the construction of directional
234 derivative kernels for Gaussian processes.
236 Example
237 -------
238 >>> X1 = np.array([[1.0, 2.0], [3.0, 4.0]])
239 >>> X2 = np.array([[1.5, 2.5], [3.5, 4.5]])
240 >>> rays = np.eye(2) # Standard basis directions
241 >>> n_order = 1
242 >>> oti_module = get_oti_module(2, 1) # dim=2, n_order=1
243 >>> diffs = differences_by_dim_func(X1, X2, rays, n_order, oti_module)
244 >>> len(diffs)
245 2
246 >>> diffs[0].shape
247 (2, 2)
248 """
249 X1 = oti_module.array(X1)
250 X2 = oti_module.array(X2)
251 n1, d = X1.shape
252 n2, _ = X2.shape
253 n_rays = rays.shape[1]
255 differences_by_dim = []
257 # Case 1: n_order == 0 (no hypercomplex perturbation)
258 if n_order == 0:
259 for k in range(d):
260 diffs_k = oti_module.zeros((n1, n2))
261 for i in range(n1):
262 diffs_k[i, :] = X1[i, k] - oti_module.transpose(X2[:, k])
263 differences_by_dim.append(diffs_k)
264 return differences_by_dim
266 # Determine the order for hypercomplex units based on return_deriv
267 if return_deriv:
268 hc_order = 2 * n_order
269 else:
270 hc_order = n_order
272 # Pre-calculate the perturbation vector using directional rays
273 e_bases = [oti_module.e(i + 1, order=hc_order) for i in range(n_rays)]
274 perts = np.dot(rays, e_bases)
276 # Case 2: return_deriv=False (prediction without derivative outputs)
277 if not return_deriv:
278 for k in range(d):
279 # Add the pre-calculated perturbation for the current dimension to all points in X1
280 X1_k_tagged = X1[:, k] + perts[k]
281 X2_k = X2[:, k]
283 # Pre-allocate the result matrix for this dimension
284 diffs_k = oti_module.zeros((n1, n2))
286 # Use an efficient single loop for subtraction
287 for i in range(n1):
288 diffs_k[i, :] = X1_k_tagged[i, 0] - X2_k[:, 0].T
290 differences_by_dim.append(diffs_k)
292 # Case 3: return_deriv=True (training kernel with derivative-derivative blocks)
293 else:
294 for k in range(d):
295 X2_k = X2[:, k]
297 # Pre-allocate the result matrix for this dimension
298 diffs_k = oti_module.zeros((n1, n2))
300 # Compute differences without perturbation first
301 for i in range(n1):
302 diffs_k[i, :] = X1[i, k] - X2_k[:, 0].T
304 # Add perturbation to the entire matrix (more efficient)
305 differences_by_dim.append(diffs_k + perts[k])
307 return differences_by_dim
309# =============================================================================
310# Derivative mapping utilities
311# =============================================================================
313def deriv_map(nbases, order):
314 """
315 Creates a mapping from (order, index_within_order) to a single
316 flattened index for all derivative components.
318 Parameters
319 ----------
320 nbases : int
321 Number of base dimensions.
322 order : int
323 Maximum derivative order.
325 Returns
326 -------
327 map_deriv : list of lists
328 Mapping where map_deriv[order][idx] gives the flattened index.
329 """
330 k = 0
331 map_deriv = []
332 for ordi in range(order + 1):
333 ndir = coti.ndir_order(nbases, ordi)
334 map_deriv_i = [0] * ndir
335 for idx in range(ndir):
336 map_deriv_i[idx] = k
337 k += 1
338 map_deriv.append(map_deriv_i)
339 return map_deriv
342def transform_der_indices(der_indices, der_map):
343 """
344 Transforms a list of user-facing derivative specifications into the
345 internal (order, index) format and the final flattened index.
347 Parameters
348 ----------
349 der_indices : list
350 User-facing derivative specifications.
351 der_map : list of lists
352 Derivative mapping from deriv_map().
354 Returns
355 -------
356 deriv_ind_transf : list
357 Flattened indices for each derivative.
358 deriv_ind_order : list
359 (index, order) tuples for each derivative.
360 """
361 deriv_ind_transf = []
362 deriv_ind_order = []
363 for deriv in der_indices:
364 imdir = coti.imdir(deriv)
365 idx, order = imdir
366 deriv_ind_transf.append(der_map[order][idx])
367 deriv_ind_order.append(imdir)
368 return deriv_ind_transf, deriv_ind_order
371# =============================================================================
372# RBF Kernel Assembly Functions (Optimized with Numba)
373# =============================================================================
375@profile
376def rbf_kernel(
377 phi,
378 phi_exp,
379 n_order,
380 n_bases,
381 der_indices,
382 powers,
383 index=-1
384):
385 """
386 Assembles the full DD-GP covariance matrix using an efficient, pre-computed
387 derivative array and block-wise matrix filling.
389 Supports both uniform blocks (all derivatives at all points) and non-contiguous
390 indices (different derivatives at different subsets of points).
392 This version uses Numba-accelerated functions for efficient matrix slicing,
393 replacing expensive np.ix_ operations.
395 Parameters
396 ----------
397 phi : OTI array
398 Base kernel matrix from kernel_func(differences, length_scales).
399 phi_exp : ndarray
400 Expanded derivative array from phi.get_all_derivs().
401 n_order : int
402 Maximum derivative order considered.
403 n_bases : int
404 Number of input dimensions (rays).
405 der_indices : list of lists
406 Multi-index derivative structures for each derivative component.
407 powers : list of int
408 Powers of (-1) applied to each term (for symmetry or sign conventions).
409 index : list of lists or int, optional (default=-1)
410 If empty list, assumes all derivative types apply to all training points.
411 If provided, specifies which training point indices have each derivative type,
412 allowing non-contiguous index support and variable block sizes.
414 Returns
415 -------
416 K : ndarray
417 Full kernel matrix with function values and derivative blocks.
418 """
419 # --- 1. Initial Setup and Efficient Derivative Extraction ---
420 dh = coti.get_dHelp()
422 # Create maps to translate derivative specifications to flat indices
423 der_map = deriv_map(n_bases, 2 * n_order)
424 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
426 # --- 2. Determine Block Sizes and Pre-allocate Matrix ---
427 n_rows_func, n_cols_func = phi.shape
428 n_deriv_types = len(der_indices)
430 # Pre-compute signs (avoid repeated exponentiation)
431 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
433 # Convert index lists to numpy arrays for numba (if provided)
434 if isinstance(index, list) and len(index) > 0:
435 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
436 else:
437 index_arrays = []
439 n_pts_with_derivs_cols = sum(len(idx) for idx in index_arrays) if index_arrays else 0
440 n_pts_with_derivs_rows = n_pts_with_derivs_cols
441 total_rows = n_rows_func + n_pts_with_derivs_rows
442 total_cols = n_cols_func + n_pts_with_derivs_cols
444 K = np.zeros((total_rows, total_cols))
445 base_shape = (n_rows_func, n_cols_func)
447 # --- 3. Fill the Matrix Block by Block ---
449 # Block (0,0): Function-Function (K_ff)
450 content_full = phi_exp[0].reshape(base_shape)
451 K[:n_rows_func, :n_cols_func] = content_full * signs[0]
453 if not index_arrays:
454 # No derivative indices provided, return early
455 return K
457 # First Block-Column: Derivative-Function (K_df)
458 row_offset = n_rows_func
459 for i in range(n_deriv_types):
460 flat_idx = der_indices_tr[i]
461 content_full = phi_exp[flat_idx].reshape(base_shape)
463 row_indices = index_arrays[i]
464 n_pts_this_order = len(row_indices)
466 # Use numba for efficient row extraction and assignment
467 extract_rows_and_assign(content_full, row_indices, K,
468 row_offset, 0, n_cols_func, signs[0])
469 row_offset += n_pts_this_order
471 # First Block-Row: Function-Derivative (K_fd)
472 col_offset = n_cols_func
473 for j in range(n_deriv_types):
474 flat_idx = der_indices_tr[j]
475 content_full = phi_exp[flat_idx].reshape(base_shape)
477 col_indices = index_arrays[j]
478 n_pts_this_order = len(col_indices)
480 # Use numba for efficient column extraction and assignment
481 extract_cols_and_assign(content_full, col_indices, K,
482 0, col_offset, n_rows_func, signs[j + 1])
483 col_offset += n_pts_this_order
485 # Inner Blocks: Derivative-Derivative (K_dd)
486 row_offset = n_rows_func
487 for i in range(n_deriv_types):
488 col_offset = n_cols_func
490 row_indices = index_arrays[i]
491 n_pts_row = len(row_indices)
493 for j in range(n_deriv_types):
494 col_indices = index_arrays[j]
495 n_pts_col = len(col_indices)
497 # Multiply derivative indices to find correct flat index
498 imdir1 = der_ind_order[j]
499 imdir2 = der_ind_order[i]
500 new_idx, new_ord = dh.mult_dir(imdir1[0], imdir1[1], imdir2[0], imdir2[1])
501 flat_idx = der_map[new_ord][new_idx]
502 content_full = phi_exp[flat_idx].reshape(base_shape)
504 # Use numba for efficient submatrix extraction and assignment
505 # This replaces the expensive np.ix_ operation
506 extract_and_assign(content_full, row_indices, col_indices, K,
507 row_offset, col_offset, signs[j + 1])
509 col_offset += n_pts_col
511 row_offset += n_pts_row
513 return K
516@numba.jit(nopython=True, cache=True)
517def _assemble_kernel_numba(phi_exp_3d, K, n_rows_func, n_cols_func,
518 fd_flat_indices, df_flat_indices, dd_flat_indices,
519 idx_flat, idx_offsets, idx_sizes,
520 signs, n_deriv_types, row_offsets, col_offsets):
521 """
522 Fused numba kernel that assembles the entire K matrix in a single call.
523 Handles ff, fd, df, and dd blocks without Python-level loop overhead.
524 """
525 # Block (0,0): Function-Function
526 s0 = signs[0]
527 for r in range(n_rows_func):
528 for c in range(n_cols_func):
529 K[r, c] = phi_exp_3d[0, r, c] * s0
531 # First Block-Row: Function-Derivative (fd)
532 for j in range(n_deriv_types):
533 fi = fd_flat_indices[j]
534 sj = signs[j + 1]
535 co = col_offsets[j]
536 off_j = idx_offsets[j]
537 sz_j = idx_sizes[j]
538 for r in range(n_rows_func):
539 for k in range(sz_j):
540 ci = idx_flat[off_j + k]
541 K[r, co + k] = phi_exp_3d[fi, r, ci] * sj
543 # First Block-Column: Derivative-Function (df)
544 for i in range(n_deriv_types):
545 fi = df_flat_indices[i]
546 ro = row_offsets[i]
547 off_i = idx_offsets[i]
548 sz_i = idx_sizes[i]
549 for k in range(sz_i):
550 ri = idx_flat[off_i + k]
551 for c in range(n_cols_func):
552 K[ro + k, c] = phi_exp_3d[fi, ri, c] * s0
554 # Inner Blocks: Derivative-Derivative (dd)
555 for i in range(n_deriv_types):
556 ro = row_offsets[i]
557 off_i = idx_offsets[i]
558 sz_i = idx_sizes[i]
559 for j in range(n_deriv_types):
560 fi = dd_flat_indices[i, j]
561 sj = signs[j + 1]
562 co = col_offsets[j]
563 off_j = idx_offsets[j]
564 sz_j = idx_sizes[j]
565 for ki in range(sz_i):
566 ri = idx_flat[off_i + ki]
567 for kj in range(sz_j):
568 ci = idx_flat[off_j + kj]
569 K[ro + ki, co + kj] = phi_exp_3d[fi, ri, ci] * sj
572@numba.jit(nopython=True, cache=True)
573def _project_W_to_phi_space(W, W_proj, n_rows_func, n_cols_func,
574 fd_flat_indices, df_flat_indices, dd_flat_indices,
575 idx_flat, idx_offsets, idx_sizes,
576 signs, n_deriv_types, row_offsets, col_offsets):
577 """
578 Reverse of _assemble_kernel_numba: project W from K-space back into
579 phi_exp-space so that vdot(W, assemble(dphi_exp)) == vdot(W_proj, dphi_exp).
580 """
581 for d in range(W_proj.shape[0]):
582 for r in range(W_proj.shape[1]):
583 for c in range(W_proj.shape[2]):
584 W_proj[d, r, c] = 0.0
585 s0 = signs[0]
586 for r in range(n_rows_func):
587 for c in range(n_cols_func):
588 W_proj[0, r, c] += s0 * W[r, c]
589 for j in range(n_deriv_types):
590 fi = fd_flat_indices[j]
591 sj = signs[j + 1]
592 co = col_offsets[j]
593 off_j = idx_offsets[j]
594 sz_j = idx_sizes[j]
595 for r in range(n_rows_func):
596 for k in range(sz_j):
597 ci = idx_flat[off_j + k]
598 W_proj[fi, r, ci] += sj * W[r, co + k]
599 for i in range(n_deriv_types):
600 fi = df_flat_indices[i]
601 ro = row_offsets[i]
602 off_i = idx_offsets[i]
603 sz_i = idx_sizes[i]
604 for k in range(sz_i):
605 ri = idx_flat[off_i + k]
606 for c in range(n_cols_func):
607 W_proj[fi, ri, c] += s0 * W[ro + k, c]
608 for i in range(n_deriv_types):
609 ro = row_offsets[i]
610 off_i = idx_offsets[i]
611 sz_i = idx_sizes[i]
612 for j in range(n_deriv_types):
613 fi = dd_flat_indices[i, j]
614 sj = signs[j + 1]
615 co = col_offsets[j]
616 off_j = idx_offsets[j]
617 sz_j = idx_sizes[j]
618 for ki in range(sz_i):
619 ri = idx_flat[off_i + ki]
620 for kj in range(sz_j):
621 ci = idx_flat[off_j + kj]
622 W_proj[fi, ri, ci] += sj * W[ro + ki, co + kj]
625def precompute_kernel_plan(n_order, n_bases, der_indices, powers, index):
626 """
627 Precompute all structural information needed by rbf_kernel so it can be
628 reused across repeated calls with different phi_exp values.
630 Returns a dict containing flat indices, signs, index arrays, precomputed
631 offsets/sizes, and mult_dir results for the dd block.
632 """
633 dh = coti.get_dHelp()
634 der_map = deriv_map(n_bases, 2 * n_order)
635 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
637 n_deriv_types = len(der_indices)
638 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
639 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
641 # Precompute sizes and offsets
642 index_sizes = np.array([len(idx) for idx in index_arrays], dtype=np.int64)
643 n_pts_with_derivs = int(index_sizes.sum())
645 # Pack all index arrays into a single flat array with offsets
646 idx_flat = np.concatenate(index_arrays) if n_deriv_types > 0 else np.array([], dtype=np.int64)
647 idx_offsets = np.zeros(n_deriv_types, dtype=np.int64)
648 for i in range(1, n_deriv_types):
649 idx_offsets[i] = idx_offsets[i - 1] + index_sizes[i - 1]
651 # Precompute row/col offsets in K for each deriv type
652 row_offsets = np.zeros(n_deriv_types, dtype=np.int64)
653 col_offsets = np.zeros(n_deriv_types, dtype=np.int64)
654 # Note: n_rows_func == n_cols_func for training kernel, but we store
655 # offsets relative to n_rows_func which is added at call time
656 cumsum = 0
657 for i in range(n_deriv_types):
658 row_offsets[i] = cumsum # relative to n_rows_func
659 col_offsets[i] = cumsum # relative to n_cols_func
660 cumsum += index_sizes[i]
662 # Precompute mult_dir results for dd blocks
663 dd_flat_indices = np.empty((n_deriv_types, n_deriv_types), dtype=np.int64)
664 for i in range(n_deriv_types):
665 for j in range(n_deriv_types):
666 imdir1 = der_ind_order[j]
667 imdir2 = der_ind_order[i]
668 new_idx, new_ord = dh.mult_dir(
669 imdir1[0], imdir1[1], imdir2[0], imdir2[1])
670 dd_flat_indices[i, j] = der_map[new_ord][new_idx]
672 # fd and df flat indices as arrays
673 fd_flat_indices = np.array(der_indices_tr, dtype=np.int64)
674 df_flat_indices = np.array(der_indices_tr, dtype=np.int64)
676 return {
677 'der_indices_tr': der_indices_tr,
678 'signs': signs,
679 'index_arrays': index_arrays,
680 'index_sizes': index_sizes,
681 'n_pts_with_derivs': n_pts_with_derivs,
682 'dd_flat_indices': dd_flat_indices,
683 'n_deriv_types': n_deriv_types,
684 # Fused kernel data
685 'idx_flat': idx_flat,
686 'idx_offsets': idx_offsets,
687 'row_offsets': row_offsets,
688 'col_offsets': col_offsets,
689 'fd_flat_indices': fd_flat_indices,
690 'df_flat_indices': df_flat_indices,
691 }
694def rbf_kernel_fast(phi_exp_3d, plan, out=None):
695 """
696 Fast kernel assembly using a precomputed plan and fused numba kernel.
698 Parameters
699 ----------
700 phi_exp_3d : ndarray of shape (n_derivs, n_rows_func, n_cols_func)
701 Pre-reshaped expanded derivative array.
702 plan : dict
703 Precomputed plan from precompute_kernel_plan().
704 out : ndarray, optional
705 Pre-allocated output array. If None, a new array is allocated.
707 Returns
708 -------
709 K : ndarray
710 Full kernel matrix.
711 """
712 n_rows_func = phi_exp_3d.shape[1]
713 n_cols_func = phi_exp_3d.shape[2]
714 total = n_rows_func + plan['n_pts_with_derivs']
715 if out is not None:
716 K = out
717 else:
718 K = np.empty((total, total))
720 if 'row_offsets_abs' in plan:
721 row_off = plan['row_offsets_abs']
722 col_off = plan['col_offsets_abs']
723 else:
724 row_off = plan['row_offsets'] + n_rows_func
725 col_off = plan['col_offsets'] + n_cols_func
727 _assemble_kernel_numba(
728 phi_exp_3d, K, n_rows_func, n_cols_func,
729 plan['fd_flat_indices'], plan['df_flat_indices'], plan['dd_flat_indices'],
730 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'],
731 plan['signs'], plan['n_deriv_types'], row_off, col_off,
732 )
734 return K
737def rbf_kernel_predictions(
738 phi,
739 phi_exp,
740 n_order,
741 n_bases,
742 der_indices,
743 powers,
744 return_deriv,
745 index=-1,
746 common_derivs=None,
747 calc_cov=False,
748 powers_predict=None
749):
750 """
751 Constructs the RBF kernel matrix for predictions with directional derivative entries.
753 This handles the asymmetric case where:
754 - Rows: Test points (predictions)
755 - Columns: Training points (with derivative structure from index)
757 This version uses Numba-accelerated functions for efficient matrix slicing.
759 Parameters
760 ----------
761 phi : OTI array
762 Base kernel matrix between test and training points.
763 phi_exp : ndarray
764 Expanded derivative array from phi.get_all_derivs().
765 n_order : int
766 Maximum derivative order.
767 n_bases : int
768 Number of input dimensions (rays).
769 der_indices : list
770 Derivative specifications for training data.
771 powers : list of int
772 Sign powers for each derivative type.
773 return_deriv : bool
774 If True, predict derivatives at ALL test points.
775 index : list of lists or int, optional (default=-1)
776 Training point indices for each derivative type.
777 common_derivs : list, optional
778 Common derivative indices to predict (intersection of training and requested).
779 calc_cov : bool, optional (default=False)
780 If True, computing covariance (use all indices for rows).
781 powers_predict : list of int, optional
782 Sign powers for prediction derivatives.
784 Returns
785 -------
786 K : ndarray
787 Prediction kernel matrix.
788 """
789 # --- 1. Initial Setup ---
790 if calc_cov and not return_deriv:
791 return phi.real
793 dh = coti.get_dHelp()
795 # Pre-compute signs
796 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
797 if powers_predict is not None:
798 signs_predict = np.array([(-1.0) ** p for p in powers_predict], dtype=np.float64)
799 else:
800 signs_predict = signs
802 # --- 2. Determine Block Sizes and Pre-allocate Matrix ---
803 n_rows_func, n_cols_func = phi.shape
804 n_deriv_types = len(der_indices)
805 n_deriv_types_pred = len(common_derivs) if common_derivs else 0
807 # Convert index to numpy arrays
808 if isinstance(index, list) and len(index) > 0 and isinstance(index[0], (list, np.ndarray)):
809 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
810 else:
811 index_arrays = []
813 if return_deriv:
814 der_map = deriv_map(n_bases, 2 * n_order)
815 index_2 = np.arange(n_cols_func, dtype=np.int64)
816 if calc_cov:
817 index_cov = np.arange(n_cols_func, dtype=np.int64)
818 n_deriv_types = n_deriv_types_pred
819 n_pts_with_derivs_rows = n_deriv_types * n_cols_func
820 else:
821 n_pts_with_derivs_rows = sum(len(idx) for idx in index_arrays) if index_arrays else 0
822 else:
823 der_map = deriv_map(n_bases, n_order)
824 index_2 = np.array([], dtype=np.int64)
825 n_pts_with_derivs_rows = sum(len(idx) for idx in index_arrays) if index_arrays else 0
827 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
829 if common_derivs:
830 der_indices_tr_pred, der_ind_order_pred = transform_der_indices(common_derivs, der_map)
831 else:
832 der_indices_tr_pred, der_ind_order_pred = [], []
834 n_pts_with_derivs_cols = n_deriv_types_pred * len(index_2)
836 total_rows = n_rows_func + n_pts_with_derivs_rows
837 total_cols = n_cols_func + n_pts_with_derivs_cols
839 K = np.zeros((total_rows, total_cols))
840 base_shape = (n_rows_func, n_cols_func)
842 # --- 3. Fill the Matrix Block by Block ---
844 # Block (0,0): Function-Function (K_ff)
845 content_full = phi_exp[0].reshape(base_shape)
846 K[:n_rows_func, :n_cols_func] = content_full * signs[0]
848 if not return_deriv:
849 # First Block-Column: Derivative-Function (K_df)
850 row_offset = n_rows_func
851 for i in range(n_deriv_types):
852 if not index_arrays:
853 break
855 row_indices = index_arrays[i]
856 n_pts_row = len(row_indices)
858 flat_idx = der_indices_tr[i]
859 content_full = phi_exp[flat_idx].reshape(base_shape)
861 # Use numba for efficient row extraction
862 extract_rows_and_assign(content_full, row_indices, K,
863 row_offset, 0, n_cols_func, signs[0])
864 row_offset += n_pts_row
865 return K
867 # --- return_deriv=True case ---
869 # First Block-Row: Function-Derivative (K_fd)
870 col_offset = n_cols_func
871 for j in range(n_deriv_types_pred):
872 n_pts_col = len(index_2)
874 flat_idx = der_indices_tr_pred[j]
875 content_full = phi_exp[flat_idx].reshape(base_shape)
877 # Use numba for efficient column extraction
878 extract_cols_and_assign(content_full, index_2, K,
879 0, col_offset, n_rows_func, signs_predict[j + 1])
880 col_offset += n_pts_col
882 # First Block-Column: Derivative-Function (K_df)
883 row_offset = n_rows_func
884 for i in range(n_deriv_types):
885 if calc_cov:
886 row_indices = index_cov
887 flat_idx = der_indices_tr_pred[i]
888 else:
889 if not index_arrays:
890 break
891 row_indices = index_arrays[i]
892 flat_idx = der_indices_tr[i]
893 n_pts_row = len(row_indices)
895 content_full = phi_exp[flat_idx].reshape(base_shape)
897 # Use numba for efficient row extraction
898 extract_rows_and_assign(content_full, row_indices, K,
899 row_offset, 0, n_cols_func, signs[0])
900 row_offset += n_pts_row
902 # Inner Blocks: Derivative-Derivative (K_dd)
903 row_offset = n_rows_func
904 for i in range(n_deriv_types):
905 if calc_cov:
906 row_indices = index_cov
907 else:
908 if not index_arrays:
909 break
910 row_indices = index_arrays[i]
911 n_pts_row = len(row_indices)
913 col_offset = n_cols_func
914 for j in range(n_deriv_types_pred):
915 n_pts_col = len(index_2)
917 # Multiply derivative indices to find correct flat index
918 imdir1 = der_ind_order_pred[j]
919 imdir2 = der_ind_order_pred[i] if calc_cov else der_ind_order[i]
920 new_idx, new_ord = dh.mult_dir(imdir1[0], imdir1[1], imdir2[0], imdir2[1])
921 flat_idx = der_map[new_ord][new_idx]
923 content_full = phi_exp[flat_idx].reshape(base_shape)
925 # Use numba for efficient submatrix extraction and assignment
926 extract_and_assign(content_full, row_indices, index_2, K,
927 row_offset, col_offset, signs_predict[j + 1])
928 col_offset += n_pts_col
929 row_offset += n_pts_row
931 return K