Giter Site home page Giter Site logo

Comments (4)

yaoyaoding avatar yaoyaoding commented on July 27, 2024 1

Hi @eric8607242,

Yes, it is as what you said and you asked a good question.

This is a temparary limitation of our current IR and runtime system. The direct reason is that we do not have an operator like "to_device". We currently do not have a C++ runtime, but replies on CUDA graph to get rid of the framework-level overhead. It is not trivial to track both CPU kernel and GPU kernels in the same CUDA graph. So, before we have an efficient C++ runtime, we will not support the feature to mix kernels on cpu and gpu in a single computation graph.

Of course, if there are some important DNNs that reply on this feature, we would like to give it a higher priority. Currently, we are focusing on dynamic shape support.

from hidet.

yaoyaoding avatar yaoyaoding commented on July 27, 2024

Hi @eric8607242,

Thanks for bringing this up. We have partially fixed this issue in #214. With this PR, we can run your example:

import torch
from torch import nn
import hidet


class TestMode(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Linear(10, 10)

    def forward(self, x):
        z = x.unsqueeze(0).expand(4, 4, 512).to(torch.device("cuda"))
        return z

if __name__ == "__main__":
    model = TestMode()
    model = model.eval().half()
    device = torch.device("cpu")
    model = model.to(device)
    hidet.torch.dynamo_config.search_space(2)
    hidet.torch.dynamo_config.use_fp16()
    model_opt = torch.compile(model, backend='hidet')

    tokens = torch.zeros(4, 512).cuda()
    model_opt(tokens)

The limitation is: for the tensor that is dependent on the model input (e.g., x.unsqueeze(0).expand(4, 4, 512) in your example), it can only be casted to the same device as the itself using either .cuda(), .cpu() or .to(device=...). The weight tensor does not have this limitation.

See the tests for more examples of what is supported and not.

from hidet.

eric8607242 avatar eric8607242 commented on July 27, 2024

Hi @yaoyaoding,

Thanks for your kindful response and quick fix. It is very helpful.

Sorry for two more silly questions.
Do you mean that if a model input is on the cpu then we can not cast the input to cuda with .cuda() or .to(torch.device("cuda")?
Why there is such a limitation? Big thanks for your help

from hidet.

eric8607242 avatar eric8607242 commented on July 27, 2024

Hi @yaoyaoding,

Thanks for the very clear answer.
I have no more questions and the issue is also solved.

Thanks for this amazing work again.
Close the issue.

from hidet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.