Coverage for jetgp/wdegp/wdegp_utils.py: 66%
390 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
« prev ^ index » next coverage.py v7.10.7, created at 2026-04-03 15:09 -0500
1import pyoti.core as coti
2import numpy as np
3from line_profiler import profile
4import numba
7# =============================================================================
8# Numba-accelerated helper functions for efficient matrix slicing
9# =============================================================================
11@numba.jit(nopython=True, cache=True)
12def extract_rows(content_full, row_indices, n_cols):
13 """
14 Extract rows from content_full at specified indices.
16 Parameters
17 ----------
18 content_full : ndarray of shape (n_rows_full, n_cols)
19 Source matrix.
20 row_indices : ndarray of int64
21 Row indices to extract.
22 n_cols : int
23 Number of columns.
25 Returns
26 -------
27 result : ndarray of shape (len(row_indices), n_cols)
28 Extracted rows.
29 """
30 n_rows = len(row_indices)
31 result = np.empty((n_rows, n_cols))
32 for i in range(n_rows):
33 ri = row_indices[i]
34 for j in range(n_cols):
35 result[i, j] = content_full[ri, j]
36 return result
39@numba.jit(nopython=True, cache=True)
40def extract_cols(content_full, col_indices, n_rows):
41 """
42 Extract columns from content_full at specified indices.
44 Parameters
45 ----------
46 content_full : ndarray of shape (n_rows, n_cols_full)
47 Source matrix.
48 col_indices : ndarray of int64
49 Column indices to extract.
50 n_rows : int
51 Number of rows.
53 Returns
54 -------
55 result : ndarray of shape (n_rows, len(col_indices))
56 Extracted columns.
57 """
58 n_cols = len(col_indices)
59 result = np.empty((n_rows, n_cols))
60 for i in range(n_rows):
61 for j in range(n_cols):
62 result[i, j] = content_full[i, col_indices[j]]
63 return result
66@numba.jit(nopython=True, cache=True)
67def extract_submatrix(content_full, row_indices, col_indices):
68 """
69 Extract submatrix from content_full at specified row and column indices.
70 Replaces the expensive np.ix_ operation.
72 Parameters
73 ----------
74 content_full : ndarray of shape (n_rows_full, n_cols_full)
75 Source matrix.
76 row_indices : ndarray of int64
77 Row indices to extract.
78 col_indices : ndarray of int64
79 Column indices to extract.
81 Returns
82 -------
83 result : ndarray of shape (len(row_indices), len(col_indices))
84 Extracted submatrix.
85 """
86 n_rows = len(row_indices)
87 n_cols = len(col_indices)
88 result = np.empty((n_rows, n_cols))
89 for i in range(n_rows):
90 ri = row_indices[i]
91 for j in range(n_cols):
92 result[i, j] = content_full[ri, col_indices[j]]
93 return result
96@numba.jit(nopython=True, cache=True, parallel=False)
97def extract_and_assign(content_full, row_indices, col_indices, K,
98 row_start, col_start, sign):
99 """
100 Extract submatrix and assign directly to K with sign multiplication.
101 Combines extraction and assignment in one pass for better performance.
103 Parameters
104 ----------
105 content_full : ndarray of shape (n_rows_full, n_cols_full)
106 Source matrix.
107 row_indices : ndarray of int64
108 Row indices to extract.
109 col_indices : ndarray of int64
110 Column indices to extract.
111 K : ndarray
112 Target matrix to fill.
113 row_start : int
114 Starting row index in K.
115 col_start : int
116 Starting column index in K.
117 sign : float
118 Sign multiplier (+1.0 or -1.0).
119 """
120 n_rows = len(row_indices)
121 n_cols = len(col_indices)
122 for i in range(n_rows):
123 ri = row_indices[i]
124 for j in range(n_cols):
125 K[row_start + i, col_start + j] = content_full[ri, col_indices[j]] * sign
128@numba.jit(nopython=True, cache=True)
129def extract_rows_and_assign(content_full, row_indices, K,
130 row_start, col_start, n_cols, sign):
131 """
132 Extract rows and assign directly to K with sign multiplication.
134 Parameters
135 ----------
136 content_full : ndarray of shape (n_rows_full, n_cols)
137 Source matrix.
138 row_indices : ndarray of int64
139 Row indices to extract.
140 K : ndarray
141 Target matrix to fill.
142 row_start : int
143 Starting row index in K.
144 col_start : int
145 Starting column index in K.
146 n_cols : int
147 Number of columns to copy.
148 sign : float
149 Sign multiplier (+1.0 or -1.0).
150 """
151 n_rows = len(row_indices)
152 for i in range(n_rows):
153 ri = row_indices[i]
154 for j in range(n_cols):
155 K[row_start + i, col_start + j] = content_full[ri, j] * sign
158@numba.jit(nopython=True, cache=True)
159def extract_cols_and_assign(content_full, col_indices, K,
160 row_start, col_start, n_rows, sign):
161 """
162 Extract columns and assign directly to K with sign multiplication.
164 Parameters
165 ----------
166 content_full : ndarray of shape (n_rows, n_cols_full)
167 Source matrix.
168 col_indices : ndarray of int64
169 Column indices to extract.
170 K : ndarray
171 Target matrix to fill.
172 row_start : int
173 Starting row index in K.
174 col_start : int
175 Starting column index in K.
176 n_rows : int
177 Number of rows to copy.
178 sign : float
179 Sign multiplier (+1.0 or -1.0).
180 """
181 n_cols = len(col_indices)
182 for i in range(n_rows):
183 for j in range(n_cols):
184 K[row_start + i, col_start + j] = content_full[i, col_indices[j]] * sign
186def differences_by_dim_func(X1, X2, n_order, oti_module, return_deriv=True):
187 """
188 Compute pairwise differences between two input arrays X1 and X2 for each dimension,
189 embedding hypercomplex units along each dimension for automatic differentiation.
191 For each dimension k, this function computes:
192 diff_k[i, j] = X1[i, k] + e_{k+1} - X2[j, k]
193 where e_{k+1} is a hypercomplex unit for the (k+1)-th dimension with order 2 * n_order.
195 Parameters
196 ----------
197 X1 : array_like of shape (n1, d)
198 First set of input points with n1 samples in d dimensions.
199 X2 : array_like of shape (n2, d)
200 Second set of input points with n2 samples in d dimensions.
201 n_order : int
202 The base order used to construct hypercomplex units (e_{k+1}) with order 2 * n_order.
203 oti_module : module
204 The PyOTI static module (e.g., pyoti.static.onumm4n2).
205 return_deriv : bool, optional
206 If True, use 2*n_order for derivative predictions.
208 Returns
209 -------
210 differences_by_dim : list of length d
211 A list where each element is an array of shape (n1, n2), containing the differences
212 between corresponding dimensions of X1 and X2, augmented with hypercomplex units.
213 """
214 X1 = oti_module.array(X1)
215 X2 = oti_module.array(X2)
216 n1, d = X1.shape
217 n2, d = X2.shape
219 differences_by_dim = []
221 if n_order == 0:
222 for k in range(d):
223 diffs_k = oti_module.zeros((n1, n2))
224 for i in range(n1):
225 diffs_k[i, :] = X1[i, k] - (oti_module.transpose(X2[:, k]))
226 differences_by_dim.append(diffs_k)
227 elif not return_deriv:
228 for k in range(d):
229 diffs_k = oti_module.zeros((n1, n2))
230 for i in range(n1):
231 diffs_k[i, :] = (
232 X1[i, k]
233 + oti_module.e(k + 1, order=n_order)
234 - (X2[:, k].T)
235 )
236 differences_by_dim.append(diffs_k)
237 else:
238 for k in range(d):
239 diffs_k = oti_module.zeros((n1, n2))
240 for i in range(n1):
241 diffs_k[i, :] = X1[i, k] - (X2[:, k].T)
242 differences_by_dim.append(diffs_k + oti_module.e(k + 1, order=2 * n_order))
244 return differences_by_dim
247# =============================================================================
248# Derivative mapping utilities
249# =============================================================================
251def deriv_map(nbases, order):
252 """
253 Creates a mapping from (order, index_within_order) to a single
254 flattened index for all derivative components.
255 """
256 k = 0
257 map_deriv = []
258 for ordi in range(order + 1):
259 ndir = coti.ndir_order(nbases, ordi)
260 map_deriv_i = [0] * ndir
261 for idx in range(ndir):
262 map_deriv_i[idx] = k
263 k += 1
264 map_deriv.append(map_deriv_i)
265 return map_deriv
268def transform_der_indices(der_indices, der_map):
269 """
270 Transforms a list of user-facing derivative specifications into the
271 internal (order, index) format and the final flattened index.
272 """
273 deriv_ind_transf = []
274 deriv_ind_order = []
275 for deriv in der_indices:
276 imdir = coti.imdir(deriv)
277 idx, order = imdir
278 deriv_ind_transf.append(der_map[order][idx])
279 deriv_ind_order.append(imdir)
280 return deriv_ind_transf, deriv_ind_order
283# =============================================================================
284# RBF Kernel Assembly Functions (Optimized with Numba)
285# =============================================================================
287@profile
288def rbf_kernel(
289 phi,
290 phi_exp,
291 n_order,
292 n_bases,
293 der_indices,
294 powers,
295 index=-1,
296):
297 """
298 Constructs the RBF kernel matrix with derivative entries using an
299 efficient pre-allocation strategy combined with a single call to
300 extract all derivative components.
302 This version uses Numba-accelerated functions for efficient matrix slicing,
303 replacing expensive np.ix_ operations.
305 Parameters
306 ----------
307 phi : OTI array
308 Base kernel matrix from kernel_func(differences, length_scales).
309 phi_exp : ndarray
310 Expanded derivative array from phi.get_all_derivs().
311 n_order : int
312 Maximum derivative order.
313 n_bases : int
314 Number of OTI bases.
315 der_indices : list
316 Derivative specifications.
317 powers : list of int
318 Sign powers for each derivative type.
319 index : list of lists
320 Training point indices for each derivative type.
322 Returns
323 -------
324 K : ndarray
325 Full RBF kernel matrix with mixed function and derivative entries.
326 """
327 dh = coti.get_dHelp()
329 # Create maps to translate derivative specifications to flat indices
330 der_map = deriv_map(n_bases, 2 * n_order)
331 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
333 # Determine Block Sizes and Pre-allocate Matrix
334 n_rows_func, n_cols_func = phi.shape
335 n_deriv_types = len(der_indices)
336 n_pts_with_derivs_cols = sum(len(order_indices) for order_indices in index)
337 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index)
338 total_rows = n_rows_func + n_pts_with_derivs_rows
339 total_cols = n_cols_func + n_pts_with_derivs_cols
341 K = np.zeros((total_rows, total_cols))
342 base_shape = (n_rows_func, n_cols_func)
344 # Pre-compute signs (avoid repeated exponentiation)
345 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
347 # Convert index lists to numpy arrays for numba
348 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
350 # Block (0,0): Function-Function (K_ff)
351 content_full = phi_exp[0].reshape(base_shape)
352 K[:n_rows_func, :n_cols_func] = content_full * signs[0]
354 # First Block-Row: Function-Derivative (K_fd)
355 col_offset = n_cols_func
356 for j in range(n_deriv_types):
357 flat_idx = der_indices_tr[j]
358 content_full = phi_exp[flat_idx].reshape(base_shape)
359 current_indices = index_arrays[j]
360 n_pts_this_order = len(current_indices)
362 # Use numba for efficient column extraction and assignment
363 extract_cols_and_assign(content_full, current_indices, K,
364 0, col_offset, n_rows_func, signs[j + 1])
365 col_offset += n_pts_this_order
367 # First Block-Column: Derivative-Function (K_df)
368 row_offset = n_rows_func
369 for i in range(n_deriv_types):
370 flat_idx = der_indices_tr[i]
371 content_full = phi_exp[flat_idx].reshape(base_shape)
372 current_indices = index_arrays[i]
373 n_pts_this_order = len(current_indices)
375 # Use numba for efficient row extraction and assignment
376 extract_rows_and_assign(content_full, current_indices, K,
377 row_offset, 0, n_cols_func, signs[0])
378 row_offset += n_pts_this_order
380 # Inner Blocks: Derivative-Derivative (K_dd)
381 row_offset = n_rows_func
382 for i in range(n_deriv_types):
383 col_offset = n_cols_func
384 row_indices = index_arrays[i]
385 n_pts_row = len(row_indices)
387 for j in range(n_deriv_types):
388 col_indices = index_arrays[j]
389 n_pts_col = len(col_indices)
391 # Multiply the derivative indices to find the correct flat index
392 imdir1 = der_ind_order[j]
393 imdir2 = der_ind_order[i]
394 new_idx, new_ord = dh.mult_dir(
395 imdir1[0], imdir1[1], imdir2[0], imdir2[1])
396 flat_idx = der_map[new_ord][new_idx]
397 content_full = phi_exp[flat_idx].reshape(base_shape)
399 # Use numba for efficient submatrix extraction and assignment (replaces np.ix_)
400 extract_and_assign(content_full, row_indices, col_indices, K,
401 row_offset, col_offset, signs[j + 1])
403 col_offset += n_pts_col
404 row_offset += n_pts_row
406 return K
409@numba.jit(nopython=True, cache=True)
410def _assemble_kernel_numba(phi_exp_3d, K, n_rows_func, n_cols_func,
411 fd_flat_indices, df_flat_indices, dd_flat_indices,
412 idx_flat, idx_offsets, idx_sizes,
413 signs, n_deriv_types, row_offsets, col_offsets):
414 """
415 Fused numba kernel that assembles the entire K matrix in a single call.
416 Handles ff, fd, df, and dd blocks without Python-level loop overhead.
417 """
418 # Block (0,0): Function-Function
419 s0 = signs[0]
420 for r in range(n_rows_func):
421 for c in range(n_cols_func):
422 K[r, c] = phi_exp_3d[0, r, c] * s0
424 # First Block-Row: Function-Derivative (fd)
425 for j in range(n_deriv_types):
426 fi = fd_flat_indices[j]
427 sj = signs[j + 1]
428 co = col_offsets[j]
429 off_j = idx_offsets[j]
430 sz_j = idx_sizes[j]
431 for r in range(n_rows_func):
432 for k in range(sz_j):
433 ci = idx_flat[off_j + k]
434 K[r, co + k] = phi_exp_3d[fi, r, ci] * sj
436 # First Block-Column: Derivative-Function (df)
437 for i in range(n_deriv_types):
438 fi = df_flat_indices[i]
439 ro = row_offsets[i]
440 off_i = idx_offsets[i]
441 sz_i = idx_sizes[i]
442 for k in range(sz_i):
443 ri = idx_flat[off_i + k]
444 for c in range(n_cols_func):
445 K[ro + k, c] = phi_exp_3d[fi, ri, c] * s0
447 # Inner Blocks: Derivative-Derivative (dd)
448 for i in range(n_deriv_types):
449 ro = row_offsets[i]
450 off_i = idx_offsets[i]
451 sz_i = idx_sizes[i]
452 for j in range(n_deriv_types):
453 fi = dd_flat_indices[i, j]
454 sj = signs[j + 1]
455 co = col_offsets[j]
456 off_j = idx_offsets[j]
457 sz_j = idx_sizes[j]
458 for ki in range(sz_i):
459 ri = idx_flat[off_i + ki]
460 for kj in range(sz_j):
461 ci = idx_flat[off_j + kj]
462 K[ro + ki, co + kj] = phi_exp_3d[fi, ri, ci] * sj
465@numba.jit(nopython=True, cache=True)
466def _project_W_to_phi_space(W, W_proj, n_rows_func, n_cols_func,
467 fd_flat_indices, df_flat_indices, dd_flat_indices,
468 idx_flat, idx_offsets, idx_sizes,
469 signs, n_deriv_types, row_offsets, col_offsets):
470 """
471 Reverse of _assemble_kernel_numba: project W from K-space back into
472 phi_exp-space so that vdot(W, assemble(dphi_exp)) == vdot(W_proj, dphi_exp).
473 """
474 for d in range(W_proj.shape[0]):
475 for r in range(W_proj.shape[1]):
476 for c in range(W_proj.shape[2]):
477 W_proj[d, r, c] = 0.0
478 s0 = signs[0]
479 for r in range(n_rows_func):
480 for c in range(n_cols_func):
481 W_proj[0, r, c] += s0 * W[r, c]
482 for j in range(n_deriv_types):
483 fi = fd_flat_indices[j]
484 sj = signs[j + 1]
485 co = col_offsets[j]
486 off_j = idx_offsets[j]
487 sz_j = idx_sizes[j]
488 for r in range(n_rows_func):
489 for k in range(sz_j):
490 ci = idx_flat[off_j + k]
491 W_proj[fi, r, ci] += sj * W[r, co + k]
492 for i in range(n_deriv_types):
493 fi = df_flat_indices[i]
494 ro = row_offsets[i]
495 off_i = idx_offsets[i]
496 sz_i = idx_sizes[i]
497 for k in range(sz_i):
498 ri = idx_flat[off_i + k]
499 for c in range(n_cols_func):
500 W_proj[fi, ri, c] += s0 * W[ro + k, c]
501 for i in range(n_deriv_types):
502 ro = row_offsets[i]
503 off_i = idx_offsets[i]
504 sz_i = idx_sizes[i]
505 for j in range(n_deriv_types):
506 fi = dd_flat_indices[i, j]
507 sj = signs[j + 1]
508 co = col_offsets[j]
509 off_j = idx_offsets[j]
510 sz_j = idx_sizes[j]
511 for ki in range(sz_i):
512 ri = idx_flat[off_i + ki]
513 for kj in range(sz_j):
514 ci = idx_flat[off_j + kj]
515 W_proj[fi, ri, ci] += sj * W[ro + ki, co + kj]
518def precompute_kernel_plan(n_order, n_bases, der_indices, powers, index):
519 """
520 Precompute all structural information needed by rbf_kernel so it can be
521 reused across repeated calls with different phi_exp values.
523 Returns a dict containing flat indices, signs, index arrays, precomputed
524 offsets/sizes, and mult_dir results for the dd block.
525 """
526 dh = coti.get_dHelp()
527 der_map = deriv_map(n_bases, 2 * n_order)
528 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
530 n_deriv_types = len(der_indices)
531 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
532 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
534 # Precompute sizes and offsets
535 index_sizes = np.array([len(idx) for idx in index_arrays], dtype=np.int64)
536 n_pts_with_derivs = int(index_sizes.sum())
538 # Pack all index arrays into a single flat array with offsets
539 idx_flat = np.concatenate(index_arrays) if n_deriv_types > 0 else np.array([], dtype=np.int64)
540 idx_offsets = np.zeros(n_deriv_types, dtype=np.int64)
541 for i in range(1, n_deriv_types):
542 idx_offsets[i] = idx_offsets[i - 1] + index_sizes[i - 1]
544 # Precompute row/col offsets in K for each deriv type
545 row_offsets = np.zeros(n_deriv_types, dtype=np.int64)
546 col_offsets = np.zeros(n_deriv_types, dtype=np.int64)
547 # Note: n_rows_func == n_cols_func for training kernel, but we store
548 # offsets relative to n_rows_func which is added at call time
549 cumsum = 0
550 for i in range(n_deriv_types):
551 row_offsets[i] = cumsum # relative to n_rows_func
552 col_offsets[i] = cumsum # relative to n_cols_func
553 cumsum += index_sizes[i]
555 # Precompute mult_dir results for dd blocks
556 dd_flat_indices = np.empty((n_deriv_types, n_deriv_types), dtype=np.int64)
557 for i in range(n_deriv_types):
558 for j in range(n_deriv_types):
559 imdir1 = der_ind_order[j]
560 imdir2 = der_ind_order[i]
561 new_idx, new_ord = dh.mult_dir(
562 imdir1[0], imdir1[1], imdir2[0], imdir2[1])
563 dd_flat_indices[i, j] = der_map[new_ord][new_idx]
565 # fd and df flat indices as arrays
566 fd_flat_indices = np.array(der_indices_tr, dtype=np.int64)
567 df_flat_indices = np.array(der_indices_tr, dtype=np.int64)
569 return {
570 'der_indices_tr': der_indices_tr,
571 'signs': signs,
572 'index_arrays': index_arrays,
573 'index_sizes': index_sizes,
574 'n_pts_with_derivs': n_pts_with_derivs,
575 'dd_flat_indices': dd_flat_indices,
576 'n_deriv_types': n_deriv_types,
577 # Fused kernel data
578 'idx_flat': idx_flat,
579 'idx_offsets': idx_offsets,
580 'row_offsets': row_offsets,
581 'col_offsets': col_offsets,
582 'fd_flat_indices': fd_flat_indices,
583 'df_flat_indices': df_flat_indices,
584 }
587def rbf_kernel_fast(phi_exp_3d, plan, out=None):
588 """
589 Fast kernel assembly using a precomputed plan and fused numba kernel.
591 Parameters
592 ----------
593 phi_exp_3d : ndarray of shape (n_derivs, n_rows_func, n_cols_func)
594 Pre-reshaped expanded derivative array.
595 plan : dict
596 Precomputed plan from precompute_kernel_plan().
597 out : ndarray, optional
598 Pre-allocated output array. If None, a new array is allocated.
600 Returns
601 -------
602 K : ndarray
603 Full kernel matrix.
604 """
605 n_rows_func = phi_exp_3d.shape[1]
606 n_cols_func = phi_exp_3d.shape[2]
607 total = n_rows_func + plan['n_pts_with_derivs']
608 if out is not None:
609 K = out
610 else:
611 K = np.empty((total, total))
613 if 'row_offsets_abs' in plan:
614 row_off = plan['row_offsets_abs']
615 col_off = plan['col_offsets_abs']
616 else:
617 row_off = plan['row_offsets'] + n_rows_func
618 col_off = plan['col_offsets'] + n_cols_func
620 _assemble_kernel_numba(
621 phi_exp_3d, K, n_rows_func, n_cols_func,
622 plan['fd_flat_indices'], plan['df_flat_indices'], plan['dd_flat_indices'],
623 plan['idx_flat'], plan['idx_offsets'], plan['index_sizes'],
624 plan['signs'], plan['n_deriv_types'], row_off, col_off,
625 )
627 return K
630@profile
631def rbf_kernel_predictions(
632 phi,
633 phi_exp,
634 n_order,
635 n_bases,
636 der_indices,
637 powers,
638 return_deriv,
639 index=-1,
640 common_derivs=None,
641 calc_cov=False,
642 powers_predict=None
643):
644 """
645 Constructs the RBF kernel matrix for predictions with derivative entries.
647 This version uses Numba-accelerated functions for efficient matrix slicing.
649 Parameters
650 ----------
651 phi : OTI array
652 Base kernel matrix between test and training points.
653 phi_exp : ndarray
654 Expanded derivative array from phi.get_all_derivs().
655 n_order : int
656 Maximum derivative order.
657 n_bases : int
658 Number of OTI bases.
659 der_indices : list
660 Derivative specifications for training data.
661 powers : list of int
662 Sign powers for each derivative type.
663 return_deriv : bool
664 If True, predict derivatives at test points.
665 index : list of lists
666 Training point indices for each derivative type.
667 common_derivs : list
668 Common derivative indices to predict.
669 calc_cov : bool
670 If True, computing covariance.
671 powers_predict : list of int, optional
672 Sign powers for prediction derivatives.
674 Returns
675 -------
676 K : ndarray
677 Prediction kernel matrix.
678 """
679 if calc_cov and not return_deriv:
680 return phi.real
682 dh = coti.get_dHelp()
684 n_rows_func, n_cols_func = phi.shape
685 n_deriv_types = len(der_indices)
686 n_deriv_types_pred = len(common_derivs) if common_derivs else 0
688 # Pre-compute signs
689 signs = np.array([(-1.0) ** p for p in powers], dtype=np.float64)
690 if powers_predict is not None:
691 signs_predict = np.array([(-1.0) ** p for p in powers_predict], dtype=np.float64)
692 else:
693 signs_predict = signs
695 if return_deriv:
696 der_map = deriv_map(n_bases, 2 * n_order)
697 index_2 = np.arange(phi_exp.shape[-1], dtype=np.int64)
698 if calc_cov:
699 index_cov = np.arange(phi_exp.shape[-1], dtype=np.int64)
700 n_deriv_types = n_deriv_types_pred
701 n_pts_with_derivs_rows = n_deriv_types * len([i for i in range(n_cols_func) if i < len(index_2)])
702 else:
703 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index) if isinstance(index, list) else 0
704 else:
705 der_map = deriv_map(n_bases, n_order)
706 index_2 = np.array([], dtype=np.int64)
707 n_pts_with_derivs_rows = sum(len(order_indices) for order_indices in index) if isinstance(index, list) else 0
709 der_indices_tr, der_ind_order = transform_der_indices(der_indices, der_map)
710 der_indices_tr_pred, der_ind_order_pred = transform_der_indices(common_derivs, der_map) if common_derivs else ([], [])
711 n_pts_with_derivs_cols = n_deriv_types_pred * len([i for i in range(n_cols_func) if i < len(index_2)])
713 total_rows = n_rows_func + n_pts_with_derivs_rows
714 total_cols = n_cols_func + n_pts_with_derivs_cols
716 K = np.zeros((total_rows, total_cols))
717 base_shape = (n_rows_func, n_cols_func)
719 # Convert index lists to numpy arrays for numba
720 if isinstance(index, list) and len(index) > 0:
721 index_arrays = [np.asarray(idx, dtype=np.int64) for idx in index]
722 else:
723 index_arrays = []
725 # Block (0,0): Function-Function (K_ff)
726 content_full = phi_exp[0].reshape(base_shape)
727 K[:n_rows_func, :n_cols_func] = content_full * signs[0]
729 if not return_deriv:
730 # First Block-Column: Derivative-Function (K_df)
731 row_offset = n_rows_func
732 for i in range(n_deriv_types):
733 if not index_arrays:
734 break
735 row_indices = index_arrays[i]
736 n_pts_row = len(row_indices)
738 flat_idx = der_indices_tr[i]
739 content_full = phi_exp[flat_idx].reshape(base_shape)
741 # Use numba for efficient row extraction
742 extract_rows_and_assign(content_full, row_indices, K,
743 row_offset, 0, n_cols_func, signs[0])
744 row_offset += n_pts_row
745 return K
747 # --- return_deriv=True case ---
749 # First Block-Row: Function-Derivative (K_fd)
750 col_offset = n_cols_func
751 for j in range(n_deriv_types_pred):
752 col_indices = index_2
753 n_pts_col = len(col_indices)
755 flat_idx = der_indices_tr_pred[j]
756 content_full = phi_exp[flat_idx].reshape(base_shape)
758 # Use numba for efficient column extraction
759 extract_cols_and_assign(content_full, col_indices, K,
760 0, col_offset, n_rows_func, signs_predict[j + 1])
761 col_offset += n_pts_col
763 # First Block-Column: Derivative-Function (K_df)
764 row_offset = n_rows_func
765 for i in range(n_deriv_types):
766 if calc_cov:
767 row_indices = index_cov
768 flat_idx = der_indices_tr_pred[i]
769 else:
770 if not index_arrays:
771 break
772 row_indices = index_arrays[i]
773 flat_idx = der_indices_tr[i]
774 n_pts_row = len(row_indices)
776 content_full = phi_exp[flat_idx].reshape(base_shape)
778 # Use numba for efficient row extraction
779 extract_rows_and_assign(content_full, row_indices, K,
780 row_offset, 0, n_cols_func, signs[0])
781 row_offset += n_pts_row
783 # Inner Blocks: Derivative-Derivative (K_dd)
784 row_offset = n_rows_func
785 for i in range(n_deriv_types):
786 if calc_cov:
787 row_indices = index_cov
788 else:
789 if not index_arrays:
790 break
791 row_indices = index_arrays[i]
792 n_pts_row = len(row_indices)
794 col_offset = n_cols_func
795 for j in range(n_deriv_types_pred):
796 col_indices = index_2
797 n_pts_col = len(col_indices)
799 # Multiply the derivative indices to find the correct flat index
800 imdir1 = der_ind_order_pred[j]
801 imdir2 = der_ind_order_pred[i] if calc_cov else der_ind_order[i]
802 new_idx, new_ord = dh.mult_dir(
803 imdir1[0], imdir1[1], imdir2[0], imdir2[1])
804 flat_idx = der_map[new_ord][new_idx]
806 content_full = phi_exp[flat_idx].reshape(base_shape)
808 # Use numba for efficient submatrix extraction and assignment (replaces np.ix_)
809 extract_and_assign(content_full, row_indices, col_indices, K,
810 row_offset, col_offset, signs_predict[j + 1])
811 col_offset += n_pts_col
812 row_offset += n_pts_row
814 return K
817# =============================================================================
818# Utility functions
819# =============================================================================
821def determine_weights(diffs_by_dim, diffs_test, length_scales, kernel_func, sigma_n):
822 """
823 Vectorized version: compute interpolation weights for multiple test points at once.
825 Parameters
826 ----------
827 diffs_by_dim : list of ndarray
828 Pairwise differences between training points (by dimension).
829 diffs_test : list of ndarray
830 Pairwise differences between test points and training points (by dimension).
831 Shape: each array is (n_test, n_train) or similar batch dimension.
832 length_scales : array-like
833 Kernel hyperparameters.
834 kernel_func : callable
835 Kernel function.
836 sigma_n : float
837 Noise parameter (if needed).
839 Returns
840 -------
841 weights_matrix : ndarray of shape (n_test, n_train)
842 Interpolation weights for each test point.
843 """
844 # Compute K matrix (training covariance) - same for all test points
845 K = kernel_func(diffs_by_dim, length_scales).real
846 n_train = K.shape[0]
848 # Compute r vectors (test-train covariances) for all test points at once
849 r_all = kernel_func(diffs_test, length_scales).real
850 n_test = r_all.shape[0]
852 # Build augmented system matrix M (same for all test points)
853 M = np.zeros((n_train + 1, n_train + 1))
854 M[:n_train, :n_train] = K
855 M[:n_train, n_train] = 1
856 M[n_train, :n_train] = 1
857 M[n_train, n_train] = 0
859 # Build augmented RHS for all test points
860 r_augmented = np.zeros((n_test, n_train + 1))
861 r_augmented[:, :n_train] = r_all
862 r_augmented[:, n_train] = 1
864 # Solve for all test points at once
865 solution = np.linalg.solve(M, r_augmented.T)
867 # Extract weights (exclude Lagrange multiplier)
868 weights_matrix = solution[:n_train, :].T
870 return weights_matrix
873def to_tuple(item):
874 """Convert list to tuple recursively."""
875 if isinstance(item, list):
876 return tuple(to_tuple(x) for x in item)
877 return item
880def to_list(x):
881 """Convert tuple to list recursively."""
882 if isinstance(x, tuple):
883 return [to_list(i) for i in x]
884 return x
887def find_common_derivatives(all_indices):
888 """Find derivative indices common to all submodels."""
889 sets = [set(to_tuple(elem) for elem in idx_list) for idx_list in all_indices]
890 return sets[0].intersection(*sets[1:])