Source code for kor.encoders.csv_data

"""Module that contains Kor flavored encoders/decoders for CSV data.

The code will need to eventually support handling some form of nested objects,
via either JSON encoded column values or by breaking down nested attributes
into additional columns (likely both methods).
"""

from io import StringIO
from typing import Any, Dict, List

import pandas as pd

from kor.encoders.typedefs import SchemaBasedEncoder
from kor.encoders.utils import unwrap_tag, wrap_in_tag
from kor.exceptions import ParseError
from kor.nodes import AbstractSchemaNode, Object

DELIMITER = "|"


def _extract_top_level_fieldnames(node: AbstractSchemaNode) -> List[str]:
    """Temporary schema description for CSV extraction."""
    if isinstance(node, Object):
        return [attributes.id for attributes in node.attributes]
    else:
        return [node.id]


# PUBLIC API


[docs]class CSVEncoder(SchemaBasedEncoder): """CSV encoder.""" def __init__(self, node: AbstractSchemaNode, use_tags: bool = False) -> None: """Attach node to the encoder to allow the encoder to understand schema. Args: node: The schema node to attach to the encoder. use_tags: Whether to wrap the output in tags. This may help identify the table content in cases when the model attempts to add clarifying explanations. """ super().__init__(node) self.use_tags = use_tags # Verify that if we have an Object then none of its attributes are lists # or objects as that functionality is not yet supported. if isinstance(node, Object): for attribute in node.attributes: if attribute.many or isinstance(attribute, Object): raise NotImplementedError( "CSV Encoder does not yet support embedded lists or " f"objects (attribute `{attribute.id}`)." )
[docs] def encode(self, data: Any) -> str: """Encode the data.""" if not isinstance(data, dict): raise TypeError(f"Was expecting a dictionary got {type(data)}") expected_key = self.node.id if expected_key not in data: raise AssertionError(f"Expected a key: `{expected_key} to appear in data.") if isinstance(self.node, Object): field_names = _extract_top_level_fieldnames(self.node) else: field_names = [self.node.id] data_to_output = data[expected_key] if not isinstance(data_to_output, list): # Should always output records for pd.Dataframe data_to_output = [data_to_output] table_content = pd.DataFrame(data_to_output, columns=field_names).to_csv( index=False, sep=DELIMITER ) if self.use_tags: return wrap_in_tag("csv", table_content) return table_content
[docs] def decode(self, text: str) -> Dict[str, List[Dict[str, Any]]]: """Decode the text.""" # First get the content between the table tags if self.use_tags: table_str = unwrap_tag("csv", text) else: table_str = text if table_str: with StringIO(table_str) as buffer: try: df = pd.read_csv( buffer, dtype=str, keep_default_na=False, sep=DELIMITER, skipinitialspace=True, ) except Exception as e: raise ParseError(e) records = df.to_dict(orient="records") else: records = [] namespace = self.node.id return {namespace: records}
[docs] def get_instruction_segment(self) -> str: """Format instructions.""" instructions = [ "Please output the extracted information in CSV format in Excel dialect.", f"Please use a {DELIMITER} as the delimiter." # TODO(Eugene): Add this when we start supporting embedded columns. # "If a column corresponds to an array or an object, # use a JSON encoding to " # "encode its value.", ] if self.use_tags: instructions.append( "Please output a <csv> tag before and a closing </csv> after the table." ) instructions.extend( [ "\n", "Do NOT add any clarifying information.", "Output MUST follow the schema above.", "Do NOT add any additional columns that do not appear in the schema.", ] ) return " ".join(instructions)