Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
da41d09
initialize diffractgpt script
crhysc Oct 28, 2025
ea4bfa5
Merge branch 'main' into feature/RamanGPT
crhysc Oct 28, 2025
bea7d59
initialize correct dataset script
crhysc Oct 28, 2025
5e6633b
initialize code to make train test alpaca jsons
crhysc Oct 29, 2025
a6d7842
version 1. should make sentences
crhysc Oct 29, 2025
fbb0c3a
initialize runner
crhysc Oct 29, 2025
fae7eee
name change
crhysc Oct 29, 2025
b7a9cde
update prompt to include ()
crhysc Oct 29, 2025
ff64e83
add cm^-1 to prompt
crhysc Oct 29, 2025
3b987f8
remove rounding bug
crhysc Oct 31, 2025
2425df9
rest of the script
crhysc Oct 31, 2025
9f0b3c2
freq normal and niggli reduce
crhysc Oct 31, 2025
2771adf
add freq upper bound
crhysc Nov 4, 2025
69c3463
initial commit
crhysc Nov 13, 2025
1d15157
Update loader.py
crhysc Nov 13, 2025
8ff5cae
patch if load_in_4bit
crhysc Nov 19, 2025
1cb88ef
upgrade the required transformers version to 4.57.1
Nov 19, 2025
69226ee
add _get_dtype(dtype) to line 466
crhysc Nov 19, 2025
f6efda5
add gpt_oss to def patch_peft_model()
crhysc Nov 19, 2025
f1bab0e
patch num_logits_to_keep
crhysc Nov 19, 2025
7d3d37b
strip num_logits_to_keep in unsloth_fast_generate() for gpt-oss models
crhysc Nov 20, 2025
c02d6a5
Update gpt_oss.py
crhysc Nov 20, 2025
078175d
patch pre_patch()
crhysc Nov 20, 2025
1be754c
force progress bar
crhysc Nov 21, 2025
85d2f82
add invalid structures error handling
crhysc Oct 2, 2025
9552f4e
add error handling for inverse_predict.py
crhysc Nov 18, 2025
8902e70
num_proc
crhysc Nov 5, 2025
431e3f9
if gen_mat = None: ...
crhysc Nov 18, 2025
3bbd9d2
rm "Here is the output"
crhysc Nov 18, 2025
e3d22e1
terminate string literal
crhysc Nov 18, 2025
d9c94af
let tokenizers be >= 0.22.0
Dec 1, 2025
5b4ef60
hf hub >= 0.32.0
Dec 1, 2025
32e831d
hf-xet>=1.1.2
Dec 1, 2025
006cf16
print target and predicted structures if PRINT_STRUCTURES=1
Dec 4, 2025
786ddac
mv print statements before validation checks
Dec 4, 2025
76602a6
let the raw LLM output be printed if PRINT_STRUCTURES=1
Dec 4, 2025
33e52c8
initialize abs factory for loading. add chat template stubs
crhysc Dec 9, 2025
d0dd793
add kwargs to format()
crhysc Dec 9, 2025
1970a36
get harmony template in factory
crhysc Dec 14, 2025
11b2ab6
from typing import Any
crhysc Dec 14, 2025
52228ed
remove relative import
Dec 14, 2025
1464ce1
import callable
Dec 14, 2025
9793ad8
remove import chattemplate
Dec 14, 2025
3cfc59a
remove imports to non-interface objects for model loading and chat te…
Dec 14, 2025
6e8771a
add type checking if statement for the trainingpropconfig import
Dec 14, 2025
9a41780
add unsloth>=2024.10,<2025.3
Dec 14, 2025
25661fe
arrange imports to debug
Dec 14, 2025
c8790f8
from atomgpt.inverse_models.dataset_utils import make_alpaca_json
Dec 14, 2025
780d334
add imports for make_alpaca_json()
Dec 14, 2025
58bac34
mv get_input() to dataset_utils
Dec 14, 2025
73e3096
rm resume from chkpt=true
Dec 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion atomgpt/inverse_models/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@

