Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
e059dd8
Tomo base plotting fixes
Sep 16, 2025
993ba86
Semi-working NeRF implementation; bug where it's not converging corre…
Sep 17, 2025
66502b2
NeRF kinda working; bugs with convergence
Sep 17, 2025
bcb7d1d
TV loss squeeze axis
Sep 17, 2025
0bb881b
NeRF works, barebones. Start writing into quantem style
Sep 18, 2025
a6a7014
Pre-transferring to quantem code
Sep 22, 2025
e162dde
Baseline working Nerf reconstructions; created a TomoDDP class
Sep 23, 2025
bb40c4b
Tomo optimizers + schedulers working; need to do object_models
Sep 23, 2025
c704ce3
Working DDP with objects and stuff
Sep 23, 2025
1960f35
Soft constraints also work
Sep 23, 2025
b0ccd8c
Multi-step training with different schedulers kind of working. Need t…
Sep 23, 2025
f9a8afa
Implemented ObjectINN with create volume in it (create_volume new ver…
Sep 23, 2025
f3b5b86
Tomo-NeRF working fully; just need to clean up a little bit
Sep 23, 2025
9c0685e
Merge pull request #103 from electronmicroscopy/tomography-hpc-2
cedriclim1 Sep 23, 2025
867bb8e
Initial commit for background subtraction
cophus Oct 7, 2025
0d09a39
Updating plots for background subtraction
cophus Oct 7, 2025
f34b9e1
Adding function overloading
cophus Oct 7, 2025
9f9ef88
More conservative fitting order
cophus Oct 7, 2025
da10551
Switching from @overload to TypeVar
cophus Oct 10, 2025
bc46375
Merge pull request #109 from cophus/tomo-background
cedriclim1 Oct 14, 2025
f1dc677
Tomo changes
Oct 14, 2025
0136682
Merge branch 'dev' into tomography
Oct 14, 2025
a57d885
Enforcing positivity optional in tomography_dataset with clamp flag
cedriclim1 Oct 16, 2025
078f351
Added SIREN, HSIREN, and Finer with allowed complex inputs
Oct 14, 2025
d83deb1
Added FinerActivation to get_activation_function
Oct 14, 2025
4227e58
Fixed PtychoLite to call from cnn.py instead of cnn2d (now removed)
Oct 14, 2025
2e70de3
Siren + HSiren need np.sqrt instead of torch.sqrt for .uniform_
Oct 14, 2025
9bbb54e
net_list fix on Siren models
Oct 14, 2025
0aec975
Softplus missing from self.net_list
Oct 14, 2025
35eae44
Print pred on cross_Correlation align stack
Oct 23, 2025
0a99af5
Adding back in validation set
Nov 12, 2025
3971489
Added pretraininig functionality, in core ML added custom loss functi…
Nov 21, 2025
69efcfd
Updates
Dec 6, 2025
9937bc5
Test
cedriclim1 Dec 15, 2025
700f205
Outlining dataset_models.py, and the top level Tomography class
cedriclim1 Dec 16, 2025
b90499d
Need to think about tomography_ddp a little bit more, also what shoul…
cedriclim1 Dec 17, 2025
d73c498
SIRT Reconstructions working. TomographyConventional looks a little c…
cedriclim1 Dec 19, 2025
bc98f4b
Implemented tomography_opt.py
cedriclim1 Dec 19, 2025
0fa3740
Starting to write the ML methods for Tomography; Need to figure out h…
cedriclim1 Dec 19, 2025
31bc2eb
Starting DDP stuff
cedriclim1 Jan 20, 2026
acb998e
Pulling from dev
cedriclim1 Jan 20, 2026
70ffb31
Implementing pretraining for object_models, along with added function…
cedriclim1 Jan 20, 2026
0ed758a
Object pretraining INR working
cedriclim1 Jan 20, 2026
c90984c
DDPMixin in ML, pretraining working for objects
cedriclim1 Jan 21, 2026
97c80d6
Added cosine annealing to set_scheduler in OptimizerMixin
cedriclim1 Jan 21, 2026
9900fc3
Reworking TomographyINRDatasets; need to figure out what to do for au…
cedriclim1 Jan 22, 2026
4bda57f
Some device switching bugs that need to be addressed.
cedriclim1 Jan 23, 2026
cc9188f
Working reconstruction loop, need to figure out this device stuff and…
cedriclim1 Jan 23, 2026
d08c334
Various u[dates
cedriclim1 Jan 26, 2026
9a37e05
Logger implementation
cedriclim1 Jan 27, 2026
cfee59a
DDP bug where some projection idx's don't get optimized.
Jan 27, 2026
370934a
DDP projection indices fixed; added hard constraints to the forward m…
Jan 28, 2026
b80cd68
NVIDIA Profiling testing added in the reconstruction loop in tomograp…
cedriclim1 Jan 28, 2026
a61cce3
Starting profiling of the reconstruction loop; need to move stuff ove…
cedriclim1 Jan 29, 2026
8f6cb02
Val + train test split implemented - cuBLAS error after adding this n…
cedriclim1 Jan 30, 2026
791a644
Small updates
cedriclim1 Jan 30, 2026
4b7f108
Implemented a working TomographyLite, need to test AutoSerialize, and…
cedriclim1 Jan 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/quantem/core/ml/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Self

import numpy as np
import torch


@dataclass(slots=True)
class Constraints(ABC):
"""
Needs to be implemented in all object models that inherit from BaseConstraints.
"""

soft_constraint_keys = []
hard_constraint_keys = []

@property
def allowed_keys(self) -> list[str]:
"""
List of all allowed keys.
"""
return self.hard_constraint_keys + self.soft_constraint_keys

def copy(self) -> Self:
"""
Copy the constraints.
"""
return deepcopy(self)

def __str__(self) -> str:
hard = "\n".join(f"{key}: {getattr(self, key)}" for key in self.hard_constraint_keys)
soft = "\n".join(f"{key}: {getattr(self, key)}" for key in self.soft_constraint_keys)

return (
"Constraints:\n"
" Hard constraints:\n"
f" {hard.replace('\n', '\n ')}\n"
" Soft constraints:\n"
f" {soft.replace('\n', '\n ')}"
)


class BaseConstraints(ABC):
"""
Base class for constraints.
"""

# Default constraints are the dataclasses themselves.
DEFAULT_CONSTRAINTS = Constraints()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._soft_constraint_losses = []
self.constraints = self.DEFAULT_CONSTRAINTS.copy()

@property
def soft_constraint_losses(self) -> list[float]:
return np.array(self._soft_constraint_losses)

@property
def constraints(self) -> Constraints:
"""
Constraints for the object model.
"""
return self._constraints

@constraints.setter
def constraints(self, constraints: Constraints | dict[str, Any]):
"""
Setter for constraints class, can be a Constraints instance or a dictionary.
"""
if isinstance(constraints, Constraints):
self._constraints = constraints
elif isinstance(constraints, dict):
for key, value in constraints.items():
setattr(self._constraints, key, value)
else:
raise ValueError(f"Invalid constraints type: {type(constraints)}")

# --- Required methods tha tneeds to implemented in subclasses ---
@abstractmethod
def apply_hard_constraints(self, *args, **kwargs) -> torch.Tensor:
"""
Apply hard constraints to the object model.
"""
raise NotImplementedError

@abstractmethod
def apply_soft_constraints(self, *args, **kwargs) -> torch.Tensor:
"""
Apply soft constraints to the object model.
"""
raise NotImplementedError
171 changes: 171 additions & 0 deletions src/quantem/core/ml/ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import os

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, DistributedSampler, random_split


class DDPMixin:
"""
Class for setting up all distributed training.

-
"""

def __init__(
self,
):
self.setup_distributed()

def setup_distributed(self, device: str | None = None):
"""
Initializes parameters depending if multiple-GPU training, single-GPU training, or CPU training.
"""
if "RANK" in os.environ:
if not dist.is_initialized():
dist.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo", init_method="env://"
)

self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
self.local_rank = int(os.environ["LOCAL_RANK"])

torch.cuda.set_device(self.local_rank)
device = torch.device("cuda", self.local_rank)
else:
self.world_size = 1
self.global_rank = 0
self.local_rank = 0

if torch.cuda.is_available():
device = torch.device("cuda:0" if device is None else device)
torch.cuda.set_device(device.index)
print("Single GPU training")
else:
device = torch.device("cpu")
print("CPU training")

if device.type == "cuda":
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

self.device = device

def setup_dataloader(
self,
dataset: Dataset,
batch_size: int,
num_workers: int = 0,
val_fraction: float = 0.0,
):
pin_mem = self.device.type == "cuda"
persist = num_workers > 0

if val_fraction > 0.0:
train_dataset, val_dataset = random_split(dataset, [1 - val_fraction, val_fraction])
else:
train_dataset = dataset
val_dataset = None

if self.world_size > 1:
shuffle = True
train_sampler = DistributedSampler(
train_dataset,
num_replicas=self.world_size,
rank=self.global_rank,
shuffle=shuffle,
)

if val_dataset:
val_sampler = DistributedSampler(
val_dataset,
num_replicas=self.world_size,
rank=self.global_rank,
shuffle=False,
)
shuffle = False

else:
train_sampler = None
val_sampler = None
shuffle = True

train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
sampler=train_sampler,
shuffle=shuffle,
pin_memory=pin_mem,
drop_last=True,
persistent_workers=persist,
)

