Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 85 additions & 7 deletions pyscf/tools/trexio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
'''

import numpy as np
import math
from pyscf import lib
from pyscf import gto
from pyscf import scf
from pyscf import fci
import trexio

def to_trexio(obj, filename, backend='h5'):
Expand Down Expand Up @@ -91,9 +93,9 @@ def _cc_to_trexio(cc_obj, trexio_file):
def _mcscf_to_trexio(cas_obj, trexio_file):
raise NotImplementedError

def mol_from_trexio(filename, backend='h5'):
def mol_from_trexio(filename):
mol = gto.Mole()
with trexio.File(filename, 'r', back_end=_mode(backend)) as tf:
with trexio.File(filename, 'r', back_end=trexio.TREXIO_AUTO) as tf:
assert trexio.read_basis_type(tf) == 'Gaussian'
if trexio.has_ecp(tf):
raise NotImplementedError
Expand Down Expand Up @@ -130,9 +132,9 @@ def mol_from_trexio(filename, backend='h5'):
mol._basis = basis
return mol.build()

def scf_from_trexio(filename, backend='h5'):
mol = mol_from_trexio(filename, backend)
with trexio.File(filename, 'r', back_end=_mode(backend)) as tf:
def scf_from_trexio(filename):
mol = mol_from_trexio(filename)
with trexio.File(filename, 'r', back_end=trexio.TREXIO_AUTO) as tf:
mo_energy = trexio.read_mo_energy(tf)
mo = trexio.read_mo_coefficient(tf)
mo_occ = trexio.read_mo_occupation(tf)
Expand Down Expand Up @@ -172,9 +174,9 @@ def write_eri(eri, filename, backend='h5'):
with trexio.File(filename, 'w', back_end=_mode(backend)) as tf:
trexio.write_mo_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel())

def read_eri(filename, backend='h5'):
def read_eri(filename):
'''Read ERIs in AO basis, 8-fold symmetry is assumed'''
with trexio.File(filename, 'r', back_end=_mode(backend)) as tf:
with trexio.File(filename, 'r', back_end=trexio.TREXIO_AUTO) as tf:
nmo = trexio.read_mo_num(tf)
nao_pair = nmo * (nmo+1) // 2
eri_size = nao_pair * (nao_pair+1) // 2
Expand Down Expand Up @@ -223,3 +225,79 @@ def _group_by(a, keys):
assert all(keys[:-1] <= keys[1:])
idx = np.unique(keys, return_index=True)[1]
return np.split(a, idx[1:])

def get_occsa_and_occsb(mcscf, norb, nelec, ci_threshold=0.):
ci_coeff = mcscf.ci
num_determinants = int(np.sum(np.abs(ci_coeff) > ci_threshold))
occslst = fci.cistring.gen_occslst(range(norb), nelec // 2)
selected_occslst = occslst[:num_determinants]

occsa = []
occsb = []
ci_values = []

for i in range(min(len(selected_occslst), mcscf.ci.shape[0])):
for j in range(min(len(selected_occslst), mcscf.ci.shape[1])):
ci_coeff = mcscf.ci[i, j]
if np.abs(ci_coeff) > ci_threshold: # Check if CI coefficient is significant compared to user defined value
occsa.append(selected_occslst[i])
occsb.append(selected_occslst[j])
ci_values.append(ci_coeff)

# Sort by the absolute value of the CI coefficients in descending order
sorted_indices = np.argsort(-np.abs(ci_values))
occsa_sorted = [occsa[idx] for idx in sorted_indices]
occsb_sorted = [occsb[idx] for idx in sorted_indices]
ci_values_sorted = [ci_values[idx] for idx in sorted_indices]

return occsa_sorted, occsb_sorted, ci_values_sorted, num_determinants

def det_to_trexio(mcscf, norb, nelec, filename, backend='h5', ci_threshold=0., chunk_size=100000):
from trexio_tools.group_tools import determinant as trexio_det

mo_num = norb
int64_num = int((mo_num - 1) / 64) + 1
occsa, occsb, ci_values, num_determinants = get_occsa_and_occsb(mcscf, norb, nelec, ci_threshold)

det_list = []
for a, b, coeff in zip(occsa, occsb, ci_values):
occsa_upshifted = [orb + 1 for orb in a]
occsb_upshifted = [orb + 1 for orb in b]
det_tmp = []
det_tmp += trexio_det.to_determinant_list(occsa_upshifted, int64_num)
det_tmp += trexio_det.to_determinant_list(occsb_upshifted, int64_num)
det_list.append(det_tmp)

if num_determinants > chunk_size:
n_chunks = math.ceil(num_determinants / chunk_size)
else:
n_chunks = 1

with trexio.File(filename, 'u', back_end=_mode(backend)) as tf:
if trexio.has_determinant(tf):
trexio.delete_determinant(tf)
trexio.write_mo_num(tf, mo_num)
trexio.write_electron_up_num(tf, len(a))
trexio.write_electron_dn_num(tf, len(b))
trexio.write_electron_num(tf, len(a) + len(b))

offset_file = 0
for i in range(n_chunks):
start = i * chunk_size
end = min((i + 1) * chunk_size, num_determinants)
current_chunk_size = end - start

if current_chunk_size > 0:
trexio.write_determinant_list(tf, offset_file, current_chunk_size, det_list[start:end])
trexio.write_determinant_coefficient(tf, offset_file, current_chunk_size, ci_values[start:end])
offset_file += current_chunk_size

def read_det_trexio(filename):
with trexio.File(filename, 'r', back_end=trexio.TREXIO_AUTO) as tf:
offset_file = 0

num_det = trexio.read_determinant_num(tf)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more of a suggestion (I do not insist for it to be implemented here): If one day num_det becomes huge - the user can run out of memory while reading the CI info. This is why we allow to read/write determinants in buffers (or chunks) of fixed size. It might be a good idea to define some relatively big buffer_size (e.g. 100k) and if num_det is larger than that value - read the data in chunks. Example of writing in chunks (using offset_file to advance the data pointers) can be found in trexio-tutorials.

Copy link
Contributor Author

@NastaMauger NastaMauger Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done
Let me know if it meets your requirement.

coeff = trexio.read_determinant_coefficient(tf, offset_file, num_det)
det = trexio.read_determinant_list(tf, offset_file, num_det)
return num_det, coeff, det