qml.labs.phox¶
Phase optimization with JAX (PHOX)¶
|
Configuration data for an IQP circuit simulation. |
|
Factory that returns a function for computing expectation values. |
|
Compute expectation value for the Bitflip noise model. |
|
Main training function. |
|
Generator that yields training results in batches of size 'unroll_steps'. |
|
Configuration options for training. |
|
Container for final training results. |
|
Result from a single batch (unrolled chunk) of training steps. |
Circuit construction utilities¶
|
Generates gates based on nearest-neighbor interactions on a 2D lattice. |
|
Generates a gate dictionary for the Phox simulator containing all gates whose generators have Pauli weight less or equal to max_weight. |
|
Generates a dictionary of random gates. |
|
Generates a batch of Pauli observables. |
Workflow¶
pennylane.labs.phox provides a compact toolkit for constructing and
simulating phase optimization circuits with JAX. The usual workflow is:
Use helpers in
pennylane.labs.phox.utilsto assemble gates and observables.Configure the circuit with
CircuitConfig.Build an expectation-value function with
build_expval_func()and evaluate it for different parameter sets.
import jax
from pennylane.labs.phox import (
CircuitConfig,
build_expval_func,
create_lattice_gates,
generate_pauli_observables,
)
n_rows, n_cols = 3, 3
n_qubits = n_rows * n_cols
gates = create_lattice_gates(n_rows, n_cols, distance=1, max_weight=2)
observables = generate_pauli_observables(n_qubits, orders=[2], bases=["Z"])
key = jax.random.PRNGKey(0)
params = jax.random.uniform(key, shape=(len(gates),))
config = CircuitConfig(
gates=gates,
observables=observables,
n_samples=4000,
key=key,
n_qubits=n_qubits,
)
expval_fn = jax.jit(build_expval_func(config))
expvals, std_errs = expval_fn(params)
Training¶
Below is a small training loop that minimizes the sum of all two-body Z
correlators on the same 3x3 lattice. The loss function reuses the
compiled expval_fn from above.
import jax.numpy as jnp
from pennylane.labs.phox import TrainingOptions, train
def loss_fn(current_params):
expvals, _ = expval_fn(current_params)
return jnp.sum(expvals)
result = train(
optimizer="Adam",
loss=loss_fn,
stepsize=0.05,
n_iters=200,
loss_kwargs={"params": params},
options=TrainingOptions(unroll_steps=10, random_state=1234),
)
print("Final loss:", float(result.losses[-1]))
print("Optimized parameters:", result.final_params)