Skip to content

Commit

Permalink
added energon dataloader for neva training (#10451)
Browse files Browse the repository at this point in the history
* added energon dataloader for neva training

* Apply isort and black reformatting

Signed-off-by: yashaswikarnati <[email protected]>

* specify global batch size to support grad accumulation

* adding neva pretrain example

* Apply isort and black reformatting

Signed-off-by: yashaswikarnati <[email protected]>

* change pretraine example to handle new ckpt reloading

* fixed code quality warnings and unused imports

Signed-off-by: ykarnati <[email protected]>

* minor changes for PR comments

* Apply isort and black reformatting

Signed-off-by: yashaswikarnati <[email protected]>

* refactor conversation template config

* Apply isort and black reformatting

Signed-off-by: yashaswikarnati <[email protected]>

* remove optional import

---------

Signed-off-by: yashaswikarnati <[email protected]>
Signed-off-by: ykarnati <[email protected]>
Co-authored-by: yashaswikarnati <[email protected]>
  • Loading branch information
yashaswikarnati and yashaswikarnati authored Sep 19, 2024
1 parent 3653bed commit 7354740
Show file tree
Hide file tree
Showing 10 changed files with 1,907 additions and 0 deletions.
5 changes: 5 additions & 0 deletions nemo/collections/multimodal/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule

__all__ = ["SimpleMultiModalDataModule"]
40 changes: 40 additions & 0 deletions nemo/collections/multimodal/data/energon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule
from nemo.collections.multimodal.data.energon.config import (
ImageTextSample,
ImageToken,
LLaVATemplateConfig,
MultiModalSampleConfig,
)
from nemo.collections.multimodal.data.energon.sample_encoder import (
BaseSampleEncoder,
InterleavedSampleEncoder,
SimilarityInterleavedEncoder,
VQASampleEncoder,
)

__all__ = [
"SimpleMultiModalDataModule",
"ImageToken",
"ImageTextSample",
"MultiModalSampleConfig",
"LLaVATemplateConfig",
"BaseSampleEncoder",
"VQASampleEncoder",
"InterleavedSampleEncoder",
"SimilarityInterleavedEncoder",
]
367 changes: 367 additions & 0 deletions nemo/collections/multimodal/data/energon/base.py

Large diffs are not rendered by default.

70 changes: 70 additions & 0 deletions nemo/collections/multimodal/data/energon/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import List
import torch
from nemo.collections.multimodal.data.energon.conversation import BaseConversationTemplateConfig


@dataclass
class MultiModalToken:
token_str: str
token_id: int
media_type: str


@dataclass
class ImageToken(MultiModalToken):
token_str: str = "<image>"
token_id: int = -200
media_type: str = "image"


@dataclass
class ImageTextSample:
'''Sample type for template formatted raw image text sample'''

__key__: str = ''
images: torch.Tensor = field(default_factory=lambda: torch.empty(0))
tokens: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.long))
labels: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.long))
loss_mask: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))


@dataclass
class ImageTextRawBatch:
"""Sample type for image text raw batch"""

__keys__: List[str] = field(default_factory=list)
#: Input images (N, C, H, W)
images: torch.Tensor = field(default_factory=lambda: torch.empty(0))
#: Context string
tokens: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.long))
labels: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.long))
loss_mask: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))


class LLaVATemplateConfig(BaseConversationTemplateConfig):
"""LLava specific template configuration which extends the base config"""

pass


@dataclass
class MultiModalSampleConfig:
image_token: ImageToken = ImageToken()
ignore_place_holder: int = -100
conversation_template_config: LLaVATemplateConfig = LLaVATemplateConfig()
image_following_text: bool = True
38 changes: 38 additions & 0 deletions nemo/collections/multimodal/data/energon/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import List, Optional


@dataclass
class BaseConversationTemplateConfig:
"""Conversation template config related parameters"""

system: Optional[str] = (
"A chat between a curious user and artificial assistant agent. The assistant gives helpful, detailed and polite answers to user's questions.".format()
) # fmt: off
roles: List[str] = field(default_factory=lambda: ['user', 'assistant'])
stop_string: str = "</s>"
chat_template = """
{%- for message in messages %}
{%- if message['role'] == 'system' %}
{{- message['content'].strip() + ' ' -}}
{%- elif message['role'] == 'user' %}
{{- 'USER: ' -}} {{- message['content'].strip() + ' ' -}}
{%- elif message['role'] == 'assistant' %}
{{- 'ASSISTANT: ' -}} {{- message['content'].strip() -}}
{{- '</s>' -}}
{%- endif %}
{%- endfor -%}
"""
Loading

0 comments on commit 7354740

Please sign in to comment.