if val_dataset:
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size * 4,
num_workers=num_workers,
sampler=val_sampler,
shuffle=False,
pin_memory=pin_mem,
drop_last=False,
persistent_workers=persist,
)
val_dataloader = val_dataloader
else:
val_dataloader = None

if self.global_rank == 0:
print("Dataloader setup complete:")
print(f" Total train samples: {len(train_dataset)}")
print(f" Local batch size: {batch_size}")
print(f" Global batch size: {batch_size * self.world_size}")
print(f" Train batches per GPU per epoch: {len(train_dataloader)}")

if val_dataset:
print(f" Total val samples: {len(val_dataset)}")
print(f" Val batches per GPU per epoch: {len(val_dataloader)}")

return train_dataloader, train_sampler, val_dataloader, val_sampler

def build_model(
self,
model: nn.Module,
pretrained_weights: dict[str, torch.Tensor] | None = None,
) -> nn.Module | nn.parallel.DistributedDataParallel:
"""
Wraps the model with DistributedDataParallel if mulitple GPUs are available.

Returns the model.
"""

model = model.to(self.device)
if pretrained_weights is not None:
model.load_state_dict(pretrained_weights.copy())

if self.world_size > 1:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.local_rank],
output_device=self.local_rank,
find_unused_parameters=False,
broadcast_buffers=True,
bucket_cap_mb=100,
gradient_as_bucket_view=True,
)

