Coverage for jetgp/full_ddegp/wddegp_utils.py: 69%
401 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
1import numpy as np
2import pyoti.sparse as oti
3import pyoti.core as coti
4from line_profiler import profile
5import numba
8# =============================================================================
9# Numba-accelerated helper functions for efficient matrix slicing
10# =============================================================================
12@numba.jit(nopython=True, cache=True)
13def extract_rows(content_full, row_indices, n_cols):
14 """
15 Extract rows from content_full at specified indices.
17 Parameters
18 ----------
19 content_full : ndarray of shape (n_rows_full, n_cols)
20 Source matrix.
21 row_indices : ndarray of int64
22 Row indices to extract.
23 n_cols : int
24 Number of columns.
26 Returns
27 -------
28 result : ndarray of shape (len(row_indices), n_cols)
29 Extracted rows.
30 """
31 n_rows = len(row_indices)
32 result = np.empty((n_rows, n_cols))
33 for i in range(n_rows):
34 ri = row_indices[i]
35 for j in range(n_cols):
36 result[i, j] = content_full[ri, j]
37 return result
40@numba.jit(nopython=True, cache=True)
41def extract_cols(content_full, col_indices, n_rows):
42 """
43 Extract columns from content_full at specified indices.
45 Parameters
46 ----------
47 content_full : ndarray of shape (n_rows, n_cols_full)
48 Source matrix.
49 col_indices : ndarray of int64
50 Column indices to extract.
51 n_rows : int
52 Number of rows.
54 Returns
55 -------
56 result : ndarray of shape (n_rows, len(col_indices))
57 Extracted columns.
58 """
59 n_cols = len(col_indices)
60 result = np.empty((n_rows, n_cols))
61 for i in range(n_rows):
62 for j in range(n_cols):
63 result[i, j] = content_full[i, col_indices[j]]
64 return result
67@numba.jit(nopython=True, cache=True)
68def extract_submatrix(content_full, row_indices, col_indices):
69 """
70 Extract submatrix from content_full at specified row and column indices.
71 Replaces the expensive np.ix_ operation.
73 Parameters
74 ----------
75 content_full : ndarray of shape (n_rows_full, n_cols_full)
76 Source matrix.
77 row_indices : ndarray of int64
78 Row indices to extract.
79 col_indices : ndarray of int64
80 Column indices to extract.
82 Returns
83 -------
84 result : ndarray of shape (len(row_indices), len(col_indices))
85 Extracted submatrix.
86 """
87 n_rows = len(row_indices)
88 n_cols = len(col_indices)
89 result = np.empty((n_rows, n_cols))
90 for i in range(n_rows):
91 ri = row_indices[i]
92 for j in range(n_cols):
93 result[i, j] = content_full[ri, col_indices[j]]
94 return result
97@numba.jit(nopython=True, cache=True, parallel=False)
98def extract_and_assign(content_full, row_indices, col_indices, K,
99 row_start, col_start, sign):
100 """
101 Extract submatrix and assign directly to K with sign multiplication.
102 Combines extraction and assignment in one pass for better performance.
104 Parameters
105 ----------
106 content_full : ndarray of shape (n_rows_full, n_cols_full)
107 Source matrix.
108 row_indices : ndarray of int64
109 Row indices to extract.
110 col_indices : ndarray of int64
111 Column indices to extract.
112 K : ndarray
113 Target matrix to fill.
114 row_start : int
115 Starting row index in K.
116 col_start : int
117 Starting column index in K.
118 sign : float
119 Sign multiplier (+1.0 or -1.0).
120 """
121 n_rows = len(row_indices)
122 n_cols = len(col_indices)
123 for i in range(n_rows):
124 ri = row_indices[i]
125 for j in range(n_cols):
126 K[row_start + i, col_start + j] = content_full[ri, col_indices[j]] * sign
129@numba.jit(nopython=True, cache=True)
130def extract_rows_and_assign(content_full, row_indices, K,
131 row_start, col_start, n_cols, sign):
132 """
133 Extract rows and assign directly to K with sign multiplication.
135 Parameters
136 ----------
137 content_full : ndarray of shape (n_rows_full, n_cols)
138 Source matrix.
139 row_indices : ndarray of int64
140 Row indices to extract.
141 K : ndarray
142 Target matrix to fill.
143 row_start : int
144 Starting row index in K.
145 col_start : int
146 Starting column index in K.
147 n_cols : int
148 Number of columns to copy.
149 sign : float
150 Sign multiplier (+1.0 or -1.0).
151 """
152 n_rows = len(row_indices)
153 for i in range(n_rows):
154 ri = row_indices[i]
155 for j in range(n_cols):
156 K[row_start + i, col_start + j] = content_full[ri, j] * sign
159@numba.jit(nopython=True, cache=True)
160def extract_cols_and_assign(content_full, col_indices, K,
161 row_start, col_start, n_rows, sign):
162 """
163 Extract columns and assign directly to K with sign multiplication.
165 Parameters
166 ----------
167 content_full : ndarray of shape (n_rows, n_cols_full)
168 Source matrix.
169 col_indices : ndarray of int64
170 Column indices to extract.
171 K : ndarray
172 Target matrix to fill.
173 row_start : int
174 Starting row index in K.
175 col_start : int
176 Starting column index in K.
177 n_rows : int
178 Number of rows to copy.
179 sign : float
180 Sign multiplier (+1.0 or -1.0).
181 """
182 n_cols = len(col_indices)
183 for i in range(n_rows):
184 for j in range(n_cols):
185 K[row_start + i, col_start + j] = content_full[i, col_indices[j]] * sign
188# =============================================================================
189# Difference computation functions
190# =============================================================================
192def differences_by_dim_func(X1, X2, rays, n_order,oti_module, return_deriv=True):
193 """
194 Compute dimension-wise pairwise differences between X1 and X2,
195 including hypercomplex perturbations in the directions specified by `rays`.
197 This optimized version pre-calculates the perturbation and uses a single
198 efficient loop for subtraction, avoiding broadcasting issues with OTI arrays.
200 Parameters
201 ----------
202 X1 : ndarray of shape (n1, d)
203 First set of input points with n1 samples in d dimensions.
204 X2 : ndarray of shape (n2, d)
205 Second set of input points with n2 samples in d dimensions.
206 rays : ndarray of shape (d, n_rays)
207 Directional vectors for derivative computation.
208 n_order : int
209 The base order used to construct hypercomplex units.
210 When return_deriv=True, uses order 2*n_order.
211 When return_deriv=False, uses order n_order.
212 return_deriv : bool, optional (default=True)
213 If True, use order 2*n_order for hypercomplex units (needed for
214 derivative-derivative blocks in training kernel).
215 If False, use order n_order (sufficient for prediction without
216 derivative outputs).
218 Returns
219 -------
220 differences_by_dim : list of length d
221 A list where each element is an array of shape (n1, n2), containing
222 the differences between corresponding dimensions of X1 and X2,
223 augmented with directional hypercomplex perturbations.
224 """
225 X1 = oti_module.array(X1)
226 X2 = oti_module.array(X2)
227 n1, d = X1.shape
228 n2, _ = X2.shape
229 n_rays = rays.shape[1]
231 differences_by_dim = []
233 # Case 1: n_order == 0 (no hypercomplex perturbation)
234 if n_order == 0:
235 for k in range(d):
236 diffs_k = oti_module.zeros((n1, n2))
237 for i in range(n1):
238 diffs_k[i, :] = X1[i, k] - X2[:, k].T
239 differences_by_dim.append(diffs_k)
240 return differences_by_dim
242 # Determine the order for hypercomplex units based on return_deriv
243 if return_deriv:
244 hc_order = 2 * n_order
245 else:
246 hc_order = n_order
248 # Pre-calculate the perturbation vector using directional rays
249 e_bases = [oti_module.e(i + 1, order=hc_order) for i in range(n_rays)]
250 perts = np.dot(rays, e_bases)
252 # Case 2: return_deriv=False (prediction without derivative outputs)
253 if not return_deriv:
254 for k in range(d):
255 # Add the pre-calculated perturbation for the current dimension to all points in X1
256 X1_k_tagged = X1[:, k] + perts[k]
257 X2_k = X2[:, k]
259 # Pre-allocate the result matrix for this dimension
260 diffs_k = oti_module.zeros((n1, n2))
262 # Use an efficient single loop for subtraction
263 for i in range(n1):
264 diffs_k[i, :] = X1_k_tagged[i, 0] - X2_k[:, 0].T
266 differences_by_dim.append(diffs_k)
268 # Case 3: return_deriv=True (training kernel with derivative-derivative blocks)
269 else:
270 for k in range(d):
271 X2_k = X2[:, k]
273 # Pre-allocate the result matrix for this dimension
274 diffs_k = oti_module.zeros((n1, n2))
276 # Compute differences without perturbation first
277 for i in range(n1):
278 diffs_k[i, :] = X1[i, k] - X2_k[:, 0].T
280 # Add perturbation to the entire matrix (more efficient)
281 differences_by_dim.append(diffs_k + perts[k])
283 return differences_by_dim
286# =============================================================================
287# Derivative mapping utilities
288# =============================================================================
290def deriv_map(nbases, order):
291 """
292 Creates a mapping from (order, index_within_order) to a single
293 flattened index for all derivative components.
294 """
295 k = 0
296 map_deriv = []
297 for ordi in range(order + 1):
298 ndir = coti.ndir_order(nbases, ordi)
299 map_deriv_i = [0] * ndir
300 for idx in range(ndir):
301 map_deriv_i[idx] = k
302 k += 1
303 map_deriv.append(map_deriv_i)
304 return map_deriv
307def transform_der_indices(der_indices, der_map):
308 """
309 Transforms a list of user-facing derivative specifications into the
310 internal (order, index) format and the final flattened index.
311 """
312 deriv_ind_transf = []
313 deriv_ind_order = []
314 for deriv in der_indices:
315 imdir = coti.imdir(deriv)
316 idx, order = imdir
317 deriv_ind_transf.append(der_map[order][idx])
318 deriv_ind_order.append(imdir)
319 return deriv_ind_transf, deriv_ind_order
322# =============================================================================
323# RBF Kernel Assembly Functions (Optimized with Numba)
324# =============================================================================
326def rbf_kernel(
327 phi,
328 phi_exp,
329 n_order,
330 n_bases,
331 der_indices,
332 powers,
333 index=-1
334):
335 """
336 Assembles the full DD-GP covariance matrix using an efficient, pre-computed
337 derivative array and block-wise matrix filling.
339 This version uses Numba-accelerated functions for efficient matrix slicing,
340 replacing expensive np.ix_ operations.
342 Parameters
343 ----------
344 phi : OTI array
345 Base kernel matrix from kernel_func(differences, length_scales).
346 phi_exp : ndarray
347 Expanded derivative array from phi.get_all_derivs().
348 n_order : int
349 Maximum derivative order considered.
350 n_bases : int
351 Total number of bases (function value + derivative terms).
352 der_indices : list of lists
353 Multi-index derivative structures for each derivative component.
354 powers : list of int
355 Powers of (-1) applied to each term (for symmetry or sign conventions).
356 index : list of lists
357 Specifies which training point indices have each derivative type.
359 Returns
360 -------
361 K : ndarray
362 Full kernel matrix with function values and derivative blocks.
363 """
364 dh = coti.get_dHelp()
366 # Create maps to translate derivative specifications to flat indices
367 der_map = deriv_map(n_bases, 2 * n_order)
368 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
370 # Determine Block Sizes and Pre-allocate Matrix
371 n_rows_func, n_cols_func = phi.shape
372 n_deriv_types = len(der_indices)
373 n_pts_with_derivs_cols = sum(len(order_indices) for order_indices in index)
374 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index)
375 total_rows = n_rows_func + n_pts_with_derivs_rows
376 total_cols = n_cols_func + n_pts_with_derivs_cols
378 K = np.zeros((total_rows, total_cols))
379 base_shape = (n_rows_func, n_cols_func)
381 # Pre-compute signs (avoid repeated exponentiation)
382 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
384 # Convert index lists to numpy arrays for numba
385 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
387 # Block (0,0): Function-Function (K_ff)
388 content_full = phi_exp[0].reshape(base_shape)
389 K[:n_rows_func, :n_cols_func] = content_full * signs[0]
391 # First Block-Column: Derivative-Function (K_df)
392 row_offset = n_rows_func
393 for i in range(n_deriv_types):
394 flat_idx = der_indices_tr[i]
395 content_full = phi_exp[flat_idx].reshape(base_shape)
397 current_indices = index_arrays[i]
398 n_pts_this_order = len(current_indices)
400 # Use numba for efficient row extraction and assignment
401 extract_rows_and_assign(content_full, current_indices, K,
402 row_offset, 0, n_cols_func, signs[0])
403 row_offset += n_pts_this_order
405 # First Block-Row: Function-Derivative (K_fd)
406 col_offset = n_cols_func
407 for j in range(n_deriv_types):
408 flat_idx = der_indices_tr[j]
409 content_full = phi_exp[flat_idx].reshape(base_shape)
411 current_indices = index_arrays[j]
412 n_pts_this_order = len(current_indices)
414 # Use numba for efficient column extraction and assignment
415 extract_cols_and_assign(content_full, current_indices, K,
416 0, col_offset, n_rows_func, signs[j + 1])
417 col_offset += n_pts_this_order
419 # Inner Blocks: Derivative-Derivative (K_dd)
420 row_offset = n_rows_func
421 for i in range(n_deriv_types):
422 col_offset = n_cols_func
424 row_indices = index_arrays[i]
425 n_pts_row = len(row_indices)
427 for j in range(n_deriv_types):
428 col_indices = index_arrays[j]
429 n_pts_col = len(col_indices)
431 # Multiply derivative indices to find correct flat index
432 imdir1 = der_ind_order[j]
433 imdir2 = der_ind_order[i]
434 new_idx, new_ord = dh.mult_dir(imdir1[0], imdir1[1], imdir2[0], imdir2[1])
435 flat_idx = der_map[new_ord][new_idx]
436 content_full = phi_exp[flat_idx].reshape(base_shape)
438 # Use numba for efficient submatrix extraction and assignment (replaces np.ix_)
439 extract_and_assign(content_full, row_indices, col_indices, K,
440 row_offset, col_offset, signs[j + 1])
442 col_offset += n_pts_col
444 row_offset += n_pts_row
446 return K
449@numba.jit(nopython=True, cache=True)
450def _assemble_kernel_numba(phi_exp_3d, K, n_rows_func, n_cols_func,
451 fd_flat_indices, df_flat_indices, dd_flat_indices,
452 idx_flat, idx_offsets, idx_sizes,
453 signs, n_deriv_types, row_offsets, col_offsets):
454 """Fused numba kernel for entire K matrix assembly."""
455 s0 = signs[0]
456 for r in range(n_rows_func):
457 for c in range(n_cols_func):
458 K[r, c] = phi_exp_3d[0, r, c] * s0
459 for j in range(n_deriv_types):
460 fi = fd_flat_indices[j]
461 sj = signs[j + 1]
462 co = col_offsets[j]
463 off_j = idx_offsets[j]
464 sz_j = idx_sizes[j]
465 for r in range(n_rows_func):
466 for k in range(sz_j):
467 ci = idx_flat[off_j + k]
468 K[r, co + k] = phi_exp_3d[fi, r, ci] * sj
469 for i in range(n_deriv_types):
470 fi = df_flat_indices[i]
471 ro = row_offsets[i]
472 off_i = idx_offsets[i]
473 sz_i = idx_sizes[i]
474 for k in range(sz_i):
475 ri = idx_flat[off_i + k]
476 for c in range(n_cols_func):
477 K[ro + k, c] = phi_exp_3d[fi, ri, c] * s0
478 for i in range(n_deriv_types):
479 ro = row_offsets[i]
480 off_i = idx_offsets[i]
481 sz_i = idx_sizes[i]
482 for j in range(n_deriv_types):
483 fi = dd_flat_indices[i, j]
484 sj = signs[j + 1]
485 co = col_offsets[j]
486 off_j = idx_offsets[j]
487 sz_j = idx_sizes[j]
488 for ki in range(sz_i):
489 ri = idx_flat[off_i + ki]
490 for kj in range(sz_j):
491 ci = idx_flat[off_j + kj]
492 K[ro + ki, co + kj] = phi_exp_3d[fi, ri, ci] * sj
495@numba.jit(nopython=True, cache=True)
496def _project_W_to_phi_space(W, W_proj, n_rows_func, n_cols_func,
497 fd_flat_indices, df_flat_indices, dd_flat_indices,
498 idx_flat, idx_offsets, idx_sizes,
499 signs, n_deriv_types, row_offsets, col_offsets):
500 """
501 Reverse of _assemble_kernel_numba: project W from K-space back into
502 phi_exp-space so that vdot(W, assemble(dphi_exp)) == vdot(W_proj, dphi_exp).
503 """
504 for d in range(W_proj.shape[0]):
505 for r in range(W_proj.shape[1]):
506 for c in range(W_proj.shape[2]):
507 W_proj[d, r, c] = 0.0
508 s0 = signs[0]
509 for r in range(n_rows_func):
510 for c in range(n_cols_func):
511 W_proj[0, r, c] += s0 * W[r, c]
512 for j in range(n_deriv_types):
513 fi = fd_flat_indices[j]
514 sj = signs[j + 1]
515 co = col_offsets[j]
516 off_j = idx_offsets[j]
517 sz_j = idx_sizes[j]
518 for r in range(n_rows_func):
519 for k in range(sz_j):
520 ci = idx_flat[off_j + k]
521 W_proj[fi, r, ci] += sj * W[r, co + k]
522 for i in range(n_deriv_types):
523 fi = df_flat_indices[i]
524 ro = row_offsets[i]
525 off_i = idx_offsets[i]
526 sz_i = idx_sizes[i]
527 for k in range(sz_i):
528 ri = idx_flat[off_i + k]
529 for c in range(n_cols_func):
530 W_proj[fi, ri, c] += s0 * W[ro + k, c]
531 for i in range(n_deriv_types):
532 ro = row_offsets[i]
533 off_i = idx_offsets[i]
534 sz_i = idx_sizes[i]
535 for j in range(n_deriv_types):
536 fi = dd_flat_indices[i, j]
537 sj = signs[j + 1]
538 co = col_offsets[j]
539 off_j = idx_offsets[j]
540 sz_j = idx_sizes[j]
541 for ki in range(sz_i):
542 ri = idx_flat[off_i + ki]
543 for kj in range(sz_j):
544 ci = idx_flat[off_j + kj]
545 W_proj[fi, ri, ci] += sj * W[ro + ki, co + kj]
548def precompute_kernel_plan(n_order, n_bases, der_indices, powers, index):
549 """Precompute structural info for rbf_kernel_fast."""
550 dh = coti.get_dHelp()
551 der_map = deriv_map(n_bases, 2 * n_order)
552 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
554 n_deriv_types = len(der_indices)
555 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
556 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
558 index_sizes = np.array([len(idx) for idx in index_arrays], dtype=np.int64)
559 n_pts_with_derivs = int(index_sizes.sum())
561 idx_flat = np.concatenate(index_arrays) if n_deriv_types > 0 else np.array([], dtype=np.int64)
562 idx_offsets = np.zeros(n_deriv_types, dtype=np.int64)
563 for i in range(1, n_deriv_types):
564 idx_offsets[i] = idx_offsets[i - 1] + index_sizes[i - 1]
566 row_offsets = np.zeros(n_deriv_types, dtype=np.int64)
567 col_offsets = np.zeros(n_deriv_types, dtype=np.int64)
568 cumsum = 0
569 for i in range(n_deriv_types):
570 row_offsets[i] = cumsum
571 col_offsets[i] = cumsum
572 cumsum += index_sizes[i]
574 dd_flat_indices = np.empty((n_deriv_types, n_deriv_types), dtype=np.int64)
575 for i in range(n_deriv_types):
576 for j in range(n_deriv_types):
577 imdir1 = der_ind_order[j]
578 imdir2 = der_ind_order[i]
579 new_idx, new_ord = dh.mult_dir(imdir1[0], imdir1[1], imdir2[0], imdir2[1])
580 dd_flat_indices[i, j] = der_map[new_ord][new_idx]
582 fd_flat_indices = np.array(der_indices_tr, dtype=np.int64)
583 df_flat_indices = np.array(der_indices_tr, dtype=np.int64)
585 return {
586 'der_indices_tr': der_indices_tr,
587 'signs': signs,
588 'index_arrays': index_arrays,
589 'index_sizes': index_sizes,
590 'n_pts_with_derivs': n_pts_with_derivs,
591 'dd_flat_indices': dd_flat_indices,
592 'n_deriv_types': n_deriv_types,
593 'idx_flat': idx_flat,
594 'idx_offsets': idx_offsets,
595 'row_offsets': row_offsets,
596 'col_offsets': col_offsets,
597 'fd_flat_indices': fd_flat_indices,
598 'df_flat_indices': df_flat_indices,
599 }
602def rbf_kernel_fast(phi_exp_3d, plan, out=None):
603 """Fast kernel assembly using precomputed plan and fused numba kernel."""
604 n_rows_func = phi_exp_3d.shape[1]
605 n_cols_func = phi_exp_3d.shape[2]
606 total = n_rows_func + plan['n_pts_with_derivs']
607 if out is not None:
608 K = out
609 else:
610 K = np.empty((total, total))
612 if 'row_offsets_abs' in plan:
613 row_off = plan['row_offsets_abs']
614 col_off = plan['col_offsets_abs']
615 else:
616 row_off = plan['row_offsets'] + n_rows_func
617 col_off = plan['col_offsets'] + n_cols_func
619 _assemble_kernel_numba(
620 phi_exp_3d, K, n_rows_func, n_cols_func,
621 plan['fd_flat_indices'], plan['df_flat_indices'], plan['dd_flat_indices'],
622 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'],
623 plan['signs'], plan['n_deriv_types'], row_off, col_off,
624 )
625 return K
628def rbf_kernel_predictions(
629 phi,
630 phi_exp,
631 n_order,
632 n_bases,
633 der_indices,
634 powers,
635 return_deriv,
636 index=-1,
637 common_derivs=None,
638 calc_cov=False,
639 powers_predict=None
640):
641 """
642 Constructs the RBF kernel matrix for predictions with directional derivative entries.
644 This version uses Numba-accelerated functions for efficient matrix slicing.
646 Parameters
647 ----------
648 phi : OTI array
649 Base kernel matrix between test and training points.
650 phi_exp : ndarray
651 Expanded derivative array from phi.get_all_derivs().
652 n_order : int
653 Maximum derivative order.
654 n_bases : int
655 Number of input dimensions.
656 der_indices : list
657 Derivative specifications.
658 powers : list of int
659 Sign powers for each derivative type.
660 return_deriv : bool
661 If True, predict derivatives at ALL test points.
662 index : list of lists
663 Training point indices for each derivative type.
664 common_derivs : list
665 Common derivative indices to predict.
666 calc_cov : bool
667 If True, computing covariance (use all indices for rows).
668 powers_predict : list of int, optional
669 Sign powers for prediction derivatives.
671 Returns
672 -------
673 K : ndarray
674 Prediction kernel matrix.
675 """
676 if calc_cov and not return_deriv:
677 return phi.real
679 dh = coti.get_dHelp()
681 n_rows_func, n_cols_func = phi.shape
682 n_deriv_types = len(der_indices)
683 n_deriv_types_pred = len(common_derivs) if common_derivs else 0
685 # Pre-compute signs
686 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
687 if powers_predict is not None:
688 signs_predict = np.array([(-1.0) ** p for p in powers_predict], dtype=np.float64)
689 else:
690 signs_predict = signs
692 if return_deriv:
693 der_map = deriv_map(n_bases, 2 * n_order)
694 index_2 = np.arange(phi_exp.shape[-1], dtype=np.int64)
695 if calc_cov:
696 index_cov = np.arange(phi_exp.shape[-1], dtype=np.int64)
697 n_deriv_types = n_deriv_types_pred
698 n_pts_with_derivs_rows = n_deriv_types * len([i for i in range(n_cols_func) if i < len(index_2)])
699 else:
700 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index)
701 else:
702 der_map = deriv_map(n_bases, n_order)
703 index_2 = np.array([], dtype=np.int64)
704 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index)
706 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
707 der_indices_tr_pred, der_ind_order_pred = transform_der_indices(common_derivs, der_map) if common_derivs else ([], [])
708 n_pts_with_derivs_cols = n_deriv_types_pred * len([i for i in range(n_cols_func) if i < len(index_2)])
710 total_rows = n_rows_func + n_pts_with_derivs_rows
711 total_cols = n_cols_func + n_pts_with_derivs_cols
713 K = np.zeros((total_rows, total_cols))
714 base_shape = (n_rows_func, n_cols_func)
716 # Convert index lists to numpy arrays for numba
717 if index != -1 and isinstance(index, list) and len(index) > 0:
718 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
719 else:
720 index_arrays = []
722 # Block (0,0): Function-Function (K_ff)
723 content_full = phi_exp[0].reshape(base_shape)
724 K[:n_rows_func, :n_cols_func] = content_full * signs[0]
726 if not return_deriv:
727 # First Block-Column: Derivative-Function (K_df)
728 row_offset = n_rows_func
729 for i in range(n_deriv_types):
730 if calc_cov:
731 row_indices = index_cov
732 else:
733 if not index_arrays:
734 break
735 row_indices = index_arrays[i]
736 n_pts_row = len(row_indices)
738 flat_idx = der_indices_tr[i]
739 content_full = phi_exp[flat_idx].reshape(base_shape)
741 # Use numba for efficient row extraction
742 extract_rows_and_assign(content_full, row_indices, K,
743 row_offset, 0, n_cols_func, signs[0])
744 row_offset += n_pts_row
745 return K
747 # --- return_deriv=True case ---
749 # First Block-Row: Function-Derivative (K_fd)
750 col_offset = n_cols_func
751 for j in range(n_deriv_types_pred):
752 col_indices = index_2
753 n_pts_col = len(col_indices)
755 flat_idx = der_indices_tr_pred[j]
756 content_full = phi_exp[flat_idx].reshape(base_shape)
758 # Use numba for efficient column extraction
759 extract_cols_and_assign(content_full, col_indices, K,
760 0, col_offset, n_rows_func, signs_predict[j + 1])
761 col_offset += n_pts_col
763 # First Block-Column: Derivative-Function (K_df)
764 row_offset = n_rows_func
765 for i in range(n_deriv_types):
766 if calc_cov:
767 row_indices = index_cov
768 flat_idx = der_indices_tr_pred[i]
769 else:
770 if not index_arrays:
771 break
772 row_indices = index_arrays[i]
773 flat_idx = der_indices_tr[i]
774 n_pts_row = len(row_indices)
776 content_full = phi_exp[flat_idx].reshape(base_shape)
778 # Use numba for efficient row extraction
779 extract_rows_and_assign(content_full, row_indices, K,
780 row_offset, 0, n_cols_func, signs[0])
781 row_offset += n_pts_row
783 # Inner Blocks: Derivative-Derivative (K_dd)
784 row_offset = n_rows_func
785 for i in range(n_deriv_types):
786 if calc_cov:
787 row_indices = index_cov
788 else:
789 if not index_arrays:
790 break
791 row_indices = index_arrays[i]
792 n_pts_row = len(row_indices)
794 col_offset = n_cols_func
795 for j in range(n_deriv_types_pred):
796 col_indices = index_2
797 n_pts_col = len(col_indices)
799 # Multiply the derivative indices to find the correct flat index
800 imdir1 = der_ind_order_pred[j]
801 imdir2 = der_ind_order_pred[i] if calc_cov else der_ind_order[i]
802 new_idx, new_ord = dh.mult_dir(
803 imdir1[0], imdir1[1], imdir2[0], imdir2[1])
804 flat_idx = der_map[new_ord][new_idx]
806 content_full = phi_exp[flat_idx].reshape(base_shape)
808 # Use numba for efficient submatrix extraction and assignment (replaces np.ix_)
809 extract_and_assign(content_full, row_indices, col_indices, K,
810 row_offset, col_offset, signs_predict[j + 1])
811 col_offset += n_pts_col
812 row_offset += n_pts_row
814 return K
817# =============================================================================
818# Utility functions
819# =============================================================================
821def determine_weights(diffs_by_dim, diffs_test, length_scales, kernel_func, sigma_n):
822 """
823 Vectorized version: compute interpolation weights for multiple test points at once.
825 Parameters
826 ----------
827 diffs_by_dim : list of ndarray
828 Pairwise differences between training points (by dimension).
829 diffs_test : list of ndarray
830 Pairwise differences between test points and training points (by dimension).
831 Shape: each array is (n_test, n_train) or similar batch dimension.
832 length_scales : array-like
833 Kernel hyperparameters.
834 kernel_func : callable
835 Kernel function.
836 sigma_n : float
837 Noise parameter (if needed).
839 Returns
840 -------
841 weights_matrix : ndarray of shape (n_test, n_train)
842 Interpolation weights for each test point.
843 """
844 # Compute K matrix (training covariance) - same for all test points
845 K = kernel_func(diffs_by_dim, length_scales).real
846 n_train = K.shape[0]
848 # Compute r vectors (test-train covariances) for all test points at once
849 r_all = kernel_func(diffs_test, length_scales).real
850 n_test = r_all.shape[0]
852 # Build augmented system matrix M (same for all test points)
853 M = np.zeros((n_train + 1, n_train + 1))
854 M[:n_train, :n_train] = K
855 M[:n_train, n_train] = 1
856 M[n_train, :n_train] = 1
857 M[n_train, n_train] = 0
859 # Build augmented RHS for all test points
860 r_augmented = np.zeros((n_test, n_train + 1))
861 r_augmented[:, :n_train] = r_all
862 r_augmented[:, n_train] = 1
864 # Solve for all test points at once
865 solution = np.linalg.solve(M, r_augmented.T)
867 # Extract weights (exclude Lagrange multiplier)
868 weights_matrix = solution[:n_train, :].T
870 return weights_matrix
873def to_list(x):
874 """Convert tuple to list recursively."""
875 if isinstance(x, tuple):
876 return [to_list(i) for i in x]
877 return x
880def to_tuple(item):
881 """Convert list to tuple recursively."""
882 if isinstance(item, list):
883 return tuple(to_tuple(x) for x in item)
884 return item
887def find_common_derivatives(all_indices):
888 """Find derivative indices common to all submodels."""
889 sets = [set(to_tuple(elem) for elem in idx_list) for idx_list in all_indices]
890 return sets[0].intersection(*sets[1:])