from typing import Union, Callable, Optional, List, Dict
import torch
from typing import Any
from jarvis.core.atoms import Atoms
from jarvis.io.vasp.inputs import Poscar
from jarvis.core.composition import Composition
from atomgpt.inverse_models.utils import (
gen_atoms,
text2atoms,
get_crystal_string_t,
get_figlet,
)


# From https://www.geeksforgeeks.org/longest-common-substring-array-strings/
Expand Down Expand Up @@ -753,6 +763,100 @@ def _tokenize(example):
)
pass
return dataset
pass

def get_input(config=None, chem="", val=10):
if config.chem_info == "none":
prefix = ""
elif config.chem_info == "element_list":
prefix = (
"The chemical elements are "
+ chem # atoms.composition.search_string
+ " . "
)
elif config.chem_info == "element_dict":
prefix = (
"The chemical contents are "
+ chem # atoms.composition.search_string
+ " . "
)
elif config.chem_info == "formula":
prefix = (
"The chemical formula is "
+ chem # atoms.composition.reduced_formula
+ " . "
)

pass
inp = (
prefix
+ "The "
+ config.prop
+ " is "
+ str(val)
+ "."
+ config.output_prompt
)
return inp

def make_alpaca_json(
dataset=[],
jids=[],
# prop="Tc_supercon",
# instruction="",
include_jid=False,
# chem_info="",
# output_prompt="",
config=None,
):
mem = []
print("config.prop", config.prop)
for i in dataset:
if i[config.prop] != "na" and i[config.id_tag] in jids:
atoms = Atoms.from_dict(i["atoms"])
info = {}
if include_jid:
info["id"] = i[config.id_tag]
info["instruction"] = config.instruction
if config.chem_info == "none":
chem = ""
elif config.chem_info == "element_list":
chem = atoms.composition.search_string
elif config.chem_info == "element_dict":
comp = Composition.from_string(
atoms.composition.reduced_formula
)
chem = comp.to_dict()
chem = str(dict(sorted(chem.items())))
elif config.chem_info == "formula":
chem = atoms.composition.reduced_formula

inp = get_input(config=config, val=i[config.prop], chem=chem)
info["input"] = inp

info["output"] = get_crystal_string_t(atoms)
mem.append(info)
return mem

def alpaca_formatting_prompts_func(examples: Dict[str, Any], alpaca_prompt: str, eos_token: str) -> Dict[str, List[str]]:
inst = examples["instruction"]
inp = examples["input"]
out = examples["output"]
texts = [alpaca_prompt.format(i, x, y) + eos_token for i, x, y in zip(inst, inp, out)]
return {"text": texts}

def harmony_formatting_prompts_func(examples: Dict[str, Any], tokenizer) -> Dict[str, List[str]]:
inst = examples["instruction"]
inp = examples["input"]
out = examples["output"]
texts: List[str] = []
for i, x, y in zip(inst, inp, out):
messages = []
i = (i or "").strip()
x = (x or "").strip()
y = (y or "").strip()
if i:
messages.append({"role": "developer", "content": i})
messages.append({"role": "user", "content": x})
messages.append({"role": "assistant", "content": y})
texts.append(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False))
return {"text": texts}
149 changes: 149 additions & 0 deletions atomgpt/inverse_models/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# factories.py

from __future__ import annotations

from abc import ABC, abstractmethod
from atomgpt.inverse_models.products import LoadedModel
from typing import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from atomgpt.inverse_models.inverse_models import TrainingPropConfig
from peft import PeftModel
from typing import Dict
from atomgpt.inverse_models.dataset_utils import alpaca_formatting_prompts_func
from atomgpt.inverse_models.dataset_utils import harmony_formatting_prompts_func
from functools import partial
from typing import List


