Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 72 additions & 26 deletions geos-mesh/src/geos/mesh/utils/arrayHelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
import numpy as np
import numpy.typing as npt
import pandas as pd # type: ignore[import-untyped]
from typing import Union, Any

import vtkmodules.util.numpy_support as vnp
from typing import Optional, Union, Any
from vtkmodules.util.numpy_support import vtk_to_numpy
from vtkmodules.vtkCommonCore import vtkDataArray, vtkPoints
from vtkmodules.vtkCommonDataModel import ( vtkUnstructuredGrid, vtkFieldData, vtkMultiBlockDataSet, vtkDataSet,
vtkCompositeDataSet, vtkDataObject, vtkPointData, vtkCellData, vtkPolyData,
vtkCell )
from vtkmodules.vtkFiltersCore import vtkCellCenters
from geos.mesh.utils.multiblockHelpers import getBlockElementIndexesFlatten

from geos.mesh.utils.multiblockHelpers import getBlockElementIndexesFlatten
from geos.utils.pieceEnum import Piece

__doc__ = """
Expand Down Expand Up @@ -46,7 +47,7 @@
"""


def getCellDimension( mesh: Union[ vtkMultiBlockDataSet, vtkDataSet ] ) -> set[ int ]:
def getCellDimension( mesh: Union[ vtkMultiBlockDataSet, vtkDataSet ], ) -> set[ int ]:
"""Get the set of the different cells dimension of a mesh.

Args:
Expand Down Expand Up @@ -183,7 +184,10 @@ def computeElementMapping(
return elementMap


def hasArray( mesh: vtkUnstructuredGrid, arrayNames: list[ str ] ) -> bool:
def hasArray(
mesh: vtkUnstructuredGrid,
arrayNames: list[ str ],
) -> bool:
"""Checks if input mesh contains at least one of input data arrays.

Args:
Expand All @@ -196,7 +200,6 @@ def hasArray( mesh: vtkUnstructuredGrid, arrayNames: list[ str ] ) -> bool:
for piece in [ Piece.CELLS, Piece.POINTS, Piece.FIELD ]:
for arrayName in arrayNames:
if isAttributeInObject( mesh, arrayName, piece ):
logging.error( f"The mesh contains the array named '{ arrayName }'." )
return True

