littlelittlecloud
commited on
Commit
·
2cc08ea
1
Parent(s):
7779efa
disposable pattern
Browse files- AutoencoderKL.cs +9 -2
- ClipEnocder.cs +9 -2
- DDPM.cs +7 -1
- Program.cs +10 -5
AutoencoderKL.cs
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
|
|
| 1 |
using TorchSharp;
|
| 2 |
|
| 3 |
-
public class AutoencoderKL
|
| 4 |
{
|
| 5 |
-
private
|
| 6 |
private readonly float _scale;
|
| 7 |
public torch.Device Device {get;}
|
| 8 |
|
|
@@ -21,4 +22,10 @@ public class AutoencoderKL
|
|
| 21 |
tokenTensor = 1.0f / _scale * tokenTensor;
|
| 22 |
return (torch.Tensor)_model.forward(tokenTensor);
|
| 23 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
}
|
|
|
|
| 1 |
+
using System;
|
| 2 |
using TorchSharp;
|
| 3 |
|
| 4 |
+
public class AutoencoderKL : IDisposable
|
| 5 |
{
|
| 6 |
+
private torch.jit.ScriptModule _model;
|
| 7 |
private readonly float _scale;
|
| 8 |
public torch.Device Device {get;}
|
| 9 |
|
|
|
|
| 22 |
tokenTensor = 1.0f / _scale * tokenTensor;
|
| 23 |
return (torch.Tensor)_model.forward(tokenTensor);
|
| 24 |
}
|
| 25 |
+
|
| 26 |
+
public void Dispose()
|
| 27 |
+
{
|
| 28 |
+
_model.Dispose();
|
| 29 |
+
_model = null;
|
| 30 |
+
}
|
| 31 |
}
|
ClipEnocder.cs
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
|
|
| 1 |
using TorchSharp;
|
| 2 |
|
| 3 |
-
public class ClipEncoder
|
| 4 |
{
|
| 5 |
-
private
|
| 6 |
public torch.Device Device {get;}
|
| 7 |
|
| 8 |
public ClipEncoder(string modelPath, torch.Device device)
|
|
@@ -17,4 +18,10 @@ public class ClipEncoder
|
|
| 17 |
{
|
| 18 |
return (torch.Tensor)_model.forward(tokenTensor);
|
| 19 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
}
|
|
|
|
| 1 |
+
using System;
|
| 2 |
using TorchSharp;
|
| 3 |
|
| 4 |
+
public class ClipEncoder : IDisposable
|
| 5 |
{
|
| 6 |
+
private torch.jit.ScriptModule _model;
|
| 7 |
public torch.Device Device {get;}
|
| 8 |
|
| 9 |
public ClipEncoder(string modelPath, torch.Device device)
|
|
|
|
| 18 |
{
|
| 19 |
return (torch.Tensor)_model.forward(tokenTensor);
|
| 20 |
}
|
| 21 |
+
|
| 22 |
+
public void Dispose()
|
| 23 |
+
{
|
| 24 |
+
_model.Dispose();
|
| 25 |
+
_model = null;
|
| 26 |
+
}
|
| 27 |
}
|
DDPM.cs
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
using System;
|
| 2 |
using TorchSharp;
|
| 3 |
|
| 4 |
-
public class DDPM
|
| 5 |
{
|
| 6 |
private readonly torch.jit.ScriptModule _model;
|
| 7 |
public torch.Device Device {get;}
|
|
@@ -36,4 +36,10 @@ public class DDPM
|
|
| 36 |
{
|
| 37 |
return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v);
|
| 38 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
}
|
|
|
|
| 1 |
using System;
|
| 2 |
using TorchSharp;
|
| 3 |
|
| 4 |
+
public class DDPM : IDisposable
|
| 5 |
{
|
| 6 |
private readonly torch.jit.ScriptModule _model;
|
| 7 |
public torch.Device Device {get;}
|
|
|
|
| 36 |
{
|
| 37 |
return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v);
|
| 38 |
}
|
| 39 |
+
|
| 40 |
+
public void Dispose()
|
| 41 |
+
{
|
| 42 |
+
_model.Dispose();
|
| 43 |
+
_model = null;
|
| 44 |
+
}
|
| 45 |
}
|
Program.cs
CHANGED
|
@@ -6,9 +6,6 @@ using TorchSharp;
|
|
| 6 |
|
| 7 |
torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
|
| 8 |
var device = TorchSharp.torch.device("cuda:0");
|
| 9 |
-
var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
|
| 10 |
-
var ddimSampler = new DDIMSampler(ddpm);
|
| 11 |
-
var autoencoderKL = new AutoencoderKL("autoencoder_kl.ckpt", device);
|
| 12 |
var clipEncoder = new ClipEncoder("clip_encoder.ckpt", device);
|
| 13 |
var start_token = 49406;
|
| 14 |
var end_token = 49407;
|
|
@@ -21,7 +18,7 @@ var dictionary = new Dictionary<string, long>(){
|
|
| 21 |
{"green", 1901},
|
| 22 |
};
|
| 23 |
|
| 24 |
-
var batch =
|
| 25 |
|
| 26 |
var prompt = "a wild cute green cat";
|
| 27 |
var tokens = prompt.Split(' ').Select(x => dictionary[x]).ToList();
|
|
@@ -33,15 +30,23 @@ var tokenTensor = torch.tensor(tokens.ToArray(), dtype: torch.ScalarType.Int64,
|
|
| 33 |
tokenTensor = tokenTensor.repeat(batch, 1);
|
| 34 |
var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
|
| 35 |
unconditional_tokenTensor = unconditional_tokenTensor.repeat(batch, 1);
|
| 36 |
-
var img = torch.randn(batch, 4,
|
| 37 |
var t = torch.full(new[]{batch, 1L}, value: batch, dtype: torch.ScalarType.Int32, device: device);
|
| 38 |
var condition = clipEncoder.Forward(tokenTensor);
|
| 39 |
var unconditional_condition = clipEncoder.Forward(unconditional_tokenTensor);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
var ddim_steps = 50;
|
| 41 |
img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
|
|
|
|
|
|
|
|
|
|
| 42 |
var decoded_images = (torch.Tensor)autoencoderKL.Forward(img);
|
| 43 |
decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
|
| 44 |
|
|
|
|
| 45 |
for(int i = 0; i!= batch; ++i)
|
| 46 |
{
|
| 47 |
var image = decoded_images[i];
|
|
|
|
| 6 |
|
| 7 |
torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
|
| 8 |
var device = TorchSharp.torch.device("cuda:0");
|
|
|
|
|
|
|
|
|
|
| 9 |
var clipEncoder = new ClipEncoder("clip_encoder.ckpt", device);
|
| 10 |
var start_token = 49406;
|
| 11 |
var end_token = 49407;
|
|
|
|
| 18 |
{"green", 1901},
|
| 19 |
};
|
| 20 |
|
| 21 |
+
var batch = 1;
|
| 22 |
|
| 23 |
var prompt = "a wild cute green cat";
|
| 24 |
var tokens = prompt.Split(' ').Select(x => dictionary[x]).ToList();
|
|
|
|
| 30 |
tokenTensor = tokenTensor.repeat(batch, 1);
|
| 31 |
var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
|
| 32 |
unconditional_tokenTensor = unconditional_tokenTensor.repeat(batch, 1);
|
| 33 |
+
var img = torch.randn(batch, 4, 64, 64, dtype: torch.ScalarType.Float32, device: device);
|
| 34 |
var t = torch.full(new[]{batch, 1L}, value: batch, dtype: torch.ScalarType.Int32, device: device);
|
| 35 |
var condition = clipEncoder.Forward(tokenTensor);
|
| 36 |
var unconditional_condition = clipEncoder.Forward(unconditional_tokenTensor);
|
| 37 |
+
|
| 38 |
+
clipEncoder.Dispose();
|
| 39 |
+
var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
|
| 40 |
+
var ddimSampler = new DDIMSampler(ddpm);
|
| 41 |
var ddim_steps = 50;
|
| 42 |
img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
|
| 43 |
+
ddpm.Dispose();
|
| 44 |
+
|
| 45 |
+
var autoencoderKL = new AutoencoderKL("autoencoder_kl.ckpt", device);
|
| 46 |
var decoded_images = (torch.Tensor)autoencoderKL.Forward(img);
|
| 47 |
decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
|
| 48 |
|
| 49 |
+
|
| 50 |
for(int i = 0; i!= batch; ++i)
|
| 51 |
{
|
| 52 |
var image = decoded_images[i];
|