diff --git a/jdaviz/configs/cubeviz/plugins/parsers.py b/jdaviz/configs/cubeviz/plugins/parsers.py
index 6e8f58c938..502965531f 100644
--- a/jdaviz/configs/cubeviz/plugins/parsers.py
+++ b/jdaviz/configs/cubeviz/plugins/parsers.py
@@ -8,13 +8,13 @@
 from astropy.nddata import StdDevUncertainty
 from astropy.time import Time
 from astropy.wcs import WCS
-from specutils import Spectrum1D, SpectralAxis
+from specutils import Spectrum1D
 
 from jdaviz.core.custom_units import PIX2
 from jdaviz.core.registries import data_parser_registry
 from jdaviz.core.validunits import check_if_unit_is_per_solid_angle
-from jdaviz.utils import standardize_metadata, PRIHDR_KEY, download_uri_to_path
-
+from jdaviz.utils import (standardize_metadata, PRIHDR_KEY, download_uri_to_path,
+                          _eqv_flux_to_sb_pixel)
 
 __all__ = ['parse_data']
 
@@ -188,23 +188,22 @@ def _return_spectrum_with_correct_units(flux, wcs, metadata, data_type=None,
     Also converts flux units to flux/pix2 solid angle units, if `flux` is not a surface
     brightness and `apply_pix2` is True.
     """
+    # handle scale factors when they are included in the unit
+    # (has to be done before Spectrum1D creation)
+    if not np.isclose(flux.unit.scale, 1, rtol=1e-5):
+        flux = flux.to(flux.unit / flux.unit.scale)
+
     with warnings.catch_warnings():
         warnings.filterwarnings(
             'ignore', message='Input WCS indicates that the spectral axis is not last',
             category=UserWarning)
+        sc = Spectrum1D(flux=flux, wcs=wcs, meta=metadata, uncertainty=uncertainty, mask=mask)
 
-        # convert flux and uncertainty to per-pix2 if input is not a surface brightness
-        if apply_pix2:
-            if not check_if_unit_is_per_solid_angle(flux.unit):
-                flux = flux / PIX2
-                if uncertainty is not None:
-                    uncertainty = uncertainty / PIX2
-
-        # handle scale factors when they are included in the unit
-        if not np.isclose(flux.unit.scale, 1.0, rtol=1e-5):
-            flux = flux.to(flux.unit / flux.unit.scale)
-
-        sc = Spectrum1D(flux=flux, wcs=wcs, uncertainty=uncertainty, mask=mask)
+    # convert flux and uncertainty to per-pix2 if input is not a surface brightness
+    target_flux_unit = None
+    if (apply_pix2 and (data_type != "mask") and
+            (not check_if_unit_is_per_solid_angle(flux.unit))):
+        target_flux_unit = flux.unit / PIX2
 
     if target_wave_unit is None and hdulist is not None:
         found_target = False
@@ -223,23 +222,22 @@ def _return_spectrum_with_correct_units(flux, wcs, metadata, data_type=None,
                     found_target = True
                     break
 
-    if (data_type == 'flux' and target_wave_unit is not None
-            and target_wave_unit != sc.spectral_axis.unit):
-        metadata['_orig_spec'] = sc
-        with warnings.catch_warnings():
-            warnings.filterwarnings(
-                'ignore', message='Input WCS indicates that the spectral axis is not last',
-                category=UserWarning)
-            new_sc = Spectrum1D(
-                flux=sc.flux,
-                spectral_axis=sc.spectral_axis.to(target_wave_unit, u.spectral()),
-                meta=metadata,
-                uncertainty=sc.uncertainty,
-                mask=sc.mask
-            )
-    else:
-        sc.meta = metadata
+    if target_wave_unit == sc.spectral_axis.unit:
+        target_wave_unit = None
+
+    if (target_wave_unit is None) and (target_flux_unit is None):  # Nothing to convert
         new_sc = sc
+    elif target_flux_unit is None:  # Convert wavelength only
+        new_sc = sc.with_spectral_axis_unit(target_wave_unit)
+    elif target_wave_unit is None:  # Convert flux only and only PIX2 stuff
+        new_sc = sc.with_flux_unit(target_flux_unit, equivalencies=_eqv_flux_to_sb_pixel())
+    else:  # Convert both
+        new_sc = sc.with_spectral_axis_and_flux_units(
+            target_wave_unit, target_flux_unit, flux_equivalencies=_eqv_flux_to_sb_pixel())
+
+    if target_wave_unit is not None:
+        new_sc.meta['_orig_spec'] = sc  # Need this for later
+
     return new_sc
 
 
@@ -300,7 +298,7 @@ def _parse_hdulist(app, hdulist, file_name=None,
         metadata['_orig_spatial_wcs'] = _get_celestial_wcs(wcs)
 
         apply_pix2 = data_type in ['flux', 'uncert']
-        sc = _return_spectrum_with_correct_units(flux, wcs, metadata, data_type,
+        sc = _return_spectrum_with_correct_units(flux, wcs, metadata, data_type=data_type,
                                                  hdulist=hdulist, apply_pix2=apply_pix2)
 
         app.add_data(sc, data_label)
@@ -358,7 +356,8 @@ def _parse_jwst_s3d(app, hdulist, data_label, ext='SCI',
     if hdu.name != 'PRIMARY' and 'PRIMARY' in hdulist:
         metadata[PRIHDR_KEY] = standardize_metadata(hdulist['PRIMARY'].header)
 
-    data = _return_spectrum_with_correct_units(flux, wcs, metadata, data_type, hdulist=hdulist)
+    data = _return_spectrum_with_correct_units(
+        flux, wcs, metadata, data_type=data_type, hdulist=hdulist)
     app.add_data(data, data_label, parent=parent)
 
     # get glue data and update if DQ:
@@ -418,7 +417,8 @@ def _parse_esa_s3d(app, hdulist, data_label, ext='DATA', flux_viewer_reference_n
     # to sky regions, where the parent data of the subset might have dropped spatial WCS info
     metadata['_orig_spatial_wcs'] = _get_celestial_wcs(wcs)
 
-    data = _return_spectrum_with_correct_units(flux, wcs, metadata, data_type, hdulist=hdulist)
+    data = _return_spectrum_with_correct_units(
+        flux, wcs, metadata, data_type=data_type, hdulist=hdulist)
 
     app.add_data(data, data_label)
 
@@ -466,12 +466,10 @@ def _parse_spectrum1d_3d(app, file_obj, data_label=None,
             if hasattr(file_obj, 'wcs'):
                 meta['_orig_spatial_wcs'] = _get_celestial_wcs(file_obj.wcs)
 
-            s1d = _return_spectrum_with_correct_units(flux, wcs=file_obj.wcs, metadata=meta)
-
-            # convert data loaded in flux units to a per-square-pixel surface
+            # Also convert data loaded in flux units to a per-square-pixel surface
             # brightness unit (e.g Jy to Jy/pix**2)
-            if (attr != "mask") and (not check_if_unit_is_per_solid_angle(flux.unit)):
-                s1d = convert_spectrum1d_from_flux_to_flux_per_pixel(s1d)
+            s1d = _return_spectrum_with_correct_units(
+                flux, file_obj.wcs, meta, data_type=attr, apply_pix2=True)
 
         cur_data_label = app.return_data_label(data_label, attr.upper())
         app.add_data(s1d, cur_data_label)
@@ -502,7 +500,8 @@ def _parse_spectrum1d(app, file_obj, data_label=None, spectrum_viewer_reference_
     # convert data loaded in flux units to a per-square-pixel surface
     # brightness unit (e.g Jy to Jy/pix**2)
     if not check_if_unit_is_per_solid_angle(file_obj.flux.unit):
-        file_obj = convert_spectrum1d_from_flux_to_flux_per_pixel(file_obj)
+        file_obj = file_obj.with_flux_unit(
+            file_obj.flux.unit / PIX2, equivalencies=_eqv_flux_to_sb_pixel())
 
     app.add_data(file_obj, data_label)
     app.add_data_to_viewer(spectrum_viewer_reference_name, data_label)
@@ -522,15 +521,15 @@ def _parse_ndarray(app, file_obj, data_label=None, data_type=None,
     flux = file_obj
 
     if not hasattr(flux, 'unit'):
-        flux = flux << u.count
+        flux = flux << (u.count / PIX2)
 
     meta = standardize_metadata({'_orig_spatial_wcs': None})
     s3d = Spectrum1D(flux=flux, meta=meta)
 
     # convert data loaded in flux units to a per-square-pixel surface
     # brightness unit (e.g Jy to Jy/pix**2)
-    if not check_if_unit_is_per_solid_angle(s3d.unit):
-        file_obj = convert_spectrum1d_from_flux_to_flux_per_pixel(s3d)
+    if not check_if_unit_is_per_solid_angle(s3d.flux.unit):
+        s3d = s3d.with_flux_unit(s3d.flux.unit / PIX2, equivalencies=_eqv_flux_to_sb_pixel())
 
     app.add_data(s3d, data_label)
 
@@ -556,12 +555,7 @@ def _parse_gif(app, file_obj, data_label=None, flux_viewer_reference_name=None):
     flux = np.rot90(np.moveaxis(flux, 0, 2), k=-1, axes=(0, 1))
 
     meta = {'filename': file_name, '_orig_spatial_wcs': None}
-    s3d = Spectrum1D(flux=flux * u.count, meta=standardize_metadata(meta))
-
-    # convert data loaded in flux units to a per-square-pixel surface
-    # brightness unit (e.g Jy to Jy/pix**2)
-    if not check_if_unit_is_per_solid_angle(s3d):
-        file_obj = convert_spectrum1d_from_flux_to_flux_per_pixel(s3d)
+    s3d = Spectrum1D(flux=flux * (u.count / PIX2), meta=standardize_metadata(meta))
 
     app.add_data(s3d, data_label)
     app.add_data_to_viewer(flux_viewer_reference_name, data_label)
@@ -580,67 +574,3 @@ def _get_data_type_by_hdu(hdu):
     else:
         data_type = ''
     return data_type
-
-
-def convert_spectrum1d_from_flux_to_flux_per_pixel(spectrum):
-    """
-    Converts a Spectrum1D object's flux units to flux per square pixel.
-
-    This function takes a `specutils.Spectrum1D` object with flux units and converts the
-    flux (and optionally, uncertainty) to a surface brightness per square pixel
-    (e.g., from Jy to Jy/pix**2). This is done by updating the units of spectrum.flux
-    and (if present) spectrum.uncertainty, and creating a new `specutils.Spectrum1D`
-    object with the modified flux and uncertainty.
-
-    Parameters
-    ----------
-    spectrum : Spectrum1D
-        A `specutils.Spectrum1D` object containing flux data, which is assumed to be in
-        flux units without any angular component in the denominator.
-
-    Returns
-    -------
-    Spectrum1D
-        A new `specutils.Spectrum1D` object with flux and uncertainty (if present)
-        converted to units of flux per square pixel.
-    """
-
-    # convert flux, which is always populated
-    flux = getattr(spectrum, 'flux')
-    flux = flux / PIX2
-
-    # and uncerts, if present
-    uncerts = getattr(spectrum, 'uncertainty')
-    if uncerts is not None:
-        # enforce common uncert type.
-        uncerts = uncerts.represent_as(StdDevUncertainty)
-        uncerts = StdDevUncertainty(uncerts.quantity / PIX2)
-
-    # create a new spectrum 1d with all the info from the input spectrum 1d,
-    # and the flux / uncerts converted from flux to SB per square pixel
-
-    # if there is a spectral axis that is a SpectralAxis, you cant also set
-    # redshift or radial_velocity
-    spectral_axis = getattr(spectrum, 'spectral_axis', None)
-    if spectral_axis is not None:
-        if isinstance(spectral_axis, SpectralAxis):
-            redshift = None
-            radial_velocity = None
-        else:
-            redshift = spectrum.redshift
-            radial_velocity = spectrum.radial_velocity
-
-    # initialize new spectrum1d with new flux, uncerts, and all other init parameters
-    # from old input spectrum as well as any 'meta'. any more missing information
-    # not in init signature that might be present in `spectrum`?
-    new_spec1d = Spectrum1D(flux=flux, uncertainty=uncerts,
-                            spectral_axis=spectrum.spectral_axis,
-                            mask=spectrum.mask,
-                            wcs=spectrum.wcs,
-                            velocity_convention=spectrum.velocity_convention,
-                            rest_value=spectrum.rest_value, redshift=redshift,
-                            radial_velocity=radial_velocity,
-                            bin_specification=getattr(spectrum, 'bin_specification', None),
-                            meta=spectrum.meta)
-
-    return new_spec1d
diff --git a/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py b/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py
index 0d3b61e727..f45d8db878 100644
--- a/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py
+++ b/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py
@@ -551,7 +551,7 @@ def _return_extracted(self, cube, wcs, collapsed_nddata):
         uncertainty = collapsed_nddata.uncertainty
 
         collapsed_spec = _return_spectrum_with_correct_units(
-            flux, wcs, collapsed_nddata.meta, 'flux',
+            flux, wcs, collapsed_nddata.meta, data_type='flux',
             target_wave_unit=target_wave_unit,
             uncertainty=uncertainty,
             mask=mask
diff --git a/jdaviz/configs/cubeviz/plugins/tests/test_parsers.py b/jdaviz/configs/cubeviz/plugins/tests/test_parsers.py
index 9c3d715548..0949f02c06 100644
--- a/jdaviz/configs/cubeviz/plugins/tests/test_parsers.py
+++ b/jdaviz/configs/cubeviz/plugins/tests/test_parsers.py
@@ -194,7 +194,7 @@ def test_numpy_cube(cubeviz_helper):
     assert data.label == 'Array'
     assert data.shape == (4, 3, 2)  # x, y, z
     assert isinstance(data.coords, PaddedSpectrumWCS)
-    assert flux.units == 'ct'
+    assert flux.units == 'ct / pix2'
 
     # Check context of second cube.
     data = cubeviz_helper.app.data_collection[1]
@@ -202,7 +202,7 @@ def test_numpy_cube(cubeviz_helper):
     assert data.label == 'uncert_array'
     assert data.shape == (4, 3, 2)  # x, y, z
     assert isinstance(data.coords, PaddedSpectrumWCS)
-    assert flux.units == 'ct'
+    assert flux.units == 'ct / pix2'
 
 
 def test_invalid_data_types(cubeviz_helper):
diff --git a/jdaviz/utils.py b/jdaviz/utils.py
index 9d04a9a752..64e91739d1 100644
--- a/jdaviz/utils.py
+++ b/jdaviz/utils.py
@@ -549,9 +549,11 @@ def _eqv_flux_to_sb_pixel():
 
     # generate an equivalency for each flux type that would need
     # another equivalency for converting to/from
-    flux_units = [u.MJy, u.erg / (u.s * u.cm**2 * u.Angstrom),
+    flux_units = [u.MJy,
+                  u.erg / (u.s * u.cm**2 * u.Angstrom),
                   u.ph / (u.Angstrom * u.s * u.cm**2),
-                  u.ph / (u.Hz * u.s * u.cm**2)]
+                  u.ph / (u.Hz * u.s * u.cm**2),
+                  u.ct]
     return [(flux_unit, flux_unit / PIX2, lambda x: x, lambda x: x)
             for flux_unit in flux_units]