return False
Expand Down Expand Up @@ -282,11 +285,16 @@ def checkValidValuesInObject(
return ( validValues, invalidValues )


def getNumpyGlobalIdsArray( data: Union[ vtkCellData, vtkPointData ] ) -> npt.NDArray:
def getNumpyGlobalIdsArray(
data: Union[ vtkCellData, vtkPointData ],
globalIdName: str | None = None,
) -> npt.NDArray:
"""Get a numpy array of the GlobalIds if it exist.

Args:
data (Union[ vtkCellData, vtkPointData ]): Cell or point array.
data (Union[vtkCellData, vtkPointData]): Cell or point array.
globalIdName (str|None, optional): The name of the attribute to consider as the one with the globalIds if it is not the default one.
Default to None.

Returns:
(npt.NDArray): The numpy array of GlobalIds.
Expand All @@ -298,23 +306,36 @@ def getNumpyGlobalIdsArray( data: Union[ vtkCellData, vtkPointData ] ) -> npt.ND
if not isinstance( data, vtkFieldData ):
raise TypeError( f"data '{ data }' entered is not a vtkFieldData object." )

global_ids: Optional[ vtkDataArray ] = data.GetGlobalIds()
if global_ids is None:
raise AttributeError( "There is no GlobalIds in the given fieldData." )
globalIds = data.GetGlobalIds( globalIdName ) # type: ignore[arg-type]

if globalIds is None:
mess: str
if globalIdName is None:
mess = "There is no GlobalIds in the fieldData."
else:
mess = f"The attribute { globalIdName } to consider as GlobalIds is not in the fieldData."
raise AttributeError( mess )

return vtk_to_numpy( global_ids )
return vtk_to_numpy( globalIds )


def getNumpyArrayByName( data: Union[ vtkCellData, vtkPointData ], name: str, sorted: bool = False ) -> npt.NDArray:
def getNumpyArrayByName(
data: Union[ vtkCellData, vtkPointData ],
name: str,
sorted: bool = False,
globalIdName: str | None = None,
) -> npt.NDArray:
"""Get the numpy array of a given vtkDataArray found by its name.

If sorted is selected, this allows the option to reorder the values wrt GlobalIds. If not GlobalIds was found,
no reordering will be perform.
If sorted is selected, this allows the option to reorder the values wrt GlobalIds.

Args:
data (Union[vtkCellData, vtkPointData]): Vtk field data.
name (str): Array name to sort.
sorted (bool, optional): Sort the output array with the help of GlobalIds. Defaults to False.
sorted (bool, optional): Sort the output array with the help of GlobalIds.
Defaults to False.
globalIdName (str|None, optional): The name of the attribute to consider as the one with the globalIds if it is not the default one.
Default to None.

Returns:
npt.NDArray: Sorted array.
Expand All @@ -327,13 +348,16 @@ def getNumpyArrayByName( data: Union[ vtkCellData, vtkPointData ], name: str, so

npArray: npt.NDArray = vtk_to_numpy( data.GetArray( name ) )
if sorted and ( data.IsA( "vtkCellData" ) or data.IsA( "vtkPointData" ) ):
globalids: npt.NDArray = getNumpyGlobalIdsArray( data )
globalids: npt.NDArray = getNumpyGlobalIdsArray( data, globalIdName )
npArray = npArray[ np.argsort( globalids ) ]

return npArray


def getAttributeSet( mesh: Union[ vtkMultiBlockDataSet, vtkDataSet ], piece: Piece ) -> set[ str ]:
def getAttributeSet(
mesh: Union[ vtkMultiBlockDataSet, vtkDataSet ],
piece: Piece,
) -> set[ str ]:
"""Get the set of all attributes from an mesh on points or on cells.

Args:
Expand Down Expand Up @@ -400,7 +424,11 @@ def getAttributesWithNumberOfComponents(
return attributes


def isAttributeInObject( mesh: Union[ vtkMultiBlockDataSet, vtkDataSet ], attributeName: str, piece: Piece ) -> bool:
def isAttributeInObject(
mesh: Union[ vtkMultiBlockDataSet, vtkDataSet ],
attributeName: str,
piece: Piece,
) -> bool:
"""Check if an attribute is in the input mesh for the given piece.

Args:
Expand Down Expand Up @@ -437,7 +465,11 @@ def isAttributeInObject( mesh: Union[ vtkMultiBlockDataSet, vtkDataSet ], attrib
return False


def isAttributeGlobal( multiBlockDataSet: vtkMultiBlockDataSet, attributeName: str, piece: Piece ) -> bool:
def isAttributeGlobal(
multiBlockDataSet: vtkMultiBlockDataSet,
attributeName: str,
piece: Piece,
) -> bool:
"""Check if an attribute is global in the input multiBlockDataSet.

Args:
Expand All @@ -457,7 +489,11 @@ def isAttributeGlobal( multiBlockDataSet: vtkMultiBlockDataSet, attributeName: s
return True


def getArrayInObject( dataSet: vtkDataSet, attributeName: str, piece: Piece ) -> npt.NDArray[ Any ]:
def getArrayInObject(
dataSet: vtkDataSet,
attributeName: str,
piece: Piece,
) -> npt.NDArray[ Any ]:
"""Return the numpy array corresponding to input attribute name in the mesh.

Args:
Expand All @@ -474,7 +510,11 @@ def getArrayInObject( dataSet: vtkDataSet, attributeName: str, piece: Piece ) ->
return npArray


def getVtkArrayTypeInObject( mesh: Union[ vtkDataSet, vtkMultiBlockDataSet ], attributeName: str, piece: Piece ) -> int:
def getVtkArrayTypeInObject(
mesh: Union[ vtkDataSet, vtkMultiBlockDataSet ],
attributeName: str,
piece: Piece,
) -> int:
"""Return VTK type of requested array from input mesh.

Args:
Expand Down Expand Up @@ -506,7 +546,11 @@ def getVtkArrayTypeInObject( mesh: Union[ vtkDataSet, vtkMultiBlockDataSet ], at
raise TypeError( "Input object must be a vtkDataSet or vtkMultiBlockDataSet." )


def getVtkArrayInObject( dataSet: vtkDataSet, attributeName: str, piece: Piece ) -> vtkDataArray:
def getVtkArrayInObject(
dataSet: vtkDataSet,
attributeName: str,
piece: Piece,
) -> vtkDataArray:
"""Return the array corresponding to input attribute name in the mesh.

Args:
Expand Down Expand Up @@ -621,9 +665,11 @@ def getComponentNames(
return tuple( componentNames )


def getAttributeValuesAsDF( surface: vtkPolyData,
attributeNames: tuple[ str, ...],
piece: Piece = Piece.CELLS ) -> pd.DataFrame:
def getAttributeValuesAsDF(
surface: vtkPolyData,
attributeNames: tuple[ str, ...],
piece: Piece = Piece.CELLS,
) -> pd.DataFrame:
"""Get attribute values from input surface.

Args:
Expand Down Expand Up @@ -660,7 +706,7 @@ def getAttributeValuesAsDF( surface: vtkPolyData,
return data


def computeCellCenterCoordinates( mesh: vtkDataSet ) -> vtkDataArray:
def computeCellCenterCoordinates( mesh: vtkDataSet, ) -> vtkDataArray:
"""Get the coordinates of Cell center.

Args:
Expand Down
Loading