class LanguageModelFactory(ABC):
@abstractmethod
def load_for_training(self, config: TrainingPropConfig) -> LoadedModel:
pass

@abstractmethod
def load_for_inference(self, checkpoint_path: str, config: TrainingPropConfig) -> LoadedModel:
pass

@abstractmethod
def get_formatting_prompts_func(self, config, model, tokenizer) -> Callable:
pass


class AtomGPTFactory(LanguageModelFactory):
def load_for_training(self, config: TrainingPropConfig) -> LoadedModel:
from atomgpt.inverse_models.loader import FastLanguageModel as AtomGPTFastLanguageModel
model, tokenizer = AtomGPTFastLanguageModel.from_pretrained(
model_name=config.model_name,
max_seq_length=config.max_seq_length,
dtype=config.dtype,
load_in_4bit=config.load_in_4bit
)
if not isinstance(model, PeftModel):
# import sys
print("Not yet a peft model, converting into peft model")
# sys.exit()
model = AtomGPTFastLanguageModel.get_peft_model(
model,
r=config.lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=config.lora_alpha,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
use_gradient_checkpointing=True,
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
)
print("Peft model created")
EOS_TOKEN = tokenizer.eos_token
return LoadedModel(model=model, tokenizer=tokenizer)

def load_for_inference(self, checkpoint_path: str, config: TrainingPropConfig) -> LoadedModel:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=checkpoint_path,
max_seq_length=config.max_seq_length,
dtype=config.dtype,
load_in_4bit=config.load_in_4bit,
)
AtomGPTFastLanguageModel.for_inference(model)
return LoadedModel(model=model, tokenizer=tokenizer)

def get_formatting_prompts_func(self, config, model, tokenizer) -> Callable:
eos = tokenizer.eos_token or "</s>"
return partial(alpaca_formatting_prompts_func, alpaca_prompt=config.alpaca_prompt, eos_token=eos)


class GPTOSSFactory(LanguageModelFactory):
def load_for_training(self, config: TrainingPropConfig) -> LoadedModel:
from unsloth import FastLanguageModel as UnslothFastLanguageModel
model, tokenizer = UnslothFastLanguageModel.from_pretrained(
model_name=config.model_name,
max_seq_length=config.max_seq_length,
dtype=config.dtype,
load_in_4bit=config.load_in_4bit,
full_finetuning = False,
)
if not isinstance(model, PeftModel):
print("Not yet a peft model, converting into peft model")
model = UnslothFastLanguageModel.get_peft_model(
model,
r=config.lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=config.lora_alpha,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
use_gradient_checkpointing=True,
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
)
print("Peft model created")
return LoadedModel(model=model, tokenizer=tokenizer)

def load_for_inference(self, checkpoint_path: str, config: TrainingPropConfig) -> LoadedModel:
model, tokenizer = UnslothFastLanguageModel.from_pretrained(
model_name=checkpoint_path,
max_seq_length=config.max_seq_length,
dtype=config.dtype,
load_in_4bit=config.load_in_4bit,
)
UnslothFastLanguageModel.for_inference(model)
return LoadedModel(model=model, tokenizer=tokenizer)

def get_formatting_prompts_func(self, config, model, tokenizer) -> Callable:
return partial(harmony_formatting_prompts_func, tokenizer=tokenizer)

FACTORY_REGISTRY: Dict[str, type[LanguageModelFactory]] = {
"gemma": AtomGPTFactory,
"qwen": AtomGPTFactory,
"Meta": AtomGPTFactory,
"Llama": AtomGPTFactory,
"llama": AtomGPTFactory,
"Mistral": AtomGPTFactory,
"mistral": AtomGPTFactory,
"gpt-oss": GPTOSSFactory,
}

def get_lm_factory(config: TrainingPropConfig) -> LanguageModelFactory:
model_name = config.model_name
if "gpt-oss" in model_name:
return GPTOSSFactory()
else:
return AtomGPTFactory()
Loading