if self.global_rank == 0:
print("Model wrapped with DDP")

if self.world_size > 1:
if self.global_rank == 0:
print("Model built, distributed, and compiled successfully")

else:
print("Model built, compiled successfully")

return model
14 changes: 13 additions & 1 deletion src/quantem/core/ml/inr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
hsiren: bool = False,
dtype: torch.dtype = torch.float32,
final_activation: str | Callable = "identity",
winner_initialization: bool = False,
) -> None:
"""Initialize Siren.

Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(
self.alpha = alpha
self.hsiren = hsiren
self.dtype = dtype

self.winner_initialization = winner_initialization
self.final_activation = final_activation

self._build()
Expand Down Expand Up @@ -109,6 +110,15 @@ def _build(self) -> None:
net_list.append(self._final_activation)
self.net = nn.Sequential(*net_list)

if self.winner_initialization:
with torch.no_grad():
self.net[0].linear.weight += (
torch.randn_like(self.net[0].linear.weight) * 5 / self.first_omega_0
)
self.net[1].linear.weight += (
torch.randn_like(self.net[1].linear.weight) * 0.1 / self.hidden_omega_0
)

def forward(self, coords: torch.Tensor) -> torch.Tensor:
output = self.net(coords)
return output
Expand Down Expand Up @@ -182,6 +192,7 @@ def __init__(
alpha: float = 1.0,
dtype: torch.dtype = torch.float32,
final_activation: str | Callable = "identity",
winner_initialization: bool = False,
) -> None:
"""Initialize HSiren.

Expand Down Expand Up @@ -217,4 +228,5 @@ def __init__(
hsiren=True,
dtype=dtype,
final_activation=final_activation,
winner_initialization=winner_initialization,
